Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Remove activation checkpointing tag to get correct FQNs (pytorch#124698)
Fixes pytorch#124546 When setting `use_orig_params = False` and using activation checkpointing, the FQN mapping as retrieved by the `_get_fqns` function is incorrect because the prefix that is added to the name of each activation checkpointed module, `_checkpoint_wrapped_module`, can still be present. I think this is an edge case with the `_get_fqns` function that was not addressed by this previous commit pytorch#118119. Without the change, the list of object names for an activation checkpointed module with FSDP (and `use_orig_params=False`) can be something like: ``` ['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_checkpoint_wrapped_module', '_flat_param'] ``` Which will incorrectly return just one FQN, `{'model.transformer.blocks.0._flat_param'}`, when all the FQNs of the parameters of the transformer block should be returned. With the change, the list of object names will now have `_checkpoint_wrapped_module` removed: ``` ['model', '_fsdp_wrapped_module', 'transformer', 'blocks', '0', '_fsdp_wrapped_module', '_flat_param'] ``` And the FQNs are correctly retrieved and returned in `_get_fqns` when [this condition](https://github.com/pytorch/pytorch/blob/ea61c9cb299b6dfebc57dc9d8821c34321d568ab/torch/distributed/checkpoint/state_dict.py#L168) is satisfied. The correct FQNs are: ``` {'model.transformer.blocks.0.attn.Wqkv.bias', 'model.transformer.blocks.0.ffn.up_proj.bias', 'model.transformer.blocks.0.attn.out_proj.weight', 'model.transformer.blocks.0.norm_2.weight', 'model.transformer.blocks.0.ffn.down_proj.weight', 'model.transformer.blocks.0.attn.Wqkv.weight', 'model.transformer.blocks.0.norm_2.bias', 'model.transformer.blocks.0.ffn.up_proj.weight', 'model.transformer.blocks.0.ffn.down_proj.bias', 'model.transformer.blocks.0.norm_1.bias', 'model.transformer.blocks.0.norm_1.weight', 'model.transformer.blocks.0.attn.out_proj.bias'} ``` Pull Request resolved: pytorch#124698 Approved by: https://github.com/Skylion007
- Loading branch information