Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero-3] print warning once and support torch parameter #2127

Merged
merged 10 commits into from
Aug 12, 2022
14 changes: 10 additions & 4 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def is_builtin_type(obj):
return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"


warned = False


#apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if isinstance(outputs, (tuple, list)):
Expand All @@ -44,10 +47,13 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
return functional.apply(module, backward_function, outputs)
else:
if not is_builtin_type(outputs):
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
global warned
if not warned:
jeffra marked this conversation as resolved.
Show resolved Hide resolved
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
warned = True
return outputs


Expand Down