Skip to content

Commit

Permalink
[fix] a bug in loading flatten state dict (#1025)
Browse files Browse the repository at this point in the history
- Thanks to Alexei for spotting this.
- Added triggering test case.
- enhanced tests with more comments etc.

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Jul 12, 2022
1 parent 5b5db28 commit 03f3e83
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 31 deletions.
6 changes: 4 additions & 2 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,10 @@ def load_state_dict(
Load a state dict. If necessary, ``unflatten_params`` will be called to
match the input state_dict.
"""
# unflatten the module automatically if the state_dict is non-flat
if self.is_flattened and "flat_param_0" not in state_dict:
# Unflatten the module automatically if the state_dict is non-flat.
# Note, we check the flat_param_ prefix since custom names can be given and flat_param_0 is
# not always in the state dict's key list.
if self.is_flattened and not any(k.startswith("flat_param_") for k in state_dict.keys()):
# This object is flatten but state_dict is not. So we unflatten and load.
with self.unflatten_params():
return super().load_state_dict(state_dict, strict)
Expand Down
130 changes: 101 additions & 29 deletions tests/nn/misc/test_flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ def _get_module_init_fns(self):
return [
self._get_basic_linear_module,
self._get_shared_params_transformer,
self._get_2_flatten_group_linear_module,
self._get_2_flatten_group_linear_module_with_names,
]

def _get_empty_module(self, seed=0):
Expand All @@ -37,6 +39,8 @@ def get_input(device, dtype):
return torch.rand(1).to(device=device, dtype=dtype)

module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module

def _get_transformer(self, seed=0):
Expand All @@ -57,6 +61,8 @@ def get_input(device, dtype):
return (src, tgt)

module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module

def _get_shared_params_transformer(self, seed=0):
Expand All @@ -79,22 +85,57 @@ def get_input(device, dtype):
return (torch.rand(8, 4).to(device=device, dtype=dtype),)

module.get_input = get_input
module.param_list = None # No param_list to FPW.
module.flat_param_names = None # No flat_param_names to FPW.
return module

def _get_output(self, module):
def _get_2_flatten_group_linear_module(self, seed=0):
module = torch.nn.Sequential(
torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 16)),
torch.nn.Linear(16, 4),
)

def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
return (torch.rand(8, 4).to(device=device, dtype=dtype),)

module.get_input = get_input
assert len(module) == 2, "next line assumes a len==2 sequential module"
module.param_list = [list(module[0].parameters()), list(module[1].parameters())]
module.flat_param_names = None # No flat_param_names to FPW.
return module

def _get_2_flatten_group_linear_module_with_names(self, seed=0):
module = torch.nn.Sequential(
torch.nn.Sequential(torch.nn.Linear(4, 8), torch.nn.Linear(8, 16)),
torch.nn.Linear(16, 4),
)

def get_input(device, dtype):
torch.manual_seed(1) # keep everything deterministic
return (torch.rand(8, 4).to(device=device, dtype=dtype),)

module.get_input = get_input
assert len(module) == 2, "next line assumes a len==2 sequential module"
module.param_list = [list(module[0].parameters()), list(module[1].parameters())]
module.flat_param_names = ["layer1", "layer2"]
return module

def _compute_output(self, module):
device = next(module.parameters()).device
dtype = next(module.parameters()).dtype
input = module.get_input(device, dtype)
return module(*input)

def _get_pnorm_after_step(self, module):
optim = torch.optim.SGD(module.parameters(), lr=0.01)
loss = self._get_output(module).sum()
loss = self._compute_output(module).sum()
loss.backward()
optim.step()
return torch.norm(torch.stack([p.detach().norm() for p in module.parameters()]))

def _test_num_params(self, module):
"""Make sure numel of params are the same after flatten."""
ref_num_params = sum(p.numel() for p in module.parameters())

flat_module = FlattenParamsWrapper(module)
Expand All @@ -104,13 +145,14 @@ def _test_num_params(self, module):
assert flat_num_params == flat_module.flat_param.numel()

def _test_output(self, module):
ref_output = self._get_output(module)
ref_output = self._compute_output(module)

flat_module = FlattenParamsWrapper(module)
flat_output = self._get_output(flat_module)
flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output)

def test_partial_flattening(self):
"""Testing some parameters are flatten, with others left non-flatten."""
module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters())

Expand Down Expand Up @@ -139,6 +181,7 @@ def test_partial_flattening(self):
assert all(p.dtype == new_dtype for p in module.encoder.layers[0].parameters())

def test_two_flattening_group(self):
"""Testing 2 flatten groups."""
module = self._get_transformer()
num_params = sum(p.numel() for p in module.parameters())

Expand All @@ -153,8 +196,9 @@ def test_two_flattening_group(self):
assert sum(p.numel() for p in module.parameters()) == num_params

def test_flatten_nothing(self):
"""Testing nothing is flatten case."""
module = self._get_transformer()
ref_out = self._get_output(module)
ref_out = self._compute_output(module)
ref_state_dict = module.state_dict()
for k, v in ref_state_dict.items():
ref_state_dict[k] = v.clone()
Expand All @@ -163,10 +207,11 @@ def test_flatten_nothing(self):
assert ref_state_dict.keys() == fpw_state_dict.keys()
for k, v in ref_state_dict.items():
torch.testing.assert_allclose(v, fpw_state_dict[k])
fpw_out = self._get_output(module)
fpw_out = self._compute_output(module)
torch.testing.assert_allclose(ref_out, fpw_out)

def test_empty_module(self):
"""Test module without any param."""
module = self._get_empty_module()
in_data = torch.rand(1)
ref_out = module(in_data)
Expand Down Expand Up @@ -223,69 +268,96 @@ def test_load_state_dict(self):
for module_init_fn in self._get_module_init_fns():
module = module_init_fn()
ref_state_dict = module.state_dict()
ref_output = self._get_output(module)
ref_output = self._compute_output(module)

module = module_init_fn(seed=1234)
flat_module = FlattenParamsWrapper(module)
flat_module = FlattenParamsWrapper(
module, param_list=module.param_list, flat_param_names=module.flat_param_names
)

# This should work without the unflatten_params context manager
flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module)
flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output)

# And it should work with the context manager too
with flat_module.unflatten_params():
flat_module.load_state_dict(ref_state_dict)
flat_output = self._get_output(flat_module)
flat_output = self._compute_output(flat_module)
assert objects_are_equal(ref_output, flat_output)

def test_flat_state_dict(self):
"""Test that flat state dict can be reloaded and produces the same results."""
for module_init_fn in self._get_module_init_fns():
flat_module = FlattenParamsWrapper(module_init_fn())
ref_output = self._get_output(flat_module)
orig_module = module_init_fn()
flat_module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)
ref_output = self._compute_output(flat_module)

flat_state_dict = flat_module.flat_state_dict()

new_module = FlattenParamsWrapper(module_init_fn(seed=1234))
orig_module = module_init_fn(seed=1234)
new_module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)
new_module.load_state_dict(flat_state_dict)
new_output = self._get_output(new_module)
new_output = self._compute_output(new_module)

assert objects_are_equal(ref_output, new_output)

def test_unflatten_params(self):
"""Testing using external flatten params tensors as module's params' backing data."""
for module_init_fn in self._get_module_init_fns():
module = FlattenParamsWrapper(module_init_fn())
orig_module = module_init_fn()
module = FlattenParamsWrapper(
orig_module, param_list=orig_module.param_list, flat_param_names=orig_module.flat_param_names
)

# keep a list of buffer's key to be used for verification below.
buffers = {k.replace("_fpw_module.", "") for k, _ in module.named_buffers()}

def clone_state_dict():
"""Return a copy of the module's current state via state_dict() API."""
return OrderedDict((k, v.clone()) for k, v in module.state_dict().items())

ref_flat_param = module.flat_param.clone()
ref_flat_params = [fp.clone() for fp in module.flat_params]
# Get the current state as a reference.
with module.unflatten_params():
ref_state_dict = clone_state_dict()
assert not torch.all(ref_flat_param == 0)
for ref_fp in ref_flat_params:
assert not torch.all(ref_fp == 0.0) # Should not all be 0s.

# confirm that unflatten_params reflects values from new_flat_param
new_flat_param = torch.full_like(module.flat_param, fill_value=42.0)
with module.unflatten_params(flat_params=[new_flat_param]):
# get new_state_dict with supplied new_flat_params.
new_flat_params = [torch.full_like(fp, fill_value=42.0) for fp in module.flat_params]
with module.unflatten_params(flat_params=new_flat_params):
new_state_dict = clone_state_dict()
assert new_state_dict.keys() == ref_state_dict.keys()
for k, v in new_state_dict.items():
if k in buffers: # buffers are not changed
torch.testing.assert_allclose(v, ref_state_dict[k])
else: # params reflect new_flat_param value
torch.testing.assert_allclose(v, torch.ones_like(v) * 42.0)

# confirm that unflatten_params reflects values from new_flat_param
assert new_state_dict.keys() == ref_state_dict.keys()
for k, v in new_state_dict.items():
if k in buffers: # buffers are not changed
torch.testing.assert_allclose(v, ref_state_dict[k])
else: # params reflect new_flat_param value
torch.testing.assert_allclose(v, torch.ones_like(v) * 42.0)

# after context manager exits, we go back to previous (reference) state
torch.testing.assert_allclose(module.flat_param, ref_flat_param)
assert len(module.flat_params) == len(ref_flat_params)
for i in range(len(module.flat_params)):
torch.testing.assert_allclose(module.flat_params[i], ref_flat_params[i])

# get another copy of state from the module (without external backing data)
with module.unflatten_params():
ref_state_dict2 = clone_state_dict()
assert objects_are_equal(ref_state_dict, ref_state_dict2)

# Verify it is still the same.
assert objects_are_equal(ref_state_dict, ref_state_dict2)

# if we load the new_state_dict, then the flat param should match new_flat_param
module.load_state_dict(new_state_dict)
torch.testing.assert_allclose(module.flat_param, new_flat_param)
assert len(module.flat_params) == len(new_flat_params)
for i in range(len(module.flat_params)):
torch.testing.assert_allclose(module.flat_params[i], new_flat_params[i])


@unittest.skipIf(not torch.cuda.is_available(), "test requires a GPU")
Expand Down

0 comments on commit 03f3e83

Please sign in to comment.