-
Notifications
You must be signed in to change notification settings - Fork 6.3k
/
replace_dataset.py
36 lines (28 loc) · 1.36 KB
/
replace_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from . import BaseWrapperDataset
class ReplaceDataset(BaseWrapperDataset):
"""Replaces tokens found in the dataset by a specified replacement token
Args:
dataset (~torch.utils.data.Dataset): dataset to replace tokens in
replace_map(Dictionary[int,int]): map of token to replace -> replacement token
offsets (List[int]): do not replace tokens before (from left if pos, right if neg) this offset. should be
as many as the number of objects returned by the underlying dataset __getitem__ method.
"""
def __init__(self, dataset, replace_map, offsets):
super().__init__(dataset)
assert len(replace_map) > 0
self.replace_map = replace_map
self.offsets = offsets
def __getitem__(self, index):
item = self.dataset[index]
is_tuple = isinstance(item, tuple)
srcs = item if is_tuple else [item]
for offset, src in zip(self.offsets, srcs):
for k, v in self.replace_map.items():
src_off = src[offset:] if offset >= 0 else src[:offset]
src_off.masked_fill_(src_off == k, v)
item = srcs if is_tuple else srcs[0]
return item