Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[T5Tokenizer] add prepare_seq2seq_batch method #6122

Merged
merged 9 commits into from
Aug 17, 2020

Conversation

patil-suraj
Copy link
Contributor

This PR adds prepare_seq2seq_batch method to T5Tokenizer as per the proposal in #6080

@sshleifer


def set_tgt_special_tokens(self) -> None:
self.prefix_tokens = [self.pad_token_id]
self.suffix_tokens = [self.eos_token_id]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not entirely sure about adding eos automatically. What do you think @sshleifer ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I wouldn't do eos in this PR. I think for that we need to either
    a) get to the bottom of why it impacts zero shot translation performance
    or
    b) add a flag to support not adding it (for backward compatibility/ zero shot tasks).

  2. Do we have evidence that adding a prefix token on the decoder side is helpful?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have evidence that adding a prefix token on the decoder side is helpful?

yes, the T5Model does this in the _shift_right method. Same is the case with the original TF T5 implementation. AFAIK in seq2seq models decoder uses special start token, in BART the tokenizer automatically adds bos, in T5 there is no bos instead pad token is used as decoder start id

Copy link
Contributor Author

@patil-suraj patil-suraj Jul 29, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get to the bottom of why it impacts zero shot translation performance

I will remove it for now, and for this issue to be solved.

src/transformers/tokenization_t5.py Show resolved Hide resolved

def set_tgt_special_tokens(self) -> None:
self.prefix_tokens = [self.pad_token_id]
self.suffix_tokens = [self.eos_token_id]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. I wouldn't do eos in this PR. I think for that we need to either
    a) get to the bottom of why it impacts zero shot translation performance
    or
    b) add a flag to support not adding it (for backward compatibility/ zero shot tasks).

  2. Do we have evidence that adding a prefix token on the decoder side is helpful?

]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More cases to test:

  • test max_target_length kwarg and allow it to be passed through, affect decoder_input_ids.shape[1]
  • empty tgt_texts
  • empty src_texts -> Raises something

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I will cover these cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

empty tgt_texts

for this can I just check if input_ids and attention_mask are returned and no decoder_input_ids and decoder_attention_mask ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these tests look great now!

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one nit, otherwise LGTM

for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v

self.set_src_special_tokens()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) I would stylistically, just say self.prefix_tokens = [] and self.prefix_tokens = [self.pad_token_id] to avoid adding a layer of abstraction.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, unless you expect people to have to subclass your work to inject some custom behavior.

src_text = ["A long paragraph for summrization.", "Another paragraph for summrization."]
batch = tokenizer.prepare_seq2seq_batch(src_text, return_tensors=FRAMEWORK)
# check if input_ids are returned and no decoder_input_ids
self.assertIn("input_ids", batch.keys())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) dont think you need .keys

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aah, right. in works for dict keys by default. Thanks 😀

self.assertIsInstance(batch, BatchEncoding)
self.assertEqual(batch.input_ids.shape, (2, 512))

def test_eos_in_input(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be cool to migrate one or more of the integration tests in test_modeling_t5.py to the new method.

@patil-suraj patil-suraj changed the title [WIP] [T5Tokenizer] add prepare_seq2seq_batch method [T5Tokenizer] add prepare_seq2seq_batch method Jul 31, 2020
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice, thanks! I have some nits on the docs.

"""
Build model inputs from a sequence or a pair of sequence for sequence classification tasks
by concatenating and adding special tokens. The special tokens depend on calling source text or target text.
An T5 sequence has the following format, where ``X`` represents the sequence:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
An T5 sequence has the following format, where ``X`` represents the sequence:
A T5 sequence has the following format, where ``X`` represents the sequence:

Args:
token_ids_0 (:obj:`List[int]`):
List of IDs to which the special tokens will be added
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
token_ids_1 (:obj:`List[int]`, `optional`):

(we only indicate real default values. If something is optional, the None default value is expected).

Optional second list of IDs for sequence pairs.

Returns:
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
:obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.

**kwargs,
) -> BatchEncoding:
"""Prepare a batch that can be passed directly to an instance of T5Model.
Arguments:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please specify the argument types with the same STYLE as above, also make sure you document all arguments (return_tensors is not documented).

**kwargs: passed to self.__call__

Returns:
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
:obj:`BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.
:class:`~transformers.BatchEncoding`: with keys input_ids, attention_mask, decoder_input_ids, decoder_attention_mask.

for k, v in decoder_inputs.items():
model_inputs[f"decoder_{k}"] = v

self.set_src_special_tokens()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, unless you expect people to have to subclass your work to inject some custom behavior.

@patil-suraj
Copy link
Contributor Author

@sshleifer , @sgugger I have made changes regarding the suggestions. Thanks !

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

]
expected_src_tokens = [71, 307, 8986, 21, 4505, 51, 52, 1707, 5, 1]
batch = tokenizer.prepare_seq2seq_batch(
src_text, tgt_texts=tgt_text, max_length=len(expected_src_tokens), return_tensors=FRAMEWORK
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these tests look great now!

self.assertNotIn("decoder_attention_mask", batch)

def test_max_target_length(self):
tokenizer = T5Tokenizer.from_pretrained("t5-small")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tip: you can use

@cached_property
def default_tok(self):
    return T5Tokenizer.from_pretrained("t5-small")

To only initialize once. This barely matters for tokenizers. More usefuls for models where __init__ can take 20 seconds.

@sshleifer sshleifer self-assigned this Aug 17, 2020
@codecov
Copy link

codecov bot commented Aug 17, 2020

Codecov Report

Merging #6122 into master will increase coverage by 0.08%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6122      +/-   ##
==========================================
+ Coverage   78.51%   78.59%   +0.08%     
==========================================
  Files         146      146              
  Lines       26326    26347      +21     
==========================================
+ Hits        20669    20708      +39     
+ Misses       5657     5639      -18     
Impacted Files Coverage Δ
src/transformers/tokenization_t5.py 96.73% <100.00%> (+0.96%) ⬆️
src/transformers/modeling_tf_gpt2.py 65.42% <0.00%> (-29.91%) ⬇️
src/transformers/tokenization_xlnet.py 66.66% <0.00%> (-23.43%) ⬇️
src/transformers/modeling_tf_utils.py 84.09% <0.00%> (-4.88%) ⬇️
src/transformers/modeling_tf_pytorch_utils.py 88.05% <0.00%> (-1.26%) ⬇️
src/transformers/file_utils.py 82.20% <0.00%> (-0.29%) ⬇️
src/transformers/generation_tf_utils.py 85.71% <0.00%> (-0.26%) ⬇️
src/transformers/generation_utils.py 97.11% <0.00%> (+0.28%) ⬆️
src/transformers/tokenization_openai.py 84.09% <0.00%> (+12.87%) ⬆️
src/transformers/modeling_tf_distilbert.py 98.79% <0.00%> (+33.89%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 92f8ce2...a84bb5b. Read the comment docs.

@patil-suraj
Copy link
Contributor Author

@sshleifer , @patrickvonplaten , all green :)

@sshleifer sshleifer merged commit 407da12 into huggingface:master Aug 17, 2020
@patil-suraj patil-suraj deleted the t5-tok-seq2seq-batch branch August 17, 2020 17:59
Zigur pushed a commit to Zigur/transformers that referenced this pull request Oct 26, 2020
fabiocapsouza pushed a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
fabiocapsouza added a commit to fabiocapsouza/transformers that referenced this pull request Nov 15, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants