Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate prepare_seq2seq_batch #10287

Merged
merged 4 commits into from
Feb 22, 2021
Merged

Deprecate prepare_seq2seq_batch #10287

merged 4 commits into from
Feb 22, 2021

Conversation

sgugger
Copy link
Collaborator

@sgugger sgugger commented Feb 19, 2021

What does this PR do?

This PR officially deprecates prepare_seq2seq_batch to prepare its removal in Transformers v5. As discussed before, the proper way to prepare data for sequence-to-sequence tasks is to:

  • call the tokenizer on the inputs
  • call the tokenizers on the targets inside the context manager as_target_tokenizer

When only dealing with input texts without targets, just using the tokenizer call works perfectly well.

For mBART and mBART50 tokenizers the source and target language can be specified at init or changed at any time by setting the attributes .src_lang and .tgt_lang.

Here is a full example showing how to port old code using prepare_seq2seq_batch to the new way in the case of an mBART tokenizer (remove the mentiones of src_lang and tgt_lang for other tokenizers:

tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro')
batch = tokenizer.prepare_seq2seq_batch(src_texts, tgt_texts, padding=True, truncation=True, src_lang="en_XX", tgt_lang="ro_RO", return_tensors="pt")

becomes

tokenizer = MBartTokenizer.from_pretrained('facebook/mbart-large-en-ro', src_lang="en_XX", tgt_lang="ro_RO")
batch = tokenizer(src_texts, padding=True, truncation=True, return_tensors="pt")
with tokenizer.as_target_tokenizer():
    targets = tokenizer(tgt_texts, padding=True, truncation=True, return_tensors="pt")
batch["labels"] = targets["input_ids"]

The languages can be changed at any time with

tokenizer.src_lang = new_src_code
tokenizer.tgt_lang = new_tgt_code

This PR fixes a few things in MBartTokenizer and MBartTokenizerFast for the new API to work completely and removes all mentions of prepare_seq2seq_batch from the documentation and tests (except the test of that method in the common tests). It was already not used anymore in the seq2seq example run_seq2seq.

Comment on lines +88 to +89
def __init__(self, *args, tokenizer_file=None, src_lang=None, tgt_lang=None, **kwargs):
super().__init__(*args, tokenizer_file=tokenizer_file, src_lang=src_lang, tgt_lang=tgt_lang, **kwargs)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the ability to set src_lang and tgt_lang at init.

Comment on lines +111 to +119
@property
def src_lang(self) -> str:
return self._src_lang

@src_lang.setter
def src_lang(self, new_src_lang: str) -> None:
self._src_lang = new_src_lang
self.set_src_lang_special_tokens(self._src_lang)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add the proper setter for src_lang.

Comment on lines -134 to -138
# test None max_target_length
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=32, padding="max_length", return_tensors="pt"
)
self.assertEqual(32, batch["labels"].shape[1])
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This second part is irrelevant to test now.

Copy link
Member

@LysandreJik LysandreJik left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool that you upgraded the tests as well. Nothing to say apart that it would be great to have some ">>> " everywhere

docs/source/model_doc/mbart.rst Outdated Show resolved Hide resolved
docs/source/model_doc/mbart.rst Outdated Show resolved Hide resolved
@@ -85,10 +85,10 @@ Usage Example
]

model_name = 'google/pegasus-xsum'
torch_device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would love me some >>>

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks a lot for working on this :)

Left a few comments.

src/transformers/models/marian/tokenization_marian.py Outdated Show resolved Hide resolved
src/transformers/models/rag/modeling_rag.py Outdated Show resolved Hide resolved
src/transformers/models/rag/modeling_rag.py Outdated Show resolved Hide resolved
src/transformers/models/rag/modeling_rag.py Outdated Show resolved Hide resolved
assert batch.labels.shape == (2, 5)
assert len(batch) == 3 # input_ids, attention_mask, labels. Other things make by BartModel
assert targets["input_ids"].shape == (2, 5)
assert len(batch) == 2 # input_ids, attention_mask. Other things make by BartModel
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Think we can remove the Other things make by BartModel

Comment on lines 40 to 44
>>> inputs = tokenizer([article], return_tensors="pt")
>>> with tokenizer.as_target_tokenizer():
... labels = tokenizer([summary], return_tensors="pt")

>>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I think it would be better to pass article and summary as either string or list in all examples to ensure consistency in docs. Some examples are using lists and some are directly passing the single string.

src/transformers/models/rag/modeling_rag.py Outdated Show resolved Hide resolved
Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very clean - I like it

sgugger and others added 2 commits February 22, 2021 11:46
Co-authored-by: Lysandre Debut <lysandre@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
@sgugger sgugger merged commit 9e147d3 into master Feb 22, 2021
@sgugger sgugger deleted the deprecate_prepare_seq2seq branch February 22, 2021 17:36
@zartdinov
Copy link

Hi all! Sorry, but this seems to be cleaner: (Some feature request: #14255)

encoded_train_dataset = train_dataset.map(
    lambda batch: tokenizer.prepare_seq2seq_batch(
        batch['text'], batch['summary'], padding='max_length', truncation=True, max_length=256, max_target_length=64
    ),
    batched=True,
    remove_columns=train_dataset.column_names,
)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants