-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Document translation re revised #59
base: main
Are you sure you want to change the base?
Conversation
…g in batch_sampler
- Removing unnecessary code - Adding plenty of comments - Grouping task/component arguments into dataclasses - Updating the task for fairseq version 0.12.1 - Removed interleaving caching (it was failing)
…ut to the encoding function is a string, not a list.
… transation. A large change in how the training data is loaded, now we read all the files in the data-dir and attempt to load based on a prefix string.
- Checking wether noised examples become too long (needs rework on List[examples] to make sense) - Adding extra padding for max_seq len when deciding if examples are too long - Changing behaviour max_seq_len validation in BT dataset
21c2e4e
to
f7f2b89
Compare
# Experimental: add BT information | ||
bt_info = self.encoder.encode("BT") if is_bt else torch.tensor([], dtype=torch.long) | ||
with data_utils.numpy_seed(self.seed, self.epoch, index): | ||
insert_sep = np.random.randint(2, dtype=bool) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
change this to "should_insert_sep"
|
||
# This language code handling is like the mBart-50 model and nllb-200 | ||
src_out = torch.cat( | ||
[torch.tensor([self.dictionary.index(src_langs[0])])] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this would be more readable with an explicit step with a named variable instead; like source_lang_code_as_bpe or something.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also applies to target language code
[torch.tensor([self.dictionary.index(tgt_langs[0])])] + tgt_out + [torch.tensor([self.dictionary.eos()])] | ||
) | ||
|
||
if len(src_out) > self.max_seq_len or len(tgt_out) > self.max_seq_len: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want truncation as the default behavior or is the default behavior 'undefiend'?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I.e. we could supply this as a config parameter
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also not sure about this truncation strategy, is it better to remove the middle rather than the end?
tgt_string = self.src_dict.string(example["target"]) | ||
print(f"{self.encoder.bpe.decode(src_string)}") | ||
print(f"{self.encoder.bpe.decode(tgt_string)}") | ||
print() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be superceded by log_example and could be removed?
return np.cumsum(lengths) - lengths | ||
|
||
|
||
class KEYS: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
KEYS and many of the other definitions here should probably be moved to a separate file.
data_dir_path = paths[(epoch - 1) % len(paths)] | ||
|
||
# if a split contains a comma, we should crash - since that is no longer supported | ||
assert "," not in split, "Split should not contain a comma" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But when split contains the cli input of --train-subset, it is allowed to have commas according to the function documentation
file_type = "src" | ||
elif file_type == lang2: | ||
file_type = "tgt" | ||
return { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This would be better as a namedtuple (even if the namedtuple constructor is defined locally).
"tgt_path": tgt_dataset["path"], | ||
"align_path": align_dataset["path"] if align_dataset is not None else None, | ||
"is_bt": datasets[0]["name"].startswith(self.cfg.bt_subset), | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should also be a namedtuple.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or dataclass
|
||
if len(datasets) != 1: | ||
parallel_datasets = [dataset for dataset in datasets if not dataset.is_bt] | ||
bt_datasets = [dataset for dataset in datasets if dataset.is_bt] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could this not be done inside the constructor of IndexedParallelBTDocumentsDataset?
else: | ||
dataset = datasets[0] | ||
|
||
dataset.set_epoch(1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add comment that this can take a while since it causes an interleave datasets call.
A very large refactoring of the document translation code to make it