diff --git a/torch/distributed/fsdp/fully_sharded_data_parallel.py b/torch/distributed/fsdp/fully_sharded_data_parallel.py index 5f924e09aa61..faacac28c573 100644 --- a/torch/distributed/fsdp/fully_sharded_data_parallel.py +++ b/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -634,6 +634,8 @@ def set_state_dict_type( state_dict_type (StateDictType): the desired ``state_dict_type`` to set. state_dict_config (Optional[StateDictConfig]): the configuration for the target ``state_dict_type``. + optim_state_dict_config (Optional[OptimStateDictConfig]): the configuration for the optimizer state dict + Returns: A StateDictSettings that include the previous state_dict type and configuration for the module.