Skip to content

Commit

Permalink
FSDP: better traceback for dtype assertion (#912)
Browse files Browse the repository at this point in the history
  • Loading branch information
sshleifer committed Jan 18, 2022
1 parent 6b2f992 commit fef4423
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,8 +279,14 @@ def _init_flatten_params(
shared_param_memo: Dict[nn.Parameter, Tuple[str, nn.Module, str]] = {}
shared_param_infos = []
params = []
fp32 = []
fp16 = []
for module_name, m in self.named_modules():
for n, p in m.named_parameters(recurse=False):
if p.dtype != torch.float16:
fp32.append(module_name)
else:
fp16.append(module_name)
if p is not None and (m, n) in p_set:
if p in shared_param_memo:
mname, shared_m, shared_n = shared_param_memo[p]
Expand All @@ -290,8 +296,10 @@ def _init_flatten_params(
param_infos.append((module_name, m, n))
params.append(p)
del shared_param_memo

assert len(set(p.dtype for p in params)) == 1, "expects all parameters to have same dtype"
fp16_msg, fp32_msg = ",".join(fp16), ",".join(fp32)
assert (
len(set(p.dtype for p in params)) == 1
), f"expects all parameters to have same dtype: fp32: {fp32_msg} \n fp16: {fp16_msg} "
assert len(set(p.requires_grad for p in params)) == 1, "expects all parameters to have same requires_grad"
assert len(params) == len(set(params)), "params list should not have dups"
return params, param_infos, shared_param_infos
Expand Down

0 comments on commit fef4423

Please sign in to comment.