In [11]:
import os
import numpy as np
import csv
from transformers import T5Tokenizer
from collections import Counter


dataset =  {
                'parts': ['train', 'valid', 'test'],
                'sentences': {'train': 'processed/train_sentences.tsv', 'valid': 'processed/valid_sentences.tsv', 'test': 'processed/test_sentences.tsv'},
                'tuples': {'train': 'processed/train_tuples.tsv', 'valid': 'processed/valid_tuples.tsv', 'test': 'processed/test_tuples.tsv'},
}

os.chdir('/collection/ka2khan/thesis/Cond_Text_Gen')
print(os.getcwd())

data_path = 'data/DailyDialog'

contexts = {}
curr_sents = {}
next_sents = {}
for part in dataset['parts']:
    contexts[part] = []
    curr_sents[part] = []
    next_sents[part] = []

    file_path = os.path.join(data_path, dataset['tuples'][part])
    print(f"Reading file: {file_path}")
    with open(file_path) as f_obj:
        reader = csv.reader(f_obj, delimiter='\t')
        for row in reader:
            assert len(row) == 3, f'Error! row does not contain exactly three items! Count: {len(row)}'

            contexts[part].append(row[0])
            curr_sents[part].append(row[1])
            next_sents[part].append(row[2])


tokenizer = T5Tokenizer.from_pretrained('t5-large')
input_lengths = []
output_lengths = []
for part in dataset['parts']:
    for index, sent in enumerate(curr_sents[part]):
        input_seq = 'generate response query: ' + sent + ' </s>'
        input_lengths.append(len(tokenizer.encode(input_seq)))
        output_lengths.append(len(tokenizer.encode(next_sents[part][index])))
        
print(np.max(input_lengths))
print(np.max(output_lengths))

counter = Counter(output_lengths)
for item in counter.most_common():
    print(f'{item[0]} -> {item[1]}')



/collection/ka2khan/thesis/Cond_Text_Gen
Reading file: data/DailyDialog/processed/train_tuples.tsv
Reading file: data/DailyDialog/processed/valid_tuples.tsv
Reading file: data/DailyDialog/processed/test_tuples.tsv
329
333
9 -> 4346
10 -> 4227
8 -> 4177
11 -> 4167
12 -> 3969
7 -> 3869
13 -> 3635
14 -> 3396
6 -> 3187
15 -> 3165
16 -> 3011
17 -> 2734
18 -> 2551
19 -> 2402
20 -> 2245
4 -> 1995
21 -> 1993
22 -> 1905
23 -> 1723
24 -> 1604
5 -> 1603
25 -> 1516
26 -> 1329
3 -> 1253
27 -> 1176
28 -> 1077
29 -> 961
30 -> 862
31 -> 829
32 -> 696
33 -> 684
34 -> 648
35 -> 587
36 -> 530
38 -> 483
37 -> 454
39 -> 432
40 -> 360
41 -> 331
43 -> 321
42 -> 304
44 -> 294
45 -> 257
46 -> 246
48 -> 240
47 -> 229
49 -> 195
50 -> 184
52 -> 169
51 -> 168
53 -> 153
54 -> 150
55 -> 113
56 -> 106
59 -> 105
57 -> 90
58 -> 88
60 -> 85
62 -> 79
63 -> 77
61 -> 73
65 -> 61
67 -> 54
66 -> 53
68 -> 47
73 -> 41
64 -> 41
70 -> 41
71 -> 39
69 -> 36
75 -> 32
72 -> 32
74 -> 30
80 -> 24
77 -> 24
81 -> 22
76 -> 21
78 -> 21
2 