Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[zero-3] print warning once and support torch parameter #2127

Merged
merged 10 commits into from
Aug 12, 2022
16 changes: 12 additions & 4 deletions deepspeed/runtime/zero/parameter_offload.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from deepspeed.runtime.zero.partition_parameters import _init_external_params
from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partitioned_param_coordinator import PartitionedParameterCoordinator, iter_params
from deepspeed import comm as dist

FWD_MODULE_STACK = list()

Expand All @@ -21,6 +22,10 @@ def is_builtin_type(obj):
return obj.__class__.__module__ == '__builtin__' or obj.__class__.__module__ == "builtins"


# ensure we only warn once, otherwise every iteration will trigger a warning
warned = False


#apply torch.autograd.Function that calls a backward_function to tensors in output
def _apply_to_tensors_only(module, functional, backward_function, outputs):
if isinstance(outputs, (tuple, list)):
Expand All @@ -45,10 +50,13 @@ def _apply_to_tensors_only(module, functional, backward_function, outputs):
return functional.apply(module, backward_function, outputs)
else:
if not is_builtin_type(outputs):
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
global warned
if not warned and dist.get_rank() == 0:
logger.warning(
f"A module has unknown inputs or outputs type ({type(outputs)}) and the tensors embedded in it cannot be detected. "
"The ZeRO-3 hooks designed to trigger before or after backward pass of the module relies on knowing the input and "
"output tensors and therefore may not get triggered properly.")
warned = True
return outputs


Expand Down