Improve checkpointing for Zero stage 1#5478
Conversation
Is this PR enabling any new scenario for checkpointing, such as fp32 -> fp16 and vice-versa? Or this is just a performance improvement for the existing scenarios |
|
@baijumeswani FYI |
The fp16-> fp32 support already exists. This PR removes dependency of zero aggregation on optimizer state being present in the state_dict. This allows only saving the model weights for zero_1. |
df4070c to
06ac0bb
Compare
|
The test graph bert_toy_postprocessed.onnx had to be run through onnxruntime/tools/python/remove_initializer_from_input.py to move the initializers bert.embeddings.position_embeddings.weight and bert.embeddings.word_embeddings.weight from inputs as trainable weights should not be expected to be overridden, and zero partitioning is conditioned upon the initializers not being in graph inputs. Additionally, the script orttraining/orttraining/test/python/orttrainer_bert_toy_onnx_ckpt_gen.py has been added to generate the zero checkpoints required for test testToyBertCheckpointLoadZero() |
06ac0bb to
a1ca6e9
Compare
change has been addressed, dismissing to unblock
@ashbhandare Are you familiar with the PyTorch flexible API specs? It is a new frontend for ORT which requires all graph initializers to be passed as graph inputs. The initializers come from the original pytorch model and passed into ORT, so the ORT backend is actually stateless in this sense, as it will only compute stuff on top of inputs. @mrry Do you think this behavior is compatible with |
7506db9 to
2340098
Compare
If the initializers are inputs in the flexible API, will the optimizer be handled by ORT? If not, zero partitioning for stage 1 should not happen within ORT and this change will not be touched. If yes, we could enable the older way of adding a 'View' for the flexible API alone. |
ae1fb60 to
59a2698
Compare
|
/azp run orttraining-linux-gpu-ci-pipeline |
59a2698 to
5bc5911
Compare
fa20163 to
eac63e6
Compare
This PR does the folllowing changes:
The correctness has been verified with the test added through #5476:
CUDA_VISIBLE_DEVICES=0,1,2,3 mpirun -n 4 --tag-output python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_zero