Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Document translation re revised #59

Open
wants to merge 20 commits into
base: main
Choose a base branch
from

Conversation

HaukurPall
Copy link
Collaborator

A very large refactoring of the document translation code to make it

  • Multilingual
  • Better support backtranslation data
  • Support training file inputs as prefixes

haukurb and others added 17 commits September 12, 2023 11:28
- Removing unnecessary code
- Adding plenty of comments
- Grouping task/component arguments into dataclasses
- Updating the task for fairseq version 0.12.1
- Removed interleaving caching (it was failing)
…ut to the encoding function is a string, not a list.
… transation. A large change in how the training data is loaded, now we read all the files in the data-dir and attempt to load based on a prefix string.
- Checking wether noised examples become too long (needs rework on List[examples] to make sense)
- Adding extra padding for max_seq len when deciding if examples are too long
- Changing behaviour max_seq_len validation in BT dataset
@HaukurPall HaukurPall force-pushed the document-translation-re-revised branch from 21c2e4e to f7f2b89 Compare September 12, 2023 11:28
# Experimental: add BT information
bt_info = self.encoder.encode("BT") if is_bt else torch.tensor([], dtype=torch.long)
with data_utils.numpy_seed(self.seed, self.epoch, index):
insert_sep = np.random.randint(2, dtype=bool)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

change this to "should_insert_sep"


# This language code handling is like the mBart-50 model and nllb-200
src_out = torch.cat(
[torch.tensor([self.dictionary.index(src_langs[0])])]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this would be more readable with an explicit step with a named variable instead; like source_lang_code_as_bpe or something.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also applies to target language code

[torch.tensor([self.dictionary.index(tgt_langs[0])])] + tgt_out + [torch.tensor([self.dictionary.eos()])]
)

if len(src_out) > self.max_seq_len or len(tgt_out) > self.max_seq_len:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want truncation as the default behavior or is the default behavior 'undefiend'?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I.e. we could supply this as a config parameter

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also not sure about this truncation strategy, is it better to remove the middle rather than the end?

tgt_string = self.src_dict.string(example["target"])
print(f"{self.encoder.bpe.decode(src_string)}")
print(f"{self.encoder.bpe.decode(tgt_string)}")
print()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be superceded by log_example and could be removed?

return np.cumsum(lengths) - lengths


class KEYS:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

KEYS and many of the other definitions here should probably be moved to a separate file.

data_dir_path = paths[(epoch - 1) % len(paths)]

# if a split contains a comma, we should crash - since that is no longer supported
assert "," not in split, "Split should not contain a comma"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But when split contains the cli input of --train-subset, it is allowed to have commas according to the function documentation

file_type = "src"
elif file_type == lang2:
file_type = "tgt"
return {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be better as a namedtuple (even if the namedtuple constructor is defined locally).

"tgt_path": tgt_dataset["path"],
"align_path": align_dataset["path"] if align_dataset is not None else None,
"is_bt": datasets[0]["name"].startswith(self.cfg.bt_subset),
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should also be a namedtuple.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or dataclass


if len(datasets) != 1:
parallel_datasets = [dataset for dataset in datasets if not dataset.is_bt]
bt_datasets = [dataset for dataset in datasets if dataset.is_bt]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this not be done inside the constructor of IndexedParallelBTDocumentsDataset?

else:
dataset = datasets[0]

dataset.set_epoch(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add comment that this can take a while since it causes an interleave datasets call.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants