From d8aebdda9a3dc947845bdf16e7d1b56e11595e4e Mon Sep 17 00:00:00 2001 From: charchit7 Date: Wed, 9 Oct 2024 12:34:33 +0530 Subject: [PATCH 1/3] gatherparams bug --- src/diffusers/training_utils.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 57bd9074870c..0842aed5c3af 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -24,6 +24,9 @@ if is_transformers_available(): import transformers + if transformers.deepspeed.is_deepspeed_zero3_enabled(): + import deepspeed + if is_peft_available(): from peft import set_peft_model_state_dict @@ -431,14 +434,12 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): one_minus_decay = 1 - decay context_manager = contextlib.nullcontext - if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): - import deepspeed if self.foreach: if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None) - with context_manager(): + with context_manager: params_grad = [param for param in parameters if param.requires_grad] s_params_grad = [ s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad From 67d1fea1fe1aa8be1ddfd758e7d4902beead155b Mon Sep 17 00:00:00 2001 From: charchit7 Date: Sat, 12 Oct 2024 17:59:12 +0530 Subject: [PATCH 2/3] calling context lib object --- src/diffusers/training_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 0842aed5c3af..850839f96c91 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -24,7 +24,7 @@ if is_transformers_available(): import transformers - if transformers.deepspeed.is_deepspeed_zero3_enabled(): + if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): import deepspeed if is_peft_available(): @@ -433,7 +433,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): self.cur_decay_value = decay one_minus_decay = 1 - decay - context_manager = contextlib.nullcontext + context_manager = contextlib.nullcontext() if self.foreach: if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): From d935b49fe620632ccb7f273e27f2efdf49fd57c1 Mon Sep 17 00:00:00 2001 From: charchit7 Date: Mon, 14 Oct 2024 07:36:38 +0530 Subject: [PATCH 3/3] fix --- src/diffusers/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 850839f96c91..e1b14dcc5163 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -461,7 +461,7 @@ def step(self, parameters: Iterable[torch.nn.Parameter]): if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled(): context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None) - with context_manager(): + with context_manager: if param.requires_grad: s_param.sub_(one_minus_decay * (s_param - param)) else: