[Not for merge] fp8allgather debug #1147
Open
+192
−34
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
With @jspark1105 's commits enabling FP8 allgather, we can run
test_te.py
and also local training without PP.However, if enabling PP, there are some issues with FP8 allgather that need to be fixed. This diff copies the changes from @jspark1105 's PR and includes the fixes we need in fairscale.
The fixes we need are as following (most fixes are naive and need better implementations):
When run
![image](https://private-user-images.githubusercontent.com/10011346/278856716-6b740043-b70b-47a1-a794-416ffafea3aa.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTkwMzU0MDYsIm5iZiI6MTcxOTAzNTEwNiwicGF0aCI6Ii8xMDAxMTM0Ni8yNzg4NTY3MTYtNmI3NDAwNDMtYjcwYi00N2ExLWE3OTQtNDE2ZmZhZmVhM2FhLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA2MjIlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNjIyVDA1NDUwNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTcxYWRmYjI4MTE1MWE1YTY2MWE0MTY1M2JmN2JmMWI5NWRiZGRmNzBlOGQ1OWMwMjlkNjU3YmNhNDBjOTkwMzImWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.NYF4L3EVrf3taGYxgm3Mx4rbUd9WUMiTcDAj9fF7VwE)
model_chunk._rebuild_full_params_recursive()
inxlformers/src/model_parallel_core/pipeline_parallel/fwd_bwd_schedules.py
, we need to pass the FP8 training related settings into the context. All changes in xlformers are included in this commit.In
![image](https://private-user-images.githubusercontent.com/10011346/278857807-23b52678-2761-47df-9c6a-99d177300410.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTkwMzU0MDYsIm5iZiI6MTcxOTAzNTEwNiwicGF0aCI6Ii8xMDAxMTM0Ni8yNzg4NTc4MDctMjNiNTI2NzgtMjc2MS00N2RmLTljNmEtOTlkMTc3MzAwNDEwLnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA2MjIlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNjIyVDA1NDUwNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWU2YzA5ZjUxZGQxYjNhNTBkYjZjODdhMTFhZjg4ZmRlMWE5Nzg2YzFmMDQwNDllNDRjNDVmYTgyZWFiM2EwZWMmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.xsQcBt3GTzlfTglD3soJBw-sNPno1GhrFSTWz5Tx8Ws)
TransformerEngine
, we don't need to return the weight gradients for FP8 training since the gradients will be accumulated in.main_grad
. All changes in TE are included in this commit.The
![image](https://private-user-images.githubusercontent.com/10011346/278857653-7da9a266-7cf7-4347-a10d-269868875cf8.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTkwMzU0MDYsIm5iZiI6MTcxOTAzNTEwNiwicGF0aCI6Ii8xMDAxMTM0Ni8yNzg4NTc2NTMtN2RhOWEyNjYtN2NmNy00MzQ3LWExMGQtMjY5ODY4ODc1Y2Y4LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA2MjIlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNjIyVDA1NDUwNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFiMTgzZTViNzI5ODBhYjIxMDVkNDY2YzU0YTM5MzFiYThkMmVlOTc1MjFlNGY4YWNiMjA2Y2VlYmM5NDk1YzAmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.XZ8crMEzvO8NglrQcMT7ydDLMKVE-2RyoBB09n22RKM)
FlattenParamsWrapper
creates the view of the parameters every forward pass. It is unnecessary if we are not doing resharding after forward. Also, it creates a problem for FP8 allgather + PP because we create.main_grad
in the beginning of the forward, and we can only access the last view of parameters. The earlier views of parameters are no longer accessable.We should not free the
![image](https://private-user-images.githubusercontent.com/10011346/278858000-d83e253a-a9ae-4d96-88f8-66a588652eb6.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTkwMzU0MDYsIm5iZiI6MTcxOTAzNTEwNiwicGF0aCI6Ii8xMDAxMTM0Ni8yNzg4NTgwMDAtZDgzZTI1M2EtYTlhZS00ZDk2LTg4ZjgtNjZhNTg4NjUyZWI2LnBuZz9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNDA2MjIlMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjQwNjIyVDA1NDUwNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPTc2MTI2NDcxNGQ4YWMxZjlmM2U1ZDc2YzMxYTQwMDlkZGJhOTcwNWE4Y2JiN2EzNDRmYjc3ODJjYTFmM2JjZDcmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0JmFjdG9yX2lkPTAma2V5X2lkPTAmcmVwb19pZD0wIn0.gxucSTtTWOpdZBaS7-vQOlwO2UL3nk3F2Onk2-myUmE)
. _free_fp16_param_shard
in the_post_backward_hook
. The FP16 shard needs to be kept since each backward pass needs to use it.