-
Notifications
You must be signed in to change notification settings - Fork 25.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Don't merge yet][T5, Bart] Allow t5 torch trace #6268
[Don't merge yet][T5, Bart] Allow t5 torch trace #6268
Conversation
Looks innocuous to me, is there some test this allows us to enable for jit and T5? |
Codecov Report
@@ Coverage Diff @@
## master #6268 +/- ##
==========================================
- Coverage 79.79% 78.50% -1.29%
==========================================
Files 148 148
Lines 27196 27196
==========================================
- Hits 21701 21351 -350
- Misses 5495 5845 +350
Continue to review full report at Codecov.
|
tests/test_modeling_common.py
Outdated
@@ -245,15 +245,18 @@ def _create_and_check_torchscript(self, config, inputs_dict): | |||
inputs = self._prepare_for_class(inputs_dict, model_class)["input_ids"] # Let's keep only input_ids | |||
|
|||
try: | |||
traced_gpt2 = torch.jit.trace(model, inputs) | |||
if model.__class__.__name__ in ["T5Model", "T5ForConditionalGeneration"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quite hacky here, but I didn't see another way....
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't it more general as is_encoder_decoder
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think bart
can run without requiring the decoder_input_ids
@sshleifer...but I guess it would be cleaner to call it encoder_decoder
here...we will have to slightly change Bart then.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I change to is_encoder_decoder=True
, and do the corresponding changes for Bart, Bart hits an assert:
assert attn_output.size() == (bsz * self.num_heads, tgt_len, self.head_dim)
Not really sure what's going on there....do you have an idea @sshleifer ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll push the changes -> let's see if we want to change Bart accordingly or revert to the T5 hack...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrote that assert and should have written an error message, but I can't understand the issue without looking more closely. Does it matter that the second arg to BartForConditionalGeneration.forward
is attention_mask
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Its outdated, seems like you figured it out! I will add a message to my assert.
@@ -282,7 +282,7 @@ class T5ModelTest(ModelTesterMixin, unittest.TestCase): | |||
all_model_classes = (T5Model, T5ForConditionalGeneration) if is_torch_available() else () | |||
all_generative_model_classes = (T5ForConditionalGeneration,) if is_torch_available() else () | |||
test_pruning = False | |||
test_torchscript = False | |||
test_torchscript = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
turn test on
1352cc6
to
0038b24
Compare
After digging a bit deeper why Bart tests fail, I think the reason is that the Bart cache/ transformers/tests/test_modeling_common.py Line 252 in ac001c4
A bug was filed for this problem: #6348. |
PR should be good for merge IMO. @LysandreJik @sshleifer @sgugger - would be great if you can take a quick second look. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. I'll do my part after this merges!
I'm a little concerned that jit cares about arg ordering, but python doesn't care about kwarg ordering, so there might be some very confusing bugs in the future. Don't have a great idea of how to fix + maintain the same API.
except RuntimeError: | ||
self.fail("Couldn't trace module.") | ||
|
||
with tempfile.TemporaryDirectory() as tmp_dir_name: | ||
pt_file_name = os.path.join(tmp_dir_name, "traced_model.pt") | ||
|
||
try: | ||
torch.jit.save(traced_gpt2, pt_file_name) | ||
torch.jit.save(traced_model, pt_file_name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(out of scope)
Why is the try/except -> self.fail
pattern useful?
Without try/except you get an error that traces back to a line with the word save
in it.
I just looked at self.fail
and it turns things into AssertionErrors, which seems like strictly less info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, great work!
Putting this on hold for now as it introduces a breaking change. |
Any updates on this? |
The problem is that it breaks backwards compatibility in a sense that the positional arguments of Bart and T5 are changed. At the moment this is the only option to make torch tracing work for Bart and T5 though...there might be a possiblity to trace a wrapper around the model though - see pytorch/pytorch#14455 . But this currently leads to another problem which is probably related to our PyTorch models not being scriptable at the moment. |
This PR would fix #5647 .
It's not a great solution IMO though.
The problem with torch script is that one cannot pass keyword arguments, but has to pass positional arguments and it is not possible to pass
None
because every input is required to be a tensor.Because T5 requires both
input_ids
anddecoder_input_ids
, the two arguments should arguably be placed as the first two arguments.There might be use cases though, where the same error would occur, which we could not save then, e.g. one wants to input
input_embeds
.Maybe @LysandreJik @sgugger have a better idea.