Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Don't merge yet][T5, Bart] Allow t5 torch trace #6268

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Aug 5, 2020

This PR would fix #5647 .

It's not a great solution IMO though.
The problem with torch script is that one cannot pass keyword arguments, but has to pass positional arguments and it is not possible to pass None because every input is required to be a tensor.
Because T5 requires both input_ids and decoder_input_ids, the two arguments should arguably be placed as the first two arguments.

There might be use cases though, where the same error would occur, which we could not save then, e.g. one wants to input input_embeds.

Maybe @LysandreJik @sgugger have a better idea.

@sgugger
Copy link
Collaborator

sgugger commented Aug 5, 2020

Looks innocuous to me, is there some test this allows us to enable for jit and T5?

@codecov
Copy link

codecov bot commented Aug 5, 2020

Codecov Report

Merging #6268 into master will decrease coverage by 1.28%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #6268      +/-   ##
==========================================
- Coverage   79.79%   78.50%   -1.29%     
==========================================
  Files         148      148              
  Lines       27196    27196              
==========================================
- Hits        21701    21351     -350     
- Misses       5495     5845     +350     
Impacted Files Coverage Δ
src/transformers/modeling_bart.py 95.76% <ø> (ø)
src/transformers/modeling_t5.py 84.46% <ø> (+1.13%) ⬆️
src/transformers/modeling_tf_mobilebert.py 24.55% <0.00%> (-70.09%) ⬇️
src/transformers/tokenization_t5.py 71.83% <0.00%> (-23.95%) ⬇️
src/transformers/tokenization_openai.py 71.21% <0.00%> (-12.88%) ⬇️
src/transformers/file_utils.py 82.44% <0.00%> (+0.25%) ⬆️
src/transformers/tokenization_dpr.py 57.65% <0.00%> (+4.50%) ⬆️
src/transformers/modeling_tf_flaubert.py 87.73% <0.00%> (+63.19%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 9f57e39...ac001c4. Read the comment docs.

@@ -245,15 +245,18 @@ def _create_and_check_torchscript(self, config, inputs_dict):
inputs = self._prepare_for_class(inputs_dict, model_class)["input_ids"] # Let's keep only input_ids

try:
traced_gpt2 = torch.jit.trace(model, inputs)
if model.__class__.__name__ in ["T5Model", "T5ForConditionalGeneration"]:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite hacky here, but I didn't see another way....

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't it more general as is_encoder_decoder?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think bart can run without requiring the decoder_input_ids @sshleifer...but I guess it would be cleaner to call it encoder_decoder here...we will have to slightly change Bart then.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I change to is_encoder_decoder=True, and do the corresponding changes for Bart, Bart hits an assert:

assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)

Not really sure what's going on there....do you have an idea @sshleifer ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll push the changes -> let's see if we want to change Bart accordingly or revert to the T5 hack...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote that assert and should have written an error message, but I can't understand the issue without looking more closely. Does it matter that the second arg to BartForConditionalGeneration.forward is attention_mask?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its outdated, seems like you figured it out! I will add a message to my assert.

@@ -282,7 +282,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase):
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else ()
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else ()
test_pruning = False
test_torchscript = False
test_torchscript = True
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

turn test on

@patrickvonplaten
Copy link
Contributor Author

After digging a bit deeper why Bart tests fail, I think the reason is that the Bart cache/past_key_value data structure is not compatible with torchscript. For now the Bart tests pass because returning the cache/past_key_vaule is disabled - see

model.config.use_cache = False
.

A bug was filed for this problem: #6348.

@patrickvonplaten patrickvonplaten changed the title [T5] Allow t5 torch trace [T5, Bart] Allow t5 torch trace Aug 8, 2020
@patrickvonplaten
Copy link
Contributor Author

PR should be good for merge IMO. @LysandreJik @sshleifer @sgugger - would be great if you can take a quick second look.

Copy link
Contributor

@sshleifer sshleifer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'll do my part after this merges!

I'm a little concerned that jit cares about arg ordering, but python doesn't care about kwarg ordering, so there might be some very confusing bugs in the future. Don't have a great idea of how to fix + maintain the same API.

except RuntimeError:
self.fail("Couldn't trace module.")

with tempfile.TemporaryDirectory() as tmp_dir_name:
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt")

try:
torch.jit.save(traced_gpt2, pt_file_name)
torch.jit.save(traced_model, pt_file_name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(out of scope)
Why is the try/except -> self.fail pattern useful?

Without try/except you get an error that traces back to a line with the word save in it.
I just looked at self.fail and it turns things into AssertionErrors, which seems like strictly less info.

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, great work!

@patrickvonplaten
Copy link
Contributor Author

Putting this on hold for now as it introduces a breaking change.

@patrickvonplaten patrickvonplaten changed the title [T5, Bart] Allow t5 torch trace [Don't merge yet][T5, Bart] Allow t5 torch trace Aug 11, 2020
@misrasaurabh1
Copy link
Contributor

Any updates on this?

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Sep 3, 2020

The problem is that it breaks backwards compatibility in a sense that the positional arguments of Bart and T5 are changed. At the moment this is the only option to make torch tracing work for Bart and T5 though...there might be a possiblity to trace a wrapper around the model though - see pytorch/pytorch#14455 . But this currently leads to another problem which is probably related to our PyTorch models not being scriptable at the moment.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

T5 TorchScript (Trace) Conversion
4 participants