Skip to content

Commit

Permalink
Add size method to BacktranslationDataset + misc fixes (#325)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #325

RoundRobinZipDataset requires size(index) method implemented in every dataset used. Also added missing return statements in a few methods.

Reviewed By: liezl200

Differential Revision: D10457159

fbshipit-source-id: 01856eb455f2f3a21e7fb723129ff35fbe29e0ae
  • Loading branch information
Deepak Gopinath authored and facebook-github-bot committed Oct 23, 2018
1 parent 1aae5f6 commit 613ffee
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
18 changes: 13 additions & 5 deletions fairseq/data/backtranslation_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
"""
self.tgt_dataset = language_pair_dataset.LanguagePairDataset(
src=tgt_dataset,
src_sizes=None,
src_sizes=tgt_dataset.sizes,
src_dict=tgt_dict,
tgt=None,
tgt_sizes=None,
Expand Down Expand Up @@ -141,19 +141,19 @@ def collater(self, samples):

def get_dummy_batch(self, num_tokens, max_positions):
""" Just use the tgt dataset get_dummy_batch """
self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)
return self.tgt_dataset.get_dummy_batch(num_tokens, max_positions)

def num_tokens(self, index):
""" Just use the tgt dataset num_tokens """
self.tgt_dataset.num_tokens(index)
return self.tgt_dataset.num_tokens(index)

def ordered_indices(self):
""" Just use the tgt dataset ordered_indices """
self.tgt_dataset.ordered_indices
return self.tgt_dataset.ordered_indices()

def valid_size(self, index, max_positions):
""" Just use the tgt dataset size """
self.tgt_dataset.valid_size(index, max_positions)
return self.tgt_dataset.valid_size(index, max_positions)

def _generate_hypotheses(self, sample):
"""
Expand All @@ -171,3 +171,11 @@ def _generate_hypotheses(self, sample):
),
)
return hypos

def size(self, index):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``.
Here, we return src dataset size as tgt dataset size as an approximation.
We do not know src size until we backtranslate and generate src sentences.
"""
return (self.tgt_dataset.size(index), self.tgt_dataset.size(index))
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ class TestDataset(torch.utils.data.Dataset):
def __init__(self, data):
super().__init__()
self.data = data
self.sizes = None

def __getitem__(self, index):
return self.data[index]
Expand Down

0 comments on commit 613ffee

Please sign in to comment.