Skip to content

Conversation

@xmfan
Copy link
Member

@xmfan xmfan commented Aug 12, 2025

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 12, 2025
@xmfan xmfan requested review from ezyang, fmassa and wconstab August 12, 2025 01:26
@xmfan xmfan marked this pull request as ready for review August 12, 2025 01:26
Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking with @avikchaudhuri @tugsbayasgalan to make sure this doesn't make the eventual pre compile rework harder

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK for the use of export to torch ir. I'll leave it up to fmassa when we should merge this (e.g., does deepcopy need to be fixed first)

@tugsbayasgalan
Copy link
Contributor

Checking with @avikchaudhuri @tugsbayasgalan to make sure this doesn't make the eventual pre compile rework harder

Since we just plan to swap out export_to_torch_ir with the eventual correct API that also works with precompile , i think this change is fine.

@fmassa
Copy link
Contributor

fmassa commented Aug 19, 2025

@xmfan I was looking into implementing the deepcopy for our parametrization so that we can get this merged, but I realized that there might be other issues in here.

The captured graph from _export_to_torch_ir gives us a graph like the following for the example_autoparallel.py code, when adding back mixed precision

graph():
    %l_args_0_ : [num_users=1] = placeholder[target=arg0]
    %to : [num_users=2] = call_method[target=to](args = (%l_args_0_, torch.bfloat16), kwargs = {})
    %l__self___wq_weight : [num_users=1] = get_attr[target=L__self___wq_weight]
    %l__self___wk_weight : [num_users=1] = get_attr[target=L__self___wk_weight]
    %l__self___wv_weight : [num_users=1] = get_attr[target=L__self___wv_weight]
    %l__self___wo_weight : [num_users=1] = get_attr[target=L__self___wo_weight]
    %wrap_body_0 : [num_users=1] = get_attr[target=wrap_body_0]
    %tag_activation_checkpoint : [num_users=1] = call_function[target=torch.ops.higher_order.tag_activation_checkpoint](args = (%wrap_body_0, %to, %l__self___wq_weight, %l__self___wk_weight, %l__self___wv_weight, %l__self___wo_weight), kwargs = {use_reentrant: False})
    %o : [num_users=1] = call_function[target=operator.getitem](args = (%tag_activation_checkpoint, 0), kwargs = {})
    %o0 : [num_users=2] = call_function[target=operator.add](args = (%o, %to), kwargs = {})
    %l__self___w1 : [num_users=1] = get_attr[target=L__self___w1]
    %o_1 : [num_users=1] = call_method[target=forward](args = (%l__self___w1, %o0), kwargs = {})
    %o_2 : [num_users=1] = call_function[target=torch.nn.functional.relu](args = (%o_1,), kwargs = {})
    %l__self___w2 : [num_users=1] = get_attr[target=L__self___w2]
    %o_3 : [num_users=1] = call_method[target=forward](args = (%l__self___w2, %o_2), kwargs = {})
    %o_4 : [num_users=1] = call_function[target=operator.add](args = (%o0, %o_3), kwargs = {})
    %output : [num_users=1] = call_method[target=to](args = (%o_4, torch.bfloat16), kwargs = {})
    return [output]

You can see that the input and output cast get properly called, but the weight parametrization disappear.
My guess is that this is because the weight names have been renamed (like L__self__wq_weight), so we never call into the parametrization.

Am I reading this right?

@xmfan xmfan force-pushed the xmfan/ac_tagging branch from 53c773c to 839b69f Compare August 19, 2025 17:14
@xmfan
Copy link
Member Author

xmfan commented Aug 19, 2025

@fmassa I see the casts properly if I enable torch._dynamo.config.install_free_tensors=True: https://gist.github.com/xmfan/5176d488358e77943dbca0ecd6fe0005. This was caused by a divergence between torch.export dynamo and torch.compile dynamo: https://github.com/pytorch/pytorch/blob/eba20d2d748cb17dce9aa26e5513e4567bfd8282/torch/_dynamo/variables/builder.py#L1881-L1901

@xmfan xmfan force-pushed the xmfan/ac_tagging branch from 839b69f to f799148 Compare August 19, 2025 20:56
@fmassa
Copy link
Contributor

fmassa commented Aug 20, 2025

I validated that with this change we get back the dtype cast hooks.

But looks like test failures are related, and this seems to be breaking init_weights for now.

Also, am I understanding this right that this PR will also make AutoParallel only support tuple inputs?

@xmfan
Copy link
Member Author

xmfan commented Aug 21, 2025

For the init_weights, dynamo encodes the fqn a certain way, self.linear.weight -> self____modules__linear____parameters__weight, iirc we want the fqns to match eager exactly for model saving/loading purposes? @fmassa

@xmfan xmfan changed the title Directly use _export_to_torch_ir strict_mode to support AC tags Migrate to dynamo frontend (_export_to_torch_ir) to support AC tags Aug 21, 2025
@fmassa
Copy link
Contributor

fmassa commented Aug 21, 2025

For the init_weights, dynamo encodes the fqn a certain way, self.linear.weight -> self____modules__linear____parameters__weight, iirc we want the fqns to match eager exactly for model saving/loading purposes? @fmassa

@xmfan yes, we would want to keep the same FQN for saving / loading.

@xmfan
Copy link
Member Author

xmfan commented Aug 21, 2025

Scrap scrap that. It's too brittle to manually restore FQNs. Keep using the export stuff.

Scrap this. Apparently it's okay, people are working on AOT precompile and we'll just move to that.

I synced with @tugsbayasgalan, there's some refactors upcoming to export so it'd be best if we stuck to public APIs. Given that strict=False works for trunk, there's probably not that much gap.

So I'm breaking this PR up:
1- I'll ensure that torch.export(strict=True) works for what we currently have landed in autoparallel: #104
2- Tugsuu is looking at getting the AC HOP to proxy itself during pre-dispatch, the tests in this PR should pass after

@xmfan xmfan mentioned this pull request Aug 21, 2025
@xmfan xmfan force-pushed the xmfan/ac_tagging branch from 36edeb4 to 300608f Compare August 26, 2025 07:04
@xmfan xmfan changed the title Migrate to dynamo frontend (_export_to_torch_ir) to support AC tags Migrate to strict mode export to support AC tags Aug 26, 2025
@xmfan xmfan changed the title Migrate to strict mode export to support AC tags Migrate to strict mode export (dynamo) to support AC tags Aug 26, 2025
@xmfan xmfan changed the title Migrate to strict mode export (dynamo) to support AC tags Migrate to strict mode export (dynamo) to support AC tags and HOPs Aug 26, 2025
Copy link
Contributor

@fmassa fmassa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM once tests pass, thanks for working on it!

Also, could you add some more comments on the need of the monkey_patch?

# mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
# MP policy causing some deepcopy issues
# mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16, reduce_dtype=torch.float32)
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does the code also works with the reduce_dtype=torch.float32? If yes, can you set it the default? That is the setup we mostly use so it would be good to have it as the default

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes works, i'll update these comments



@contextmanager
def monkey_patch_export_verifier():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain a bit (in code as well) why we currently need this?

Is this something that you expect we will remove in the future or is it something that is meant to stay?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added comment, it's something we will remove in the future, either when:

  • export offers a mode to drop the serializability constraints
  • precompile frontend

@xmfan xmfan force-pushed the xmfan/ac_tagging branch from e2a6491 to 1a8463f Compare August 28, 2025 17:34
@xmfan xmfan force-pushed the xmfan/ac_tagging branch from 1a8463f to 997920d Compare August 28, 2025 23:19
@xmfan xmfan merged commit bf39515 into main Aug 28, 2025
6 checks passed
@fmassa fmassa deleted the xmfan/ac_tagging branch August 29, 2025 08:08
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants