Skip to content

Commit

Permalink
[feat] Gracefully handle local/global state dict queries (#89)
Browse files Browse the repository at this point in the history
Return either the local or global state when queried, depending on a prior consolidation
  • Loading branch information
blefaudeux committed Sep 15, 2020
1 parent 3d7f524 commit d16e9f6
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 10 deletions.
30 changes: 20 additions & 10 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,18 @@ def state_dict(self) -> Dict[str, Any]:
"""
Return the last known global optimizer state, which consist of a list of the shards.
NOTE: This is limited to the replica which was responsible for the consolidation.
NOTE:
- If the state has not been consolidated, this returns a shard's worth, not the global state.
- Returning the global state is limited to the replica which was responsible for the consolidation.
The state may also not be up to date, depending on when `consolidate_state_dict` was last called.
"""

assert (
len(self._all_states) > 0
), "The optimizer state is not materialized, please call consolidate_state_dict on every replica beforehand"
if len(self._all_states) == 0:
logging.warning("Optimizer state has not been consolidated. Returning the local state")
logging.warning("Please call `consolidate_state_dict()` beforehand if you meant to save the global state")
state_dict = self.local_state_dict()
state_dict["local_state_dict"] = True
return state_dict

# Flatten the param_groups, save the partition which logs the rank <> shard correspondence
partition: List[Tuple[int, int]] = []
Expand All @@ -167,6 +172,7 @@ def state_dict(self) -> Dict[str, Any]:
"state": [s["state"] for s in self._all_states],
"param_groups": param_groups,
"partition": partition,
"local_state_dict": False,
}

def load_local_state_dict(self, state_dict: dict) -> None:
Expand Down Expand Up @@ -196,12 +202,16 @@ def load_local_state_dict(self, state_dict: dict) -> None:
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
""" Restore the global parameter groups as well as the shard """

# Get this optimizer's param_groups shard
param_groups = state_dict["param_groups"][
state_dict["partition"][self.rank][0] : state_dict["partition"][self.rank][1]
]
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups})
# Check whether we got a local or global dict
if state_dict["local_state_dict"]:
self.load_local_state_dict(state_dict)
else:
# Get this optimizer's param_groups shard
param_groups = state_dict["param_groups"][
state_dict["partition"][self.rank][0] : state_dict["partition"][self.rank][1]
]
# Dispatch this rank's state dictionary to the wrapped shard optimizer
self.load_local_state_dict({"state": state_dict["state"][self.rank], "param_groups": param_groups})

def add_param_group(self, param_group: dict) -> None:
super().add_param_group(param_group)
Expand Down
14 changes: 14 additions & 0 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,20 @@ def test_local_state_dict():
assert x == torch.tensor([0.9], device=DEVICE)


def test_implicit_local_state_dict():
x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], lr=0.1)
local_state_dict = o.state_dict()
o = optim.OSS([x], lr=0.01)
o.load_state_dict(local_state_dict)
# We should now be using a lr of 0.1.
assert o.optim.param_groups[0]["lr"] == 0.1
assert o.param_groups[0]["lr"] == 0.1
x.backward()
o.step()
assert x == torch.tensor([0.9], device=DEVICE)


def run_test_add_param_group(rank, world_size):
dist_init(rank, world_size)
params = []
Expand Down

0 comments on commit d16e9f6

Please sign in to comment.