In [2]:
import torch
from diffusers import UNet2DConditionModel

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
model = UNet2DConditionModel(
    sample_size=(45,1),             # size of the generated image
    in_channels=1,              # input channels (e.g., latents or features)
    out_channels=1,             # output channels
    layers_per_block=2,
    block_out_channels=(4,8,16, 16),  # number of channels in each block
    down_block_types=(
       "DownBlock2D","DownBlock2D","CrossAttnDownBlock2D", "CrossAttnDownBlock2D",
    ),
    up_block_types=(
         "CrossAttnUpBlock2D","CrossAttnUpBlock2D","UpBlock2D", "UpBlock2D",
    ),
    norm_num_groups=4,
    cross_attention_dim=16  # size of the text embedding or conditioning vector
)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

Total parameters: 160,505
Trainable parameters: 160,505


In [6]:
x = torch.randn(32,1,45,1)
print(x.shape)

timesteps = torch.randn(32)

cond_low = torch.randn(32,45,16)

out = model(x, timestep=timesteps, encoder_hidden_states=cond_low)
print(out.sample.shape)

torch.Size([32, 1, 45, 1])
torch.Size([32, 1, 45, 1])


In [11]:
x = torch.randn(32,1,64,1)
print(x.shape)

timesteps = torch.randn(32)

cond_low = torch.randn(32,1,16)

out = model(x, timestep=timesteps, encoder_hidden_states=cond_low)
print(out.sample.shape)

torch.Size([32, 1, 64, 1])
torch.Size([32, 1, 64, 1])


In [93]:
batch_size = 2
height = 45
width = 1

# Input (e.g., noisy latent image)
x = torch.randn(batch_size, 1, height, width)

# Timesteps (e.g., from diffusion schedule)
timesteps = torch.tensor([10, 20], dtype=torch.long)

# Dummy conditioning vector (e.g., text embedding, energy embedding, etc.)
encoder_hidden_states = torch.randn(batch_size, 1, 128)  # (batch, seq_len, cross_attention_dim)

In [96]:
# One with low conditioning value
cond_low = torch.full((batch_size, 1, 16), 0.1)

# One with high conditioning value
cond_high = torch.full((batch_size, 1, 16), 100.0)

torch.manual_seed(42)
cond_rand1 = torch.randn(batch_size, 1, 16)

torch.manual_seed(7)
cond_rand2 = torch.randn(batch_size, 1, 16)

In [97]:
with torch.no_grad():
    out_low = model(x, timestep=timesteps, encoder_hidden_states=cond_low)
    out_high = model(x, timestep=timesteps, encoder_hidden_states=cond_high)
    out_rand1 = model(x, timestep=timesteps, encoder_hidden_states=cond_rand1)
    out_rand2 = model(x, timestep=timesteps, encoder_hidden_states=cond_rand2)

In [99]:
out_low.sample.shape

torch.Size([2, 1, 45, 1])

In [98]:
diff = torch.mean(torch.abs(out_low.sample - out_high.sample))
print("Avg difference between low and high condition outputs:", diff.item())

Avg difference between low and high condition outputs: 0.05425228178501129


In [37]:
import torch

# Example tensor
before_x = torch.randn(2, 4, 1)  # (batch=16, layers=45, channel=1)

# Reshape it
after_x = before_x.permute(0, 2, 1).unsqueeze(-1)  # → (16, 1, 45, 1)

In [38]:
before_x.shape

torch.Size([2, 4, 1])

In [21]:
after_x.shape

torch.Size([2, 1, 4, 1])