Skip to content

Improve checkpointing for Zero stage 1#5478

Merged
ashbhandare merged 18 commits into
masterfrom
aibhanda/zero_1_ckpt
Dec 7, 2020
Merged

Improve checkpointing for Zero stage 1#5478
ashbhandare merged 18 commits into
masterfrom
aibhanda/zero_1_ckpt

Conversation

@ashbhandare
Copy link
Copy Markdown
Contributor

This PR does the folllowing changes:

  1. Completely shard the FP32 weight in case of an fp16 run
  2. Simplify the aggregation logic for Zero checkpoints

The correctness has been verified with the test added through #5476:
CUDA_VISIBLE_DEVICES=0,1,2,3 mpirun -n 4 --tag-output python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_zero

Comment thread orttraining/orttraining/python/training/checkpoint.py Outdated
@thiagocrepaldi
Copy link
Copy Markdown
Contributor

This PR does the folllowing changes:

1. Completely shard the FP32 weight in case of an fp16 run

2. Simplify the aggregation logic for Zero checkpoints

The correctness has been verified with the test added through #5476:
CUDA_VISIBLE_DEVICES=0,1,2,3 mpirun -n 4 --tag-output python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_zero

Is this PR enabling any new scenario for checkpointing, such as fp32 -> fp16 and vice-versa? Or this is just a performance improvement for the existing scenarios

@thiagocrepaldi
Copy link
Copy Markdown
Contributor

@baijumeswani FYI

Comment thread orttraining/orttraining/core/session/training_session.cc Outdated
Comment thread orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc Outdated
Comment thread orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc
jessebenson
jessebenson previously approved these changes Oct 15, 2020
@ashbhandare
Copy link
Copy Markdown
Contributor Author

This PR does the folllowing changes:

1. Completely shard the FP32 weight in case of an fp16 run

2. Simplify the aggregation logic for Zero checkpoints

The correctness has been verified with the test added through #5476:
CUDA_VISIBLE_DEVICES=0,1,2,3 mpirun -n 4 --tag-output python orttraining_run_bert_pretrain.py ORTBertPretrainTest.test_pretrain_zero

Is this PR enabling any new scenario for checkpointing, such as fp32 -> fp16 and vice-versa? Or this is just a performance improvement for the existing scenarios

The fp16-> fp32 support already exists. This PR removes dependency of zero aggregation on optimizer state being present in the state_dict. This allows only saving the model weights for zero_1.

@ashbhandare ashbhandare force-pushed the aibhanda/zero_1_ckpt branch 2 times, most recently from df4070c to 06ac0bb Compare October 29, 2020 16:07
@ashbhandare
Copy link
Copy Markdown
Contributor Author

The test graph bert_toy_postprocessed.onnx had to be run through onnxruntime/tools/python/remove_initializer_from_input.py to move the initializers bert.embeddings.position_embeddings.weight and bert.embeddings.word_embeddings.weight from inputs as trainable weights should not be expected to be overridden, and zero partitioning is conditioned upon the initializers not being in graph inputs.

Additionally, the script orttraining/orttraining/test/python/orttrainer_bert_toy_onnx_ckpt_gen.py has been added to generate the zero checkpoints required for test testToyBertCheckpointLoadZero()

Comment thread orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc Outdated
Comment thread orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc Outdated
Comment thread orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc Outdated
Comment thread orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc Outdated
Comment thread orttraining/orttraining/core/graph/zero_optimizer_graph_builder.cc Outdated
Comment thread orttraining/orttraining/python/checkpointing_utils.py
Comment thread orttraining/orttraining/python/checkpointing_utils.py Outdated
Comment thread orttraining/orttraining/python/checkpointing_utils.py
Comment thread orttraining/orttraining/python/training/checkpoint.py
Comment thread orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py Outdated
baijumeswani
baijumeswani previously approved these changes Oct 30, 2020
Copy link
Copy Markdown
Contributor

@baijumeswani baijumeswani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

@ashbhandare ashbhandare dismissed thiagocrepaldi’s stale review October 30, 2020 16:33

change has been addressed, dismissing to unblock

@thiagocrepaldi
Copy link
Copy Markdown
Contributor

thiagocrepaldi commented Nov 2, 2020

d zero partitioning is conditioned upo

@ashbhandare Are you familiar with the PyTorch flexible API specs? It is a new frontend for ORT which requires all graph initializers to be passed as graph inputs. The initializers come from the original pytorch model and passed into ORT, so the ORT backend is actually stateless in this sense, as it will only compute stuff on top of inputs.

@mrry Do you think this behavior is compatible with ORTModule design? Do we intend to support ZeRO on the flexible API?

Comment thread orttraining/orttraining/python/checkpointing_utils.py
Comment thread orttraining/orttraining/python/training/checkpoint.py
Comment thread orttraining/orttraining/test/python/orttrainer_bert_toy_onnx_ckpt_gen.py Outdated
Comment thread orttraining/orttraining/test/python/orttrainer_bert_toy_onnx_ckpt_gen.py Outdated
Comment thread orttraining/orttraining/test/python/orttrainer_bert_toy_onnx_ckpt_gen.py Outdated
Comment thread orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py Outdated
Comment thread orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py Outdated
Comment thread orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py Outdated
Comment thread orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py Outdated
Comment thread orttraining/orttraining/test/python/orttraining_test_orttrainer_bert_toy_onnx.py Outdated
@ashbhandare
Copy link
Copy Markdown
Contributor Author

d zero partitioning is conditioned upo

@ashbhandare Are you familiar with the PyTorch flexible API specs? It is a new frontend for ORT which requires all graph initializers to be passed as graph inputs. The initializers come from the original pytorch model and passed into ORT, so the ORT backend is actually stateless in this sense, as it will only compute stuff on top of inputs.

@mrry Do you think this behavior is compatible with ORTModule design? Do we intend to support ZeRO on the flexible API?

If the initializers are inputs in the flexible API, will the optimizer be handled by ORT? If not, zero partitioning for stage 1 should not happen within ORT and this change will not be touched. If yes, we could enable the older way of adding a 'View' for the flexible API alone.

@ashbhandare ashbhandare force-pushed the aibhanda/zero_1_ckpt branch 3 times, most recently from ae1fb60 to 59a2698 Compare December 2, 2020 17:53
@thiagocrepaldi
Copy link
Copy Markdown
Contributor

/azp run orttraining-linux-gpu-ci-pipeline

thiagocrepaldi
thiagocrepaldi previously approved these changes Dec 3, 2020
Comment thread onnxruntime/test/python/onnxruntime_test_ort_trainer.py Outdated
@ashbhandare ashbhandare merged commit 7cebf76 into master Dec 7, 2020
@ashbhandare ashbhandare deleted the aibhanda/zero_1_ckpt branch December 7, 2020 17:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants