ORTModule GraphTransitionManager#19007
Merged
Merged
Conversation
baijumeswani
reviewed
Jan 8, 2024
pengwa
added a commit
that referenced
this pull request
Jan 16, 2024
## Dependency #19007 ## ORTModule memory efficient gradient management Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with #8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter. This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly. ## Design Check the details from https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ ## Convergence Validation:  differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2) ## TODO Consolidate the logic with Stage3's similar logic.
pengwa
commented
Feb 21, 2024
Contributor
Author
pengwa
left a comment
There was a problem hiding this comment.
Sorry for the late response.
…pengwa/refactor_io
…pengwa/refactor_io
…pengwa/refactor_io
…pengwa/refactor_io
Contributor
|
construct_inputs and restore_outputs can probably be done by calling tree_flatten and tree_unflatten in https://github.com/pytorch/pytorch/blob/15bd81bfafa86fec9d675e7f071c867c852ebe8f/torch/utils/_pytree.py#L799. |
…pengwa/refactor_io
Contributor
Author
|
Thanks @wschin. |
rohan11235813
pushed a commit
to quadric-io/onnxruntime
that referenced
this pull request
Aug 19, 2025
## Dependency microsoft/onnxruntime#19007 ## ORTModule memory efficient gradient management Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with microsoft/onnxruntime#8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter. This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly. ## Design Check the details from https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ ## Convergence Validation:  differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2) ## TODO Consolidate the logic with Stage3's similar logic.
rohan11235813
pushed a commit
to quadric-io/onnxruntime
that referenced
this pull request
Sep 15, 2025
## Dependency microsoft/onnxruntime#19007 ## ORTModule memory efficient gradient management Previously I have tried to solve the coarsed-grained gradient accumulation/update problem in ORTModule with microsoft/onnxruntime#8979, while that resolution somehow is not fully validated with DDP or there is user hooks on the gradient accumulation on torch parameter. This PR is addressing the problem in the similar approach as PR 8979, e.g. trigger gradient accumulation once ORT computed the grad, but instead of use a AccumulateGrad op, this time with a ONNX operator PythonOp, internally it will call param.backward(grad), which will help handle all related hooks correctly. ## Design Check the details from https://microsoftapc-my.sharepoint.com/:p:/g/personal/pengwa_microsoft_com/EaaBq4EzsFhOmsDEXCG7Ba4Bb9bwd0O2sFV_JXJ4jBLYLA?e=7Sz2g8&nav=eyJzSWQiOjI3MSwiY0lkIjozMjE4NzI1NDIzfQ ## Convergence Validation:  differences are on mostly 0.000x, sometimes 0.00x, which may comes from the different order gradient apply happens before or after this change (on deepspeed zero stage 2) ## TODO Consolidate the logic with Stage3's similar logic.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Problem
Currently, the codebase contains some logics pertaining to model re-export checks and graph_builder reinitialization checks. Ideally, these operations should function akin to a state machine. However, upon inspecting the implementation, it becomes apparent that certain states are checked or set in various scattered locations. This fragmentation makes it challenging to comprehend when a re-export or re-initialization will be triggered. For optimal clarity and maintainability, it is advisable to consolidate these states into a cohesive component, rather than dispersing them within the current graph execution manager.
Furthermore, the process of model exports and post-export processing for stage 3 support or memory-efficient gradient management introduces considerable complexity. To enhance the codebase's structure, it would be beneficial to extract these intricate functionalities into a dedicated component, divorcing them from the current graph execution manager.
As part of the effort to improve the codebase, it's essential to address inconsistencies in handling input/output flatten/unflatten operations. Currently, there are several functions performing these operations recursively, each with slightly different implementations. This inconsistency leads to varying support for input/output data types and structures in different parts of the code. To rectify this, the proposed pull request simplifies these operations into a set of primitive functions, ensuring uniformity. This not only streamlines the code but also facilitates the maintenance of consistency when introducing bug fixes or supporting new data types. One thing to mention here: input output handling is deeply bound to the graph transition mentioned above, so it is difficult to make this change separately.
While acknowledging the complexity of these logics, it is reassuring that the codebase benefits from an extensive suite of unit tests that cover all possible branches. Despite the intricacies, ensuring the passage of all tests has been a time-intensive but necessary aspect of this development effort.
Design
Introduce
GraphTransitionManagerand put all model export and post-export processing logics in it.PostExportProcessedModelInfo, which contains all the information we need, to pass to ORT to build gradient graph (currently we do the same for training or evaluating, but ideally we should not do it for evaluating, let's keep this behavior as it is now, and make the change later).The
GraphTransitionManagerinstance is a property ofGraphExecutionManager(e.g.TrainingManageror ``InferenceManager),self._graph_transition_manager._post_export_processed_model_info.construct_inputsto construct the list of inputs used for ORT runs.self._graph_transition_manager._post_export_processed_model_info.restore_outputs(user_outputs)to restore the outputs in original PyTorch output structure.Motivation and Context