-
Notifications
You must be signed in to change notification settings - Fork 579
Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT #2952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This pull request was exported from Phabricator. Differential Revision: D74295924 |
meta-pytorch#2952) Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Differential Revision: D74295924
|
This pull request was exported from Phabricator. Differential Revision: D74295924 |
meta-pytorch#2952) Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Differential Revision: D74295924
|
This pull request was exported from Phabricator. Differential Revision: D74295924 |
meta-pytorch#2952) Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Reviewed By: TroyGarden Differential Revision: D74295924
|
This pull request was exported from Phabricator. Differential Revision: D74295924 |
meta-pytorch#2952) Summary: Pull Request resolved: meta-pytorch#2952 # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Reviewed By: TroyGarden Differential Revision: D74295924
b4dfa13 to
1bc6ae0
Compare
meta-pytorch#2952) Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Reviewed By: TroyGarden Differential Revision: D74295924
|
This pull request was exported from Phabricator. Differential Revision: D74295924 |
|
This pull request was exported from Phabricator. Differential Revision: D74295924 |
meta-pytorch#2952) Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Reviewed By: TroyGarden Differential Revision: D74295924
|
This pull request was exported from Phabricator. Differential Revision: D74295924 |
meta-pytorch#2952) Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Reviewed By: TroyGarden Differential Revision: D74295924
|
This pull request was exported from Phabricator. Differential Revision: D74295924 |
|
This pull request has been reverted by d797031. |
Summary:
Context
stride_per_rank_per_rankandinverse_indicesfields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function.stride_per_rank_per_rankandinverse_indices.Ref
Differential Revision: D74295924