-
Notifications
You must be signed in to change notification settings - Fork 9
Migrate to strict mode export (dynamo) to support AC tags and HOPs #93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checking with @avikchaudhuri @tugsbayasgalan to make sure this doesn't make the eventual pre compile rework harder
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm OK for the use of export to torch ir. I'll leave it up to fmassa when we should merge this (e.g., does deepcopy need to be fixed first)
Since we just plan to swap out export_to_torch_ir with the eventual correct API that also works with precompile , i think this change is fine. |
|
@xmfan I was looking into implementing the The captured graph from graph():
%l_args_0_ : [num_users=1] = placeholder[target=arg0]
%to : [num_users=2] = call_method[target=to](args = (%l_args_0_, torch.bfloat16), kwargs = {})
%l__self___wq_weight : [num_users=1] = get_attr[target=L__self___wq_weight]
%l__self___wk_weight : [num_users=1] = get_attr[target=L__self___wk_weight]
%l__self___wv_weight : [num_users=1] = get_attr[target=L__self___wv_weight]
%l__self___wo_weight : [num_users=1] = get_attr[target=L__self___wo_weight]
%wrap_body_0 : [num_users=1] = get_attr[target=wrap_body_0]
%tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body_0, %to, %l__self___wq_weight, %l__self___wk_weight, %l__self___wv_weight, %l__self___wo_weight), kwargs = {use_reentrant: False})
%o : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {})
%o0 : [num_users=2] = call_function[target=operator.add](args = (%o, %to), kwargs = {})
%l__self___w1 : [num_users=1] = get_attr[target=L__self___w1]
%o_1 : [num_users=1] = call_method[target=forward](args = (%l__self___w1, %o0), kwargs = {})
%o_2 : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%o_1,), kwargs = {})
%l__self___w2 : [num_users=1] = get_attr[target=L__self___w2]
%o_3 : [num_users=1] = call_method[target=forward](args = (%l__self___w2, %o_2), kwargs = {})
%o_4 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o_3), kwargs = {})
%output : [num_users=1] = call_method[target=to](args = (%o_4, torch.bfloat16), kwargs = {})
return [output]You can see that the input and output cast get properly called, but the weight parametrization disappear. Am I reading this right? |
53c773c to
839b69f
Compare
|
@fmassa I see the casts properly if I enable |
839b69f to
f799148
Compare
|
I validated that with this change we get back the dtype cast hooks. But looks like test failures are related, and this seems to be breaking Also, am I understanding this right that this PR will also make AutoParallel only support tuple inputs? |
|
For the |
|
Scrap scrap that. It's too brittle to manually restore FQNs. Keep using the export stuff.
|
36edeb4 to
300608f
Compare
fmassa
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM once tests pass, thanks for working on it!
Also, could you add some more comments on the need of the monkey_patch?
examples/example_autoparallel.py
Outdated
| # mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) | ||
| # MP policy causing some deepcopy issues | ||
| # mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32) | ||
| mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does the code also works with the reduce_dtype=torch.float32? If yes, can you set it the default? That is the setup we mostly use so it would be good to have it as the default
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes works, i'll update these comments
|
|
||
|
|
||
| @contextmanager | ||
| def monkey_patch_export_verifier(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain a bit (in code as well) why we currently need this?
Is this something that you expect we will remove in the future or is it something that is meant to stay?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added comment, it's something we will remove in the future, either when:
- export offers a mode to drop the serializability constraints
- precompile frontend
e2a6491 to
1a8463f
Compare
1a8463f to
997920d
Compare
Depends on: pytorch/pytorch#161479