diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 57bd9074870c..e1b14dcc5163 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -24,6 +24,9 @@ if is_transformers_available(): import transformers + if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): + import deepspeed + if is_peft_available(): from peft import set_peft_model_state_dict @@ -430,15 +433,13 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): self.cur_decay_value = decay one_minus_decay = 1 - decay - context_manager = contextlib.nullcontext - if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): - import deepspeed + context_manager = contextlib.nullcontext() if self.foreach: if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) - with context_manager(): + with context_manager: params_grad = [param for param in parameters if param.requires_grad] s_params_grad = [ s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad @@ -460,7 +461,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) - with context_manager(): + with context_manager: if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: