From b6bf773315974d44a918e91e654234e821efe223 Mon Sep 17 00:00:00 2001 From: kesimeg <48391912+kesimeg@users.noreply.github.com> Date: Sun, 15 Oct 2023 22:45:19 +0300 Subject: [PATCH 1/3] fix une2td ignoring class_labels --- src/diffusers/models/unet_2d.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index db6d3a5dce3f..4e14e7b8b394 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -291,7 +291,9 @@ def forward( class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb - + elif self.class_embedding is None and class_labels is not None: + raise ValueError("class_embedding needs to be initialized to use class conditioning") + # 2. pre-process skip_sample = sample sample = self.conv_in(sample) From 0e8038ce5a1f6f5ffe5dc9e41656aa89e388be61 Mon Sep 17 00:00:00 2001 From: Ege Date: Tue, 17 Oct 2023 19:51:02 +0300 Subject: [PATCH 2/3] unet2.py error message updated --- src/diffusers/models/unet_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index 4e14e7b8b394..e4aea20d3a33 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -292,7 +292,7 @@ def forward( class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb elif self.class_embedding is None and class_labels is not None: - raise ValueError("class_embedding needs to be initialized to use class conditioning") + raise ValueError("class_embedding needs to be initialized in order to use class conditioning") # 2. pre-process skip_sample = sample From 92b5243023f79c27c0dadafa508f5671c9ad2773 Mon Sep 17 00:00:00 2001 From: Ege Date: Tue, 17 Oct 2023 20:27:49 +0300 Subject: [PATCH 3/3] style and quality changes --- src/diffusers/models/unet_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index e4aea20d3a33..38e26422e2a7 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -293,7 +293,7 @@ def forward( emb = emb + class_emb elif self.class_embedding is None and class_labels is not None: raise ValueError("class_embedding needs to be initialized in order to use class conditioning") - + # 2. pre-process skip_sample = sample sample = self.conv_in(sample)