Skip to content

Commit

Permalink
[bugfix] OSS + Apex (#136)
Browse files Browse the repository at this point in the history
* fixing the issue wrt Apex, validated with Latte, Classy would need another pass
  • Loading branch information
blefaudeux committed Oct 14, 2020
1 parent 6d802f5 commit 37c686e
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
19 changes: 14 additions & 5 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def step(self, closure: Optional[Callable[[], float]] = None, **kwargs: Any) ->
) in self.per_device_params.items(): # all the params on this device (inc all ranks)
self._broadcast_params(self._broadcast_buffers[device], device_params)

# Sync hypothethical new results from the wrapped optimizer to the exposed param_groups
self._sync_param_groups(local_to_global=True)

return loss

def local_state_dict(self) -> dict:
Expand Down Expand Up @@ -334,12 +337,18 @@ def add_param_group(self, param_group: dict) -> None:
if len(param_groups) == len(self.optim.param_groups) + 1:
self.optim.add_param_group(param_groups[-1])

def _sync_param_groups(self) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers)."""
def _sync_param_groups(self, local_to_global: bool = False) -> None:
"""Sync learning rate and other optimizer attributes (needed to support schedulers).
If the global param groups have been altered, and we want to make sure that the
wrapped optimizer uses the up to date version.
Conversely if the wrapped optimizer has new keys, we expose them through the global param groups"""

for global_group, local_group in zip(self.param_groups, self.optim.param_groups):
for k in local_group.keys():
if k != "params":
# Params have been sharded and should not be synced here
# Sync everything but the parameters
for k in filter(lambda x: x != "params", local_group.keys()):
if local_to_global:
global_group[k] = local_group[k]
elif k in global_group.keys():
local_group[k] = global_group[k]

def _collect_sharded_states(self) -> List[Dict[str, Any]]:
Expand Down
15 changes: 15 additions & 0 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,21 @@ def step(self, closure=None, kwarg=[]):
assert x == torch.tensor([0.9], device=DEVICE)


def test_step_with_extra_inner_key():
class SGDWithNewKey(torch.optim.SGD):
# Dummy optimizer which adds a new key to the param groups
def step(self, closure=None):
super().step()
self.param_groups[0]["new_key"] = 0.1

x = torch.tensor([1.0], device=DEVICE, requires_grad=True)
o = optim.OSS([x], SGDWithNewKey, lr=0.1)
x.backward()
o.step()
assert o.param_groups[0]["new_key"] == 0.1
assert x == torch.tensor([0.9], device=DEVICE)


def test_step_without_closure():
class SGDWithoutClosure(torch.optim.SGD):
def step(self):
Expand Down

0 comments on commit 37c686e

Please sign in to comment.