Skip to content

UNet2DModel ignores class_labels parameter in forward method #5330

@kesimeg

Description

@kesimeg

Describe the bug

Hello I was trying to add condition to a pre-trained UNet2DModel and re-train it with condition. I was thinking I could do that just by passing condition with class_labels parameter. After sometime I realized that if unet model is not initialized with class_embed_type the forward method completely ignores the class_labels parameter. The problem can be found in unet_2d.py between lines 285-293:

    if self.class_embedding is not None:
        if class_labels is None:
            raise ValueError("class_labels should be provided when doing class conditioning")

        if self.config.class_embed_type == "timestep":
            class_labels = self.time_proj(class_labels)

        class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
        emb = emb + class_emb

Before checking out the the unet_2d.py I thought it already had an embedding which would be used whenever class_labels parameter is passed. So I was trying to use it that way. Maybe I can open a feature request or a pull request for it but I think there should at least be a warning or an error. You can pass a parameter and it might just be ignored. I added a snippet for reproduction. You can see that passing class_labels does not change anything at all.

Reproduction

unet = UNet2DModel(
sample_size=32,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)

batch_size = 8
rand_noise = torch.rand(batch_size,3,32,32)
timesteps = torch.randint(low=0,high=1000,size=(batch_size,))
class_labels = torch.randint(low=0,high=1000,size=(1,batch_size)).long()

out_1 = unet(rand_noise, timesteps,class_labels = class_labels).sample
out_2 = unet(rand_noise, timesteps).sample
print(torch.equal(out_1, out_2))

Logs

No response

System Info

Colab

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions