-
Notifications
You must be signed in to change notification settings - Fork 5
/
inputter.py
596 lines (499 loc) · 21.2 KB
/
inputter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
# -*- coding: utf-8 -*-
"""
Defining general functions for inputters
"""
import glob
import os
import codecs
from collections import Counter, defaultdict, OrderedDict
from itertools import count
import torch
import torchtext.data
import torchtext.vocab
from dataset_base import UNK_WORD, PAD_WORD, BOS_WORD, EOS_WORD
from text_dataset import TextDataset
from image_dataset import ImageDataset
from audio_dataset import AudioDataset
from ..utils.logging import logger
import gc
def _getstate(self):
return dict(self.__dict__, stoi=dict(self.stoi))
def _setstate(self, state):
self.__dict__.update(state)
self.stoi = defaultdict(lambda: 0, self.stoi)
torchtext.vocab.Vocab.__getstate__ = _getstate
torchtext.vocab.Vocab.__setstate__ = _setstate
def get_fields(data_type, n_src_features, n_tgt_features):
"""
Args:
data_type: type of the source input. Options are [text|img|audio].
n_src_features: the number of source features to
create `torchtext.data.Field` for.
n_tgt_features: the number of target features to
create `torchtext.data.Field` for.
Returns:
A dictionary whose keys are strings and whose values are the
corresponding Field objects.
"""
if data_type == 'text':
return TextDataset.get_fields(n_src_features, n_tgt_features)
elif data_type == 'img':
return ImageDataset.get_fields(n_src_features, n_tgt_features)
elif data_type == 'audio':
return AudioDataset.get_fields(n_src_features, n_tgt_features)
else:
raise ValueError("Data type not implemented")
def load_fields_from_vocab(vocab, data_type="text"):
"""
Load Field objects from `vocab.pt` file.
"""
vocab = dict(vocab)
n_src_features = len(collect_features(vocab, 'src'))
n_tgt_features = len(collect_features(vocab, 'tgt'))
fields = get_fields(data_type, n_src_features, n_tgt_features)
for k, v in vocab.items():
# Hack. Can't pickle defaultdict :(
v.stoi = defaultdict(lambda: 0, v.stoi)
fields[k].vocab = v
return fields
def save_fields_to_vocab(fields):
"""
Save Vocab objects in Field objects to `vocab.pt` file.
"""
vocab = []
for k, f in fields.items():
if f is not None and 'vocab' in f.__dict__:
f.vocab.stoi = f.vocab.stoi
vocab.append((k, f.vocab))
return vocab
def merge_vocabs(vocabs, vocab_size=None, min_frequency=1):
"""
Merge individual vocabularies (assumed to be generated from disjoint
documents) into a larger vocabulary.
Args:
vocabs: `torchtext.vocab.Vocab` vocabularies to be merged
vocab_size: `int` the final vocabulary size. `None` for no limit.
min_frequency: `int` minimum frequency for word to be retained.
Return:
`torchtext.vocab.Vocab`
"""
merged = sum([vocab.freqs for vocab in vocabs], Counter())
return torchtext.vocab.Vocab(merged,
specials=[UNK_WORD, PAD_WORD,
BOS_WORD, EOS_WORD],
max_size=vocab_size,
min_freq=min_frequency)
def get_num_features(data_type, corpus_file, side):
"""
Args:
data_type (str): type of the source input.
Options are [text|img|audio].
corpus_file (str): file path to get the features.
side (str): for source or for target.
Returns:
number of features on `side`.
"""
assert side in ["src", "tgt"]
if data_type == 'text':
return TextDataset.get_num_features(corpus_file, side)
elif data_type == 'img':
return ImageDataset.get_num_features(corpus_file, side)
elif data_type == 'audio':
return AudioDataset.get_num_features(corpus_file, side)
else:
raise ValueError("Data type not implemented")
def make_features(batch, side, data_type='text'):
"""
Args:
batch (Tensor): a batch of source or target data.
side (str): for source or for target.
data_type (str): type of the source input.
Options are [text|img|audio].
Returns:
A sequence of src/tgt tensors with optional feature tensors
of size (len x batch).
"""
assert side in ['src', 'tgt']
if isinstance(batch.__dict__[side], tuple):
data = batch.__dict__[side][0]
else:
data = batch.__dict__[side]
feat_start = side + "_feat_"
keys = sorted([k for k in batch.__dict__ if feat_start in k])
features = [batch.__dict__[k] for k in keys]
levels = [data] + features
if data_type == 'text':
return torch.cat([level.unsqueeze(2) for level in levels], 2)
else:
return levels[0]
def collect_features(fields, side="src"):
"""
Collect features from Field object.
"""
assert side in ["src", "tgt"]
feats = []
for j in count():
key = side + "_feat_" + str(j)
if key not in fields:
break
feats.append(key)
return feats
def collect_feature_vocabs(fields, side):
"""
Collect feature Vocab objects from Field object.
"""
assert side in ['src', 'tgt']
feature_vocabs = []
for j in count():
key = side + "_feat_" + str(j)
if key not in fields:
break
feature_vocabs.append(fields[key].vocab)
return feature_vocabs
def build_dataset(fields, data_type, src_data_iter=None, src_path=None,
src_dir=None, tgt_data_iter=None, tgt_path=None,
src_seq_length=0, tgt_seq_length=0,
src_seq_length_trunc=0, tgt_seq_length_trunc=0,
dynamic_dict=True, sample_rate=0,
window_size=0, window_stride=0, window=None,
normalize_audio=True, use_filter_pred=True,
image_channel_size=3):
"""
Build src/tgt examples iterator from corpus files, also extract
number of features.
"""
def _make_examples_nfeats_tpl(data_type, src_data_iter, src_path, src_dir,
src_seq_length_trunc, sample_rate,
window_size, window_stride,
window, normalize_audio,
image_channel_size=3):
"""
Process the corpus into (example_dict iterator, num_feats) tuple
on source side for different 'data_type'.
"""
if data_type == 'text':
src_examples_iter, num_src_feats = \
TextDataset.make_text_examples_nfeats_tpl(
src_data_iter, src_path, src_seq_length_trunc, "src")
elif data_type == 'img':
src_examples_iter, num_src_feats = \
ImageDataset.make_image_examples_nfeats_tpl(
src_data_iter, src_path, src_dir, image_channel_size)
elif data_type == 'audio':
if src_data_iter:
raise ValueError("""Data iterator for AudioDataset isn't
implemented""")
if src_path is None:
raise ValueError("AudioDataset requires a non None path")
src_examples_iter, num_src_feats = \
AudioDataset.make_audio_examples_nfeats_tpl(
src_path, src_dir, sample_rate,
window_size, window_stride, window,
normalize_audio)
return src_examples_iter, num_src_feats
src_examples_iter, num_src_feats = \
_make_examples_nfeats_tpl(data_type, src_data_iter, src_path, src_dir,
src_seq_length_trunc, sample_rate,
window_size, window_stride,
window, normalize_audio,
image_channel_size=image_channel_size)
# For all data types, the tgt side corpus is in form of text.
tgt_examples_iter, num_tgt_feats = \
TextDataset.make_text_examples_nfeats_tpl(
tgt_data_iter, tgt_path, tgt_seq_length_trunc, "tgt")
if data_type == 'text':
dataset = TextDataset(fields, src_examples_iter, tgt_examples_iter,
num_src_feats, num_tgt_feats,
src_seq_length=src_seq_length,
tgt_seq_length=tgt_seq_length,
dynamic_dict=dynamic_dict,
use_filter_pred=use_filter_pred)
elif data_type == 'img':
dataset = ImageDataset(fields, src_examples_iter, tgt_examples_iter,
num_src_feats, num_tgt_feats,
tgt_seq_length=tgt_seq_length,
use_filter_pred=use_filter_pred,
image_channel_size=image_channel_size)
elif data_type == 'audio':
dataset = AudioDataset(fields, src_examples_iter, tgt_examples_iter,
tgt_seq_length=tgt_seq_length,
use_filter_pred=use_filter_pred)
return dataset
def _build_field_vocab(field, counter, **kwargs):
specials = list(OrderedDict.fromkeys(
tok for tok in [field.unk_token, field.pad_token, field.init_token,
field.eos_token]
if tok is not None))
field.vocab = field.vocab_cls(counter, specials=specials, **kwargs)
def build_vocab(train_dataset_files, fields, data_type, share_vocab,
src_vocab_path, src_vocab_size, src_words_min_frequency,
tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency):
"""
Args:
train_dataset_files: a list of train dataset pt file.
fields (dict): fields to build vocab for.
data_type: "text", "img" or "audio"?
share_vocab(bool): share source and target vocabulary?
src_vocab_path(string): Path to src vocabulary file.
src_vocab_size(int): size of the source vocabulary.
src_words_min_frequency(int): the minimum frequency needed to
include a source word in the vocabulary.
tgt_vocab_path(string): Path to tgt vocabulary file.
tgt_vocab_size(int): size of the target vocabulary.
tgt_words_min_frequency(int): the minimum frequency needed to
include a target word in the vocabulary.
Returns:
Dict of Fields
"""
counter = {}
# Prop src from field to get lower memory using when training with image
if data_type == 'img' or data_type == 'audio':
fields.pop("src")
for k in fields:
counter[k] = Counter()
# Load vocabulary
src_vocab = load_vocabulary(src_vocab_path, tag="source")
if src_vocab is not None:
src_vocab_size = len(src_vocab)
logger.info('Loaded source vocab has %d tokens.' % src_vocab_size)
for i, token in enumerate(src_vocab):
# keep the order of tokens specified in the vocab file by
# adding them to the counter with decreasing counting values
counter['src'][token] = src_vocab_size - i
tgt_vocab = load_vocabulary(tgt_vocab_path, tag="target")
if tgt_vocab is not None:
tgt_vocab_size = len(tgt_vocab)
logger.info('Loaded source vocab has %d tokens.' % tgt_vocab_size)
for i, token in enumerate(tgt_vocab):
counter['tgt'][token] = tgt_vocab_size - i
for index, path in enumerate(train_dataset_files):
dataset = torch.load(path)
logger.info(" * reloading %s." % path)
for ex in dataset.examples:
for k in fields:
val = getattr(ex, k, None)
if not fields[k].sequential:
continue
elif k == 'src' and src_vocab:
continue
elif k == 'tgt' and tgt_vocab:
continue
counter[k].update(val)
# Drop the none-using from memory but keep the last
if (index < len(train_dataset_files) - 1):
dataset.examples = None
gc.collect()
del dataset.examples
gc.collect()
del dataset
gc.collect()
_build_field_vocab(fields["tgt"], counter["tgt"],
max_size=tgt_vocab_size,
min_freq=tgt_words_min_frequency)
logger.info(" * tgt vocab size: %d." % len(fields["tgt"].vocab))
# All datasets have same num of n_tgt_features,
# getting the last one is OK.
for j in range(dataset.n_tgt_feats):
key = "tgt_feat_" + str(j)
_build_field_vocab(fields[key], counter[key])
logger.info(" * %s vocab size: %d." % (key,
len(fields[key].vocab)))
if data_type == 'text':
_build_field_vocab(fields["src"], counter["src"],
max_size=src_vocab_size,
min_freq=src_words_min_frequency)
logger.info(" * src vocab size: %d." % len(fields["src"].vocab))
# All datasets have same num of n_src_features,
# getting the last one is OK.
for j in range(dataset.n_src_feats):
key = "src_feat_" + str(j)
_build_field_vocab(fields[key], counter[key])
logger.info(" * %s vocab size: %d." %
(key, len(fields[key].vocab)))
# Merge the input and output vocabularies.
if share_vocab:
# `tgt_vocab_size` is ignored when sharing vocabularies
logger.info(" * merging src and tgt vocab...")
merged_vocab = merge_vocabs(
[fields["src"].vocab, fields["tgt"].vocab],
vocab_size=src_vocab_size,
min_frequency=src_words_min_frequency)
fields["src"].vocab = merged_vocab
fields["tgt"].vocab = merged_vocab
return fields
def load_vocabulary(vocabulary_path, tag=""):
"""
Loads a vocabulary from the given path.
:param vocabulary_path: path to load vocabulary from
:param tag: tag for vocabulary (only used for logging)
:return: vocabulary or None if path is null
"""
vocabulary = None
if vocabulary_path:
vocabulary = []
logger.info("Loading {} vocabulary from {}".format(tag,
vocabulary_path))
if not os.path.exists(vocabulary_path):
raise RuntimeError(
"{} vocabulary not found at {}!".format(tag, vocabulary_path))
else:
with codecs.open(vocabulary_path, 'r', 'utf-8') as f:
for line in f:
if len(line.strip()) == 0:
continue
word = line.strip().split()[0]
vocabulary.append(word)
return vocabulary
class OrderedIterator(torchtext.data.Iterator):
""" Ordered Iterator Class """
def create_batches(self):
""" Create batches """
if self.train:
def _pool(data, random_shuffler):
for p in torchtext.data.batch(data, self.batch_size * 100):
p_batch = torchtext.data.batch(
sorted(p, key=self.sort_key),
self.batch_size, self.batch_size_fn)
for b in random_shuffler(list(p_batch)):
yield b
self.batches = _pool(self.data(), self.random_shuffler)
else:
self.batches = []
for b in torchtext.data.batch(self.data(), self.batch_size,
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))
class DatasetLazyIter(object):
""" An Ordered Dataset Iterator, supporting multiple datasets,
and lazy loading.
Args:
datsets (list): a list of datasets, which are lazily loaded.
fields (dict): fields dict for the datasets.
batch_size (int): batch size.
batch_size_fn: custom batch process function.
device: the GPU device.
is_train (bool): train or valid?
"""
def __init__(self, datasets, fields, batch_size, batch_size_fn,
device, is_train):
self.datasets = datasets
self.fields = fields
self.batch_size = batch_size
self.batch_size_fn = batch_size_fn
self.device = device
self.is_train = is_train
self.cur_iter = self._next_dataset_iterator(datasets)
# We have at least one dataset.
assert self.cur_iter is not None
def __iter__(self):
dataset_iter = (d for d in self.datasets)
while self.cur_iter is not None:
for batch in self.cur_iter:
yield batch
self.cur_iter = self._next_dataset_iterator(dataset_iter)
def __len__(self):
# We return the len of cur_dataset, otherwise we need to load
# all datasets to determine the real len, which loses the benefit
# of lazy loading.
assert self.cur_iter is not None
return len(self.cur_iter)
def _next_dataset_iterator(self, dataset_iter):
try:
# Drop the current dataset for decreasing memory
if hasattr(self, "cur_dataset"):
self.cur_dataset.examples = None
gc.collect()
del self.cur_dataset
gc.collect()
self.cur_dataset = next(dataset_iter)
except StopIteration:
return None
# We clear `fields` when saving, restore when loading.
self.cur_dataset.fields = self.fields
# Sort batch by decreasing lengths of sentence required by pytorch.
# sort=False means "Use dataset's sortkey instead of iterator's".
return OrderedIterator(
dataset=self.cur_dataset, batch_size=self.batch_size,
batch_size_fn=self.batch_size_fn,
device=self.device, train=self.is_train,
sort=False, sort_within_batch=True,
repeat=False)
def build_dataset_iter(datasets, fields, opt, is_train=True):
"""
This returns user-defined train/validate data iterator for the trainer
to iterate over. We implement simple ordered iterator strategy here,
but more sophisticated strategy like curriculum learning is ok too.
"""
batch_size = opt.batch_size if is_train else opt.valid_batch_size
if is_train and opt.batch_type == "tokens":
def batch_size_fn(new, count, sofar):
"""
In token batching scheme, the number of sequences is limited
such that the total number of src/tgt tokens (including padding)
in a batch <= batch_size
"""
# Maintains the longest src and tgt length in the current batch
global max_src_in_batch, max_tgt_in_batch
# Reset current longest length at a new batch (count=1)
if count == 1:
max_src_in_batch = 0
max_tgt_in_batch = 0
# Src: <bos> w1 ... wN <eos>
max_src_in_batch = max(max_src_in_batch, len(new.src) + 2)
# Tgt: w1 ... wN <eos>
max_tgt_in_batch = max(max_tgt_in_batch, len(new.tgt) + 1)
src_elements = count * max_src_in_batch
tgt_elements = count * max_tgt_in_batch
return max(src_elements, tgt_elements)
else:
batch_size_fn = None
if opt.gpu_ranks:
device = "cuda"
else:
device = "cpu"
return DatasetLazyIter(datasets, fields, batch_size, batch_size_fn,
device, is_train)
def lazily_load_dataset(corpus_type, opt):
"""
Dataset generator. Don't do extra stuff here, like printing,
because they will be postponed to the first loading time.
Args:
corpus_type: 'train' or 'valid'
Returns:
A list of dataset, the dataset(s) are lazily loaded.
"""
assert corpus_type in ["train", "valid"]
def _lazy_dataset_loader(pt_file, corpus_type):
dataset = torch.load(pt_file)
logger.info('Loading %s dataset from %s, number of examples: %d' %
(corpus_type, pt_file, len(dataset)))
return dataset
# Sort the glob output by file name (by increasing indexes).
pts = sorted(glob.glob(opt.data + '.' + corpus_type + '.[0-9]*.pt'))
if pts:
for pt in pts:
yield _lazy_dataset_loader(pt, corpus_type)
else:
# Only one inputters.*Dataset, simple!
pt = opt.data + '.' + corpus_type + '.pt'
yield _lazy_dataset_loader(pt, corpus_type)
def _load_fields(dataset, data_type, opt, checkpoint):
if checkpoint is not None:
logger.info('Loading vocab from checkpoint at %s.' % opt.train_from)
fields = load_fields_from_vocab(
checkpoint['vocab'], data_type)
else:
fields = load_fields_from_vocab(
torch.load(opt.data + '.vocab.pt'), data_type)
fields = dict([(k, f) for (k, f) in fields.items()
if k in dataset.examples[0].__dict__])
if data_type == 'text':
logger.info(' * vocabulary size. source = %d; target = %d' %
(len(fields['src'].vocab), len(fields['tgt'].vocab)))
else:
logger.info(' * vocabulary size. target = %d' %
(len(fields['tgt'].vocab)))
return fields
def _collect_report_features(fields):
src_features = collect_features(fields, side='src')
tgt_features = collect_features(fields, side='tgt')
return src_features, tgt_features