In [33]:
import json
import random

from collections import Counter

import pandas as pd
from sklearn.model_selection import train_test_split

from utils.data2seq import Dial2seq, SequencePreprocessor

## Merging TC and DD and splitting them into train, val and test

In [7]:
topical_sequencer = Dial2seq('data/topical_chat_annotated.json', 3)
daily_sequencer = Dial2seq('data/daily_dialogue_annotated.json', 3)

In [9]:
daily = daily_sequencer.transform()
topical = topical_sequencer.transform()

In [10]:
daily_preproc = SequencePreprocessor()

In [11]:
daily_dataset = daily_preproc.transform(daily)
len(daily_dataset)

29656

In [14]:
topical_preproc = SequencePreprocessor()
topical_dataset = topical_preproc.transform(topical)
len(topical_dataset)

84140

In [15]:
midas_dataset = list()
midas_entity_dataset = list()

for sample in daily_dataset:
    # if there is no annotated entity, add it to midas dataset
    if not sample['predict']['entity']['label']:
        midas_dataset.append(sample)
    else:
        midas_entity_dataset.append(sample)

for sample in topical_dataset:
    # if there is no annotated entity, add it to midas dataset
    if not sample['predict']['entity']['label']:
        midas_dataset.append(sample)
    else:
        midas_entity_dataset.append(sample)

In [16]:
len(midas_dataset), len(midas_entity_dataset)

(101327, 12469)

In [35]:
random.Random(42).shuffle(midas_entity_dataset)

In [36]:
train, val_test = train_test_split(midas_entity_dataset, test_size=0.2, random_state=42)
val, test = train_test_split(val_test, test_size=0.5, random_state=42)

In [62]:
%%bash
mkdir data/annotated

In [63]:
with open('data/annotated/train.json', 'w', encoding='utf-8') as f:
    json.dump(train, f, ensure_ascii=False, indent=2)
    
with open('data/annotated/val.json', 'w', encoding='utf-8') as f:
    json.dump(val, f, ensure_ascii=False, indent=2)
    
with open('data/annotated/test.json', 'w', encoding='utf-8') as f:
    json.dump(test, f, ensure_ascii=False, indent=2)

## Stats for train/val/test

In [37]:
target_train = [sample['predict'] for sample in train]
target_val = [sample['predict'] for sample in val]
target_test = [sample['predict'] for sample in test]

In [38]:
target_train = pd.json_normalize(target_train)
target_val = pd.json_normalize(target_val)
target_test = pd.json_normalize(target_test)

### Midas stats

In [39]:
target_train['midas'].value_counts()

opinion                  4048
statement                4013
yes_no_question           875
pos_answer                254
comment                   176
command                   168
open_question_factual     162
open_question_opinion     152
neg_answer                 82
complaint                  24
dev_command                10
appreciation                8
other_answers               3
Name: midas, dtype: int64

In [40]:
target_val['midas'].value_counts()

statement                503
opinion                  502
yes_no_question          106
open_question_opinion     27
pos_answer                26
command                   26
comment                   23
open_question_factual     18
neg_answer                 8
complaint                  6
dev_command                2
Name: midas, dtype: int64

In [41]:
target_test['midas'].value_counts()

statement                519
opinion                  501
yes_no_question           88
pos_answer                36
open_question_factual     30
comment                   23
open_question_opinion     21
command                   18
neg_answer                 5
complaint                  5
dev_command                1
Name: midas, dtype: int64

**validate that all midas labels from val and test are present in train**

In [58]:
set(target_val['midas'].value_counts().index.tolist()) -set(target_train['midas'].value_counts().index.tolist())

set()

In [59]:
set(target_test['midas'].value_counts().index.tolist()) - set(target_train['midas'].value_counts().index.tolist())

set()

### Entity stats

In [44]:
target_train['entity.label'].value_counts()

person                 2254
videoname              1517
location               1137
organization            825
device                  645
duration                634
genre                   541
sport                   508
number                  464
sportteam               370
softwareapplication     315
vehicle                 165
event                   143
position                130
date                    130
year                     88
gamename                 55
party                    40
bookname                 10
songname                  4
Name: entity.label, dtype: int64

In [43]:
target_val['entity.label'].value_counts()

person                 281
videoname              206
location               140
organization            98
device                  78
duration                77
genre                   73
sport                   62
number                  56
sportteam               54
softwareapplication     34
event                   24
vehicle                 17
position                15
date                    15
year                    10
gamename                 5
party                    2
Name: entity.label, dtype: int64

In [45]:
target_test['entity.label'].value_counts()

person                 295
videoname              176
location               145
organization            86
device                  84
genre                   77
duration                77
number                  61
sport                   61
sportteam               43
softwareapplication     42
vehicle                 25
event                   23
position                16
date                    15
year                     8
gamename                 7
party                    5
bookname                 1
Name: entity.label, dtype: int64

In [48]:
target_train['entity.label'].value_counts()

person                 2254
videoname              1517
location               1137
organization            825
device                  645
duration                634
genre                   541
sport                   508
number                  464
sportteam               370
softwareapplication     315
vehicle                 165
event                   143
position                130
date                    130
year                     88
gamename                 55
party                    40
bookname                 10
songname                  4
Name: entity.label, dtype: int64

**validate that all entities from val and test are present in train**

In [56]:
set(target_val['entity.label'].value_counts().index.tolist()) -set(target_train['entity.label'].value_counts().index.tolist())

set()

In [57]:
set(target_test['entity.label'].value_counts().index.tolist()) -set(target_train['entity.label'].value_counts().index.tolist())

set()