diff --git a/src/diffusers/models/unet_2d.py b/src/diffusers/models/unet_2d.py index db6d3a5dce3f..38e26422e2a7 100644 --- a/src/diffusers/models/unet_2d.py +++ b/src/diffusers/models/unet_2d.py @@ -291,6 +291,8 @@ def forward( class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb + elif self.class_embedding is None and class_labels is not None: + raise ValueError("class_embedding needs to be initialized in order to use class conditioning") # 2. pre-process skip_sample = sample