-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Description
Describe the bug
Hello I was trying to add condition to a pre-trained UNet2DModel and re-train it with condition. I was thinking I could do that just by passing condition with class_labels parameter. After sometime I realized that if unet model is not initialized with class_embed_type the forward method completely ignores the class_labels parameter. The problem can be found in unet_2d.py between lines 285-293:
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when doing class conditioning")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
Before checking out the the unet_2d.py I thought it already had an embedding which would be used whenever class_labels parameter is passed. So I was trying to use it that way. Maybe I can open a feature request or a pull request for it but I think there should at least be a warning or an error. You can pass a parameter and it might just be ignored. I added a snippet for reproduction. You can see that passing class_labels does not change anything at all.
Reproduction
unet = UNet2DModel(
sample_size=32,
in_channels=3,
out_channels=3,
layers_per_block=2,
block_out_channels=(128, 128, 256, 256, 256),
down_block_types=(
"DownBlock2D",
"DownBlock2D",
"DownBlock2D",
"AttnDownBlock2D",
"DownBlock2D",
),
up_block_types=(
"UpBlock2D",
"AttnUpBlock2D",
"UpBlock2D",
"UpBlock2D",
"UpBlock2D",
),
)
batch_size = 8
rand_noise = torch.rand(batch_size,3,32,32)
timesteps = torch.randint(low=0,high=1000,size=(batch_size,))
class_labels = torch.randint(low=0,high=1000,size=(1,batch_size)).long()
out_1 = unet(rand_noise, timesteps,class_labels = class_labels).sample
out_2 = unet(rand_noise, timesteps).sample
print(torch.equal(out_1, out_2))
Logs
No response
System Info
Colab
Who can help?
No response