/
datasets.py
1039 lines (901 loc) · 35.6 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding: utf-8
"""
Dataset module
"""
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
from torch.utils.data import BatchSampler, DataLoader, Dataset, Sampler
from joeynmt.batch import Batch
from joeynmt.config import ConfigurationError
from joeynmt.helpers import read_list_from_file
from joeynmt.helpers_for_ddp import (
DistributedSubsetSampler,
RandomSubsetSampler,
get_logger,
use_ddp,
)
from joeynmt.tokenizers import BasicTokenizer
logger = get_logger(__name__)
CPU_DEVICE = torch.device("cpu")
class BaseDataset(Dataset):
"""
BaseDataset which loads and looks up data.
- holds pointer to tokenizers, encoding functions.
:param path: path to data directory
:param src_lang: source language code, i.e. `en`
:param trg_lang: target language code, i.e. `de`
:param has_trg: bool indicator if trg exists
:param has_prompt: bool indicator if prompt exists
:param split: bool indicator for train set or not
:param tokenizer: tokenizer objects
:param sequence_encoder: encoding functions
"""
# pylint: disable=too-many-instance-attributes
def __init__(
self,
path: str,
src_lang: str,
trg_lang: str,
split: str = "train",
has_trg: bool = False,
has_prompt: Dict[str, bool] = None,
tokenizer: Dict[str, BasicTokenizer] = None,
sequence_encoder: Dict[str, Callable] = None,
random_subset: int = -1
):
self.path = path
self.src_lang = src_lang
self.trg_lang = trg_lang
self.has_trg = has_trg
self.split = split
if self.split == "train":
assert self.has_trg
self.tokenizer = tokenizer
self.sequence_encoder = sequence_encoder
self.has_prompt = has_prompt
# for random subsampling
self.random_subset = random_subset
self.indices = None # range(self.__len__())
# Note: self.indices is kept sorted, even if shuffle = True in make_iter()
# (Sampler yields permuted indices)
self.seed = 1 # random seed for generator
def reset_indices(self, random_subset: int = None):
# should be called after data are loaded.
# otherwise self.__len__() is undefined.
self.indices = list(range(self.__len__())) if self.__len__() > 0 else []
if random_subset is not None:
self.random_subset = random_subset
if 0 < self.random_subset:
assert (self.split != "test" and self.random_subset < self.__len__()), \
("Can only subsample from train or dev set "
f"larger than {self.random_subset}.")
def load_data(self, path: Path, **kwargs) -> Any:
"""
load data
- preprocessing (lowercasing etc) is applied here.
"""
raise NotImplementedError
def get_item(self, idx: int, lang: str, is_train: bool = None) -> List[str]:
"""
seek one src/trg item of the given index.
- tokenization is applied here.
- length-filtering, bpe-dropout etc also triggered if self.split == "train"
"""
# workaround if tokenizer prepends an extra escape symbol before lang_tang ...
def _remove_escape(item):
if (
item is not None and self.tokenizer[lang] is not None
and item[0] == self.tokenizer[lang].SPACE_ESCAPE
and item[1] in self.tokenizer[lang].lang_tags
):
return item[1:]
return item
line, prompt = self.lookup_item(idx, lang)
is_train = self.split == "train" if is_train is None else is_train
item = _remove_escape(self.tokenizer[lang](line, is_train=is_train))
if self.has_prompt[lang] and prompt is not None:
prompt = _remove_escape(self.tokenizer[lang](prompt, is_train=False))
item = item if item is not None else []
max_length = self.tokenizer[lang].max_length
if 0 < max_length < len(prompt) + len(item) + 1:
# truncate prompt
offset = max_length - len(item) - 1
if prompt[0] in self.tokenizer[lang].lang_tags:
prompt = [prompt[0]] + prompt[-(offset - 1):]
else:
prompt = prompt[-offset:]
item = prompt + [self.tokenizer[lang].sep_token] + item
return item
def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]:
raise NotImplementedError
def __getitem__(self, idx: Union[int, str]) -> Tuple[int, List[str], List[str]]:
"""
lookup one item pair of the given index.
:param idx: index of the instance to lookup
:return:
- index # needed to recover the original order
- tokenized src sentences
- tokenized trg sentences
"""
if idx > self.__len__():
raise KeyError
src, trg = None, None
src = self.get_item(idx=idx, lang=self.src_lang)
if self.has_trg or self.has_prompt[self.trg_lang]:
trg = self.get_item(idx=idx, lang=self.trg_lang)
if trg is None:
src = None
return idx, src, trg
def get_list(self,
lang: str,
tokenized: bool = False,
subsampled: bool = True) -> Union[List[str], List[List[str]]]:
"""get data column-wise."""
raise NotImplementedError
@property
def src(self) -> List[str]:
"""get detokenized preprocessed data in src language."""
return self.get_list(self.src_lang, tokenized=False, subsampled=True)
@property
def trg(self) -> List[str]:
"""get detokenized preprocessed data in trg language."""
return (
self.get_list(self.trg_lang, tokenized=False, subsampled=True)
if self.has_trg else []
)
def collate_fn(
self,
batch: List[Tuple],
pad_index: int,
eos_index: int,
device: torch.device = CPU_DEVICE,
) -> Batch:
"""
Custom collate function.
See https://pytorch.org/docs/stable/data.html#dataloader-collate-fn for details.
Please override the batch class here. (not in TrainManager)
:param batch:
:param pad_index:
:param eos_index:
:param device:
:return: joeynmt batch object
"""
idx, src_list, trg_list = zip(*batch)
assert len(batch) == len(src_list) == len(trg_list), (len(batch), len(src_list))
assert all(s is not None for s in src_list), src_list
src, src_length, src_prompt_mask = self.sequence_encoder[
self.src_lang](src_list, bos=False, eos=True)
if self.has_trg or self.has_prompt[self.trg_lang]:
if self.has_trg:
assert all(t is not None for t in trg_list), trg_list
trg, _, trg_prompt_mask = self.sequence_encoder[self.trg_lang](
trg_list, bos=True, eos=self.has_trg
) # no EOS if not self.has_trg
else:
assert all(t is None for t in trg_list)
trg, trg_prompt_mask = None, None # Note: we don't need trg_length!
return Batch(
src=torch.tensor(src).long(),
src_length=torch.tensor(src_length).long(),
src_prompt_mask=(
torch.tensor(src_prompt_mask).long()
if self.has_prompt[self.src_lang] else None
),
trg=torch.tensor(trg).long() if trg else None,
trg_prompt_mask=(
torch.tensor(trg_prompt_mask).long()
if self.has_prompt[self.trg_lang] else None
),
indices=torch.tensor(idx).long(),
device=device,
pad_index=pad_index,
eos_index=eos_index,
is_train=self.split == "train",
)
def make_iter(
self,
batch_size: int,
batch_type: str = "sentence",
seed: int = 42,
shuffle: bool = False,
num_workers: int = 0,
pad_index: int = 1,
eos_index: int = 3,
device: torch.device = CPU_DEVICE,
generator_state: torch.Tensor = None,
) -> DataLoader:
"""
Returns a torch DataLoader for a torch Dataset. (no bucketing)
:param batch_size: size of the batches the iterator prepares
:param batch_type: measure batch size by sentence count or by token count
:param seed: random seed for shuffling
:param shuffle: whether to shuffle the order of sequences before each epoch
(for testing, no effect even if set to True; generator is
still used for random subsampling, but not for permutation!)
:param num_workers: number of cpus for multiprocessing
:param pad_index:
:param eos_index:
:param device:
:param generator_state:
:return: torch DataLoader
"""
shuffle = shuffle and self.split == "train"
# for decoding in DDP, we cannot use TokenBatchSampler
if use_ddp() and self.split != "train":
assert batch_type == "sentence", self
generator = torch.Generator()
generator.manual_seed(seed)
if generator_state is not None:
generator.set_state(generator_state)
# define sampler which yields an integer
sampler: Sampler[int]
if use_ddp(): # use ddp
sampler = DistributedSubsetSampler(
self, shuffle=shuffle, drop_last=True, generator=generator
)
else:
sampler = RandomSubsetSampler(self, shuffle=shuffle, generator=generator)
# batch sampler which yields a list of integers
if batch_type == "sentence":
batch_sampler = SentenceBatchSampler(
sampler, batch_size=batch_size, drop_last=False, seed=seed
)
elif batch_type == "token":
batch_sampler = TokenBatchSampler(
sampler, batch_size=batch_size, drop_last=False, seed=seed
)
else:
raise ConfigurationError(f"{batch_type}: Unknown batch type")
# initialize generator seed
batch_sampler.set_seed(seed) # set seed and resample
# ensure that sequence_encoder (padding func) exists
assert self.sequence_encoder[self.src_lang] is not None
if self.has_trg:
assert self.sequence_encoder[self.trg_lang] is not None
# data iterator
return DataLoader(
dataset=self,
batch_sampler=batch_sampler,
collate_fn=partial(
self.collate_fn,
eos_index=eos_index,
pad_index=pad_index,
device=device
),
num_workers=num_workers
)
def __len__(self) -> int:
raise NotImplementedError
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(split={self.split}, len={self.__len__()}, "
f'src_lang="{self.src_lang}", trg_lang="{self.trg_lang}", '
f"has_trg={self.has_trg}, random_subset={self.random_subset}, "
f"has_src_prompt={self.has_prompt[self.src_lang]}, "
f"has_trg_prompt={self.has_prompt[self.trg_lang]})"
)
class PlaintextDataset(BaseDataset):
"""
PlaintextDataset which stores plain text pairs.
- used for text file data in the format of one sentence per line.
"""
def __init__(
self,
path: str,
src_lang: str,
trg_lang: str,
split: str = "train",
has_trg: bool = False,
has_prompt: Dict[str, bool] = None,
tokenizer: Dict[str, BasicTokenizer] = None,
sequence_encoder: Dict[str, Callable] = None,
random_subset: int = -1,
**kwargs
):
super().__init__(
path=path,
src_lang=src_lang,
trg_lang=trg_lang,
split=split,
has_trg=has_trg,
has_prompt=has_prompt,
tokenizer=tokenizer,
sequence_encoder=sequence_encoder,
random_subset=random_subset
)
# load data
self.data = self.load_data(path, **kwargs)
self.reset_indices()
def load_data(self, path: str, **kwargs) -> Any:
def _pre_process(seq, lang):
if self.tokenizer[lang] is not None:
seq = [self.tokenizer[lang].pre_process(s) for s in seq if len(s) > 0]
return seq
path = Path(path)
src_file = path.with_suffix(f"{path.suffix}.{self.src_lang}")
assert src_file.is_file(), f"{src_file} not found. Abort."
src_list = read_list_from_file(src_file)
data = {self.src_lang: _pre_process(src_list, self.src_lang)}
if self.has_trg:
trg_file = path.with_suffix(f"{path.suffix}.{self.trg_lang}")
assert trg_file.is_file(), f"{trg_file} not found. Abort."
trg_list = read_list_from_file(trg_file)
data[self.trg_lang] = _pre_process(trg_list, self.trg_lang)
assert len(src_list) == len(trg_list)
return data
def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]:
try:
line = self.data[lang][idx]
prompt = (
self.data[f"{lang}_prompt"][idx]
if f"{lang}_prompt" in self.data else None
)
return line, prompt
except Exception as e:
logger.error(idx, e)
raise ValueError from e
def get_list(self,
lang: str,
tokenized: bool = False,
subsampled: bool = True) -> Union[List[str], List[List[str]]]:
"""
Return list of preprocessed sentences in the given language.
(not length-filtered, no bpe-dropout)
"""
indices = self.indices if subsampled else range(self.__len__())
item_list = []
for idx in indices:
item, _ = self.lookup_item(idx, lang)
if tokenized:
item = self.tokenizer[lang](item, is_train=False)
item_list.append(item)
assert len(indices) == len(item_list), (len(indices), len(item_list))
return item_list
def __len__(self) -> int:
return len(self.data[self.src_lang])
class TsvDataset(BaseDataset):
"""
TsvDataset which handles data in tsv format.
- file_name should be specified without extension `.tsv`
- needs src_lang and trg_lang (i.e. `en`, `de`) in header.
see: test/data/toy/dev.tsv
"""
def __init__(
self,
path: str,
src_lang: str,
trg_lang: str,
split: str = "train",
has_trg: bool = False,
has_prompt: Dict[str, bool] = None,
tokenizer: Dict[str, BasicTokenizer] = None,
sequence_encoder: Dict[str, Callable] = None,
random_subset: int = -1,
**kwargs
):
super().__init__(
path=path,
src_lang=src_lang,
trg_lang=trg_lang,
split=split,
has_trg=has_trg,
has_prompt=has_prompt,
tokenizer=tokenizer,
sequence_encoder=sequence_encoder,
random_subset=random_subset
)
# load tsv file
self.df = self.load_data(path, **kwargs)
self.reset_indices()
def load_data(self, path: str, **kwargs) -> Any:
path = Path(path)
file_path = path.with_suffix(f"{path.suffix}.tsv")
assert file_path.is_file(), f"{file_path} not found. Abort."
try:
import pandas as pd # pylint: disable=import-outside-toplevel
# TODO: use `chunksize` for online data loading.
df = pd.read_csv(
file_path.as_posix(),
sep="\t",
header=0,
encoding="utf-8",
index_col=None
)
df = df.dropna()
df = df.reset_index()
assert self.src_lang in df.columns
df[self.src_lang
] = df[self.src_lang].apply(self.tokenizer[self.src_lang].pre_process)
if self.trg_lang not in df.columns:
self.has_trg = False
assert self.split == "test"
if self.has_trg:
df[self.trg_lang] = df[self.trg_lang].apply(
self.tokenizer[self.trg_lang].pre_process
)
if f"{self.src_lang}_prompt" in df.columns:
self.has_prompt[self.src_lang] = True
df[f"{self.src_lang}_prompt"] = df[f"{self.src_lang}_prompt"].apply(
self.tokenizer[self.src_lang].pre_process, allow_empty=True
)
if f"{self.trg_lang}_prompt" in df.columns:
self.has_prompt[self.trg_lang] = True
df[f"{self.trg_lang}_prompt"] = df[f"{self.trg_lang}_prompt"].apply(
self.tokenizer[self.trg_lang].pre_process, allow_empty=True
)
return df
except ImportError as e:
logger.error(e)
raise ImportError from e
def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]:
try:
row = self.df.iloc[idx]
line = row[lang]
prompt = row.get(f"{lang}_prompt", None)
return line, prompt
except Exception as e:
logger.error(idx, e)
raise ValueError from e
def get_list(self,
lang: str,
tokenized: bool = False,
subsampled: bool = True) -> Union[List[str], List[List[str]]]:
indices = self.indices if subsampled else range(self.__len__())
df = self.df.iloc[indices]
return (
df[lang].apply(self.tokenizer[lang]).to_list()
if tokenized else df[lang].to_list()
)
def __len__(self) -> int:
return len(self.df)
class StreamDataset(BaseDataset):
"""
StreamDataset which interacts with stream inputs.
- called by `translate()` func in `prediction.py`.
"""
# pylint: disable=unused-argument
def __init__(
self,
path: str,
src_lang: str,
trg_lang: str,
split: str = "test",
has_trg: bool = False,
has_prompt: Dict[str, bool] = None,
tokenizer: Dict[str, BasicTokenizer] = None,
sequence_encoder: Dict[str, Callable] = None,
random_subset: int = -1,
**kwargs
):
super().__init__(
path=path,
src_lang=src_lang,
trg_lang=trg_lang,
split=split,
has_trg=has_trg,
has_prompt=has_prompt,
tokenizer=tokenizer,
sequence_encoder=sequence_encoder,
random_subset=random_subset
)
# place holder
self.cache = []
def _split_at_sep(self, line: str, prompt: str, lang: str, sep_token: str):
"""
Split string at sep_token
:param line: (non-empty) input string
:param prompt: input prompt
:param lang:
:param sep_token:
"""
if (
sep_token is not None and line is not None and sep_token in line
and prompt is None
):
line, prompt = line.split(sep_token)
if line:
line = self.tokenizer[lang].pre_process(line, allow_empty=False)
if prompt:
prompt = self.tokenizer[lang].pre_process(prompt, allow_empty=True)
self.has_prompt[lang] = True
return line, prompt
def set_item(
self,
src_line: str,
trg_line: Optional[str] = None,
src_prompt: Optional[str] = None,
trg_prompt: Optional[str] = None
) -> None:
"""
Set input text to the cache.
:param src_line: (non-empty) str
:param trg_line: Optional[str]
:param src_prompt: Optional[str]
:param trg_prompt: Optional[str]
"""
assert isinstance(src_line, str) and src_line.strip() != "", \
"The input sentence is empty! Please make sure " \
"that you are feeding a valid input."
src_line, src_prompt = self._split_at_sep(
src_line, src_prompt, self.src_lang, self.tokenizer[self.src_lang].sep_token
)
assert src_line is not None
trg_line, trg_prompt = self._split_at_sep(
trg_line, trg_prompt, self.trg_lang, self.tokenizer[self.trg_lang].sep_token
)
if self.has_trg:
assert trg_line is not None
self.cache.append((src_line, trg_line, src_prompt, trg_prompt))
self.reset_indices()
def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]:
# pylint: disable=no-else-return
try:
assert lang in [self.src_lang, self.trg_lang]
if lang == self.trg_lang:
assert self.has_trg or self.has_prompt[lang]
src_line, trg_line, src_prompt, trg_prompt = self.cache[idx]
if lang == self.src_lang:
return src_line, src_prompt
elif lang == self.trg_lang:
return trg_line, trg_prompt
else:
raise ValueError
except Exception as e:
logger.error(idx, e)
raise ValueError from e
def reset_cache(self):
self.cache = []
self.reset_indices()
def __len__(self) -> int:
return len(self.cache)
def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(split={self.split}, len={len(self.cache)}, "
f'src_lang="{self.src_lang}", trg_lang="{self.trg_lang}", '
f"has_trg={self.has_trg}, random_subset={self.random_subset}, "
f"has_src_prompt={self.has_prompt[self.src_lang]}, "
f"has_trg_prompt={self.has_prompt[self.trg_lang]})"
)
class BaseHuggingfaceDataset(BaseDataset):
"""
Wrapper for Huggingface's dataset object
cf.) https://huggingface.co/docs/datasets
"""
COLUMN_NAME = "sentence" # dummy column name. should be overriden.
def __init__(
self,
path: str,
src_lang: str,
trg_lang: str,
has_trg: bool = True,
has_prompt: Dict[str, bool] = None,
tokenizer: Dict[str, BasicTokenizer] = None,
sequence_encoder: Dict[str, Callable] = None,
random_subset: int = -1,
**kwargs,
):
super().__init__(
path=path,
src_lang=src_lang,
trg_lang=trg_lang,
split=kwargs["split"],
has_trg=has_trg,
has_prompt=has_prompt,
tokenizer=tokenizer,
sequence_encoder=sequence_encoder,
random_subset=random_subset,
)
# load data
self.dataset = self.load_data(path, **kwargs)
self._kwargs = kwargs # should contain arguments passed to `load_dataset()`
self.reset_indices()
def load_data(self, path: str, **kwargs) -> Any:
# pylint: disable=import-outside-toplevel
try:
from datasets import Dataset as Dataset_hf
from datasets import DatasetDict, config, load_dataset, load_from_disk
if Path(path, config.DATASET_STATE_JSON_FILENAME).exists() \
or Path(path, config.DATASETDICT_JSON_FILENAME).exists():
hf_dataset = load_from_disk(path)
if isinstance(hf_dataset, DatasetDict):
assert kwargs["split"] in hf_dataset
hf_dataset = hf_dataset[kwargs["split"]]
else:
hf_dataset = load_dataset(path, **kwargs)
assert isinstance(hf_dataset, Dataset_hf)
assert self.COLUMN_NAME in hf_dataset.features
return hf_dataset
except ImportError as e:
logger.error(e)
raise ImportError from e
def lookup_item(self, idx: int, lang: str) -> Tuple[str, str]:
try:
line = self.dataset[idx]
assert lang in line[self.COLUMN_NAME], (line, lang)
prompt = line.get(f"{lang}_prompt", None)
return line[self.COLUMN_NAME][lang], prompt
except Exception as e:
logger.error(idx, e)
raise ValueError from e
def get_list(self,
lang: str,
tokenized: bool = False,
subsampled: bool = True) -> Union[List[str], List[List[str]]]:
dataset = self.dataset
if subsampled:
dataset = dataset.filter(
lambda x, idx: idx in self.indices, with_indices=True
)
assert len(dataset) == len(self.indices), (len(dataset), len(self.indices))
if tokenized:
def _tok(item):
item[f'tok_{lang}'] = self.tokenizer[lang](item[self.COLUMN_NAME][lang])
return item
return dataset.map(_tok, desc=f"Tokenizing {lang}...")[f'tok_{lang}']
return dataset.flatten()[f'{self.COLUMN_NAME}.{lang}']
def __len__(self) -> int:
return self.dataset.num_rows
def __repr__(self) -> str:
ret = (
f"{self.__class__.__name__}(len={self.__len__()}, "
f'src_lang="{self.src_lang}", trg_lang="{self.trg_lang}", '
f"has_trg={self.has_trg}, random_subset={self.random_subset}, "
f"has_src_prompt={self.has_prompt[self.src_lang]}, "
f"has_trg_prompt={self.has_prompt[self.trg_lang]}"
)
for k, v in self._kwargs.items():
ret += f", {k}={v}"
ret += ")"
return ret
class HuggingfaceTranslationDataset(BaseHuggingfaceDataset):
"""
Wrapper for Huggingface's `datasets.features.Translation` class
cf.) https://github.com/huggingface/datasets/blob/master/src/datasets/features/translation.py
""" # noqa
COLUMN_NAME = "translation"
def load_data(self, path: str, **kwargs) -> Any:
dataset = super().load_data(path=path, **kwargs)
# pylint: disable=import-outside-toplevel
try:
from datasets.features import Translation as Translation_hf
assert isinstance(dataset.features[self.COLUMN_NAME], Translation_hf), \
f"Data type mismatch. Please cast `{self.COLUMN_NAME}` column to " \
"datasets.features.Translation class."
assert self.src_lang in dataset.features[self.COLUMN_NAME].languages
if self.has_trg:
assert self.trg_lang in dataset.features[self.COLUMN_NAME].languages
except ImportError as e:
logger.error(e)
raise ImportError from e
# preprocess (lowercase, pretokenize, etc.) + validity check
def _pre_process(item):
sl = self.src_lang
tl = self.trg_lang
item[self.COLUMN_NAME][sl] = self.tokenizer[sl].pre_process(
item[self.COLUMN_NAME][sl]
)
if self.has_trg:
item[self.COLUMN_NAME][tl] = self.tokenizer[tl].pre_process(
item[self.COLUMN_NAME][tl]
)
if self.has_prompt[sl]:
item[f"{sl}_prompt"] = self.tokenizer[sl].pre_process(
item[f"{sl}_prompt"], allow_empty=True
)
if self.has_prompt[tl]:
item[f"{tl}_prompt"] = self.tokenizer[tl].pre_process(
item[f"{tl}_prompt"], allow_empty=True
)
return item
def _drop_nan(item):
src_item = item[self.COLUMN_NAME][self.src_lang]
is_src_valid = src_item is not None and len(src_item) > 0
if self.has_trg:
trg_item = item[self.COLUMN_NAME][self.trg_lang]
is_trg_valid = trg_item is not None and len(trg_item) > 0
return is_src_valid and is_trg_valid
return is_src_valid
dataset = dataset.filter(_drop_nan, desc="Dropping NaN...")
dataset = dataset.map(_pre_process, desc="Preprocessing...")
return dataset
def build_dataset(
dataset_type: str,
path: str,
src_lang: str,
trg_lang: str,
split: str,
tokenizer: Dict = None,
sequence_encoder: Dict = None,
has_prompt: Dict = None,
random_subset: int = -1,
**kwargs,
):
"""
Builds a dataset.
:param dataset_type: (str) one of {`plain`, `tsv`, `stream`, `huggingface`}
:param path: (str) either a local file name or
dataset name to download from remote
:param src_lang: (str) language code for source
:param trg_lang: (str) language code for target
:param split: (str) one of {`train`, `dev`, `test`}
:param tokenizer: tokenizer objects for both source and target
:param sequence_encoder: encoding functions for both source and target
:param has_prompt: prompt indicators
:param random_subset: (int) number of random subset; -1 means no subsampling
:return: loaded Dataset
"""
dataset = None
has_trg = True # by default, we expect src-trg pairs
_placeholder = {src_lang: None, trg_lang: None}
tokenizer = _placeholder if tokenizer is None else tokenizer
sequence_encoder = _placeholder if sequence_encoder is None else sequence_encoder
has_prompt = _placeholder if has_prompt is None else has_prompt
if dataset_type == "plain":
if not Path(path).with_suffix(f"{Path(path).suffix}.{trg_lang}").is_file():
has_trg = False # no target is given -> create dataset from src only
dataset = PlaintextDataset(
path=path,
src_lang=src_lang,
trg_lang=trg_lang,
split=split,
has_trg=has_trg,
has_prompt=has_prompt,
tokenizer=tokenizer,
sequence_encoder=sequence_encoder,
random_subset=random_subset,
**kwargs,
)
elif dataset_type == "tsv":
dataset = TsvDataset(
path=path,
src_lang=src_lang,
trg_lang=trg_lang,
split=split,
has_trg=has_trg,
has_prompt=has_prompt,
tokenizer=tokenizer,
sequence_encoder=sequence_encoder,
random_subset=random_subset,
**kwargs,
)
elif dataset_type == "stream":
dataset = StreamDataset(
path=path,
src_lang=src_lang,
trg_lang=trg_lang,
split="test",
has_trg=False,
has_prompt=has_prompt,
tokenizer=tokenizer,
sequence_encoder=sequence_encoder,
random_subset=-1,
**kwargs,
)
elif dataset_type == "huggingface":
# "split" should be specified in kwargs
if "split" not in kwargs:
kwargs["split"] = "validation" if split == "dev" else split
dataset = HuggingfaceTranslationDataset(
path=path,
src_lang=src_lang,
trg_lang=trg_lang,
has_trg=has_trg,
has_prompt=has_prompt,
tokenizer=tokenizer,
sequence_encoder=sequence_encoder,
random_subset=random_subset,
**kwargs,
)
else:
raise ConfigurationError(f"{dataset_type}: Unknown dataset type.")
return dataset
class SentenceBatchSampler(BatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices based on num of instances.
An instance longer than dataset.max_len will be filtered out.
:param sampler: Base sampler. Can be any iterable object
:param batch_size: Size of mini-batch.
:param drop_last: If `True`, the sampler will drop the last batch if its size
would be less than `batch_size`
"""
def __init__(self, sampler: Sampler, batch_size: int, drop_last: bool, seed: int):
super().__init__(sampler, batch_size, drop_last)
self.seed = seed
@property
def num_samples(self) -> int:
"""
Returns number of samples in the dataset.
This may change during sampling.
Note: len(dataset) won't change during sampling.
Use len(dataset) instead, to retrieve the original dataset length.
"""
assert self.sampler.data_source.indices is not None
try:
return len(self.sampler)
except NotImplementedError as e: # pylint: disable=unused-variable # noqa: F841
return len(self.sampler.data_source.indices)
def __iter__(self):
batch = []
d = self.sampler.data_source
for idx in self.sampler:
_, src, trg = d[idx] # pylint: disable=unused-variable
if src is not None: # otherwise drop instance
batch.append(idx)
if len(batch) >= self.batch_size:
yield batch
batch = []
if len(batch) > 0:
if not self.drop_last:
yield batch
else:
logger.warning(f"Drop indices {batch}.")
def __len__(self) -> int:
# pylint: disable=no-else-return
if self.drop_last:
return self.num_samples // self.batch_size
else:
return (self.num_samples + self.batch_size - 1) // self.batch_size
def set_seed(self, seed: int) -> None:
assert seed is not None, seed
self.sampler.data_source.seed = seed
if hasattr(self.sampler, 'set_seed'):
self.sampler.set_seed(seed) # set seed and resample
elif hasattr(self.sampler, 'generator'):
self.sampler.generator.manual_seed(seed)
if self.num_samples < len(self.sampler.data_source):
logger.info(
"Sample random subset from %s data: n=%d, seed=%d",
self.sampler.data_source.split, self.num_samples, seed
)
def reset(self) -> None:
if hasattr(self.sampler, 'reset'):
self.sampler.reset()
def get_state(self):
if hasattr(self.sampler, 'generator'):
return self.sampler.generator.get_state()
return None
def set_state(self, state) -> None:
if hasattr(self.sampler, 'generator'):
self.sampler.generator.set_state(state)
class TokenBatchSampler(SentenceBatchSampler):
"""
Wraps another sampler to yield a mini-batch of indices based on num of tokens
(incl. padding). An instance longer than dataset.max_len or shorter than
dataset.min_len will be filtered out.
* no bucketing implemented
.. warning::