Skip to content

Commit

Permalink
[DCP] Update DCP to use the updated FSDP optim state_dict APIs (#95303)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch/pytorch#95303
Approved by: https://github.com/fegin
  • Loading branch information
wz337 authored and cyyever committed Feb 25, 2023
1 parent 88fb1cc commit dd8df34
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions test/distributed/checkpoint/test_2d_fsdp_dt_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def _test_fsdp_dt_checkpoint(self, fsdp_pg=None) -> None:
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {
"model": model.state_dict(),
"optim": FSDP.sharded_optim_state_dict(model, optim),
"optim": FSDP.optim_state_dict(model, optim),
}

dist_cp.save_state_dict(
Expand Down Expand Up @@ -181,7 +181,7 @@ def _test_fsdp_dt_checkpoint(self, fsdp_pg=None) -> None:
optimizer_key="optim",
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
)
flattened_osd = FSDP.flatten_sharded_optim_state_dict(
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state["optim"], model_2, optim_2
)
optim_2.load_state_dict(flattened_osd)
Expand Down
4 changes: 2 additions & 2 deletions test/distributed/checkpoint/test_fsdp_optim_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_distributed_tensor_planner(self) -> None:
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
state_dict = {
"model": model.state_dict(),
"optim": FSDP.sharded_optim_state_dict(model, optim),
"optim": FSDP.optim_state_dict(model, optim),
}

dist_cp.save_state_dict(
Expand Down Expand Up @@ -80,7 +80,7 @@ def test_distributed_tensor_planner(self) -> None:
storage_reader=dist_cp.FileSystemReader(CHECKPOINT_DIR),
)

flattened_osd = FSDP.flatten_sharded_optim_state_dict(
flattened_osd = FSDP.optim_state_dict_to_load(
optim_state["optim"], model_2, optim_2
)
optim_2.load_state_dict(flattened_osd)
Expand Down
4 changes: 2 additions & 2 deletions torch/distributed/checkpoint/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def load_sharded_optimizer_state_dict(
>>> # Save
>>> with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
>>> state_dict = {
>>> "optimizer": FSDP.sharded_optim_state_dict(model, optim, optim_params),
>>> "optimizer": FSDP.optim_state_dict(model, optim),
>>> "model": model.state_dict()
>>> }
>>> dist_cp.save_state_dict(
Expand Down Expand Up @@ -241,7 +241,7 @@ def load_sharded_optimizer_state_dict(
>>> storage_reader=dist_cp.FileSystemReader("checkpoint"),
>>> )
>>>
>>> flattened_osd = FSDP.flatten_sharded_optim_state_dict(
>>> flattened_osd = FSDP.optim_state_dict_to_load(
>>> optim_state["optimizer"], model, optim
>>> )
>>>
Expand Down

0 comments on commit dd8df34

Please sign in to comment.