# Conditioned diffusion models
Imagine you want to train a diffusion model on a dataset with 10 classes. How do you control which class is generated? This lab shows one way to add *conditioning information* to a diffusion model. Specifically, we’ll train a class-conditioned diffusion model on FashionMNIST, where we can specify which class we’d like the model to generate at inference time.

## Setup and Data Prep

In [None]:
%pip install -q diffusers  # Installing

In [None]:
import torch
import torchvision
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from diffusers import DDPMScheduler, UNet2DModel
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
# Identify and choose device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

In [None]:
#Loading the Fahion minst dataset a dataset containing images and numerical class labels
dataset = torchvision.datasets.FashionMNIST(root="FashionMNIST/", train=True, download=True, transform=torchvision.transforms.ToTensor())

In [None]:
# Stealing a predefined dataloader to show examples
train_dataloader = DataLoader(dataset, batch_size=10, shuffle=True)

# Showing examples from dataset
x, y = next(iter(train_dataloader))
print('Input shape:', x.shape)
print('Labels:', y)
plt.imshow(torchvision.utils.make_grid(x)[0], cmap='Greys');

## Creating a Class-Conditioned UNet

The way we’ll feed in the class conditioning is as follows:

- Create a standard UNet2DModel with some additional input channels
- Map the class label to a learned vector of shape (`class_emb_size`) via an embedding layer
- Concatenate this information as extra channels for the internal UNet input with `net_input = torch.cat((x, class_cond), 1)`
- Feed this net_input (which has (`class_emb_size+1`) channels in total) into the UNet to get the final prediction


In this example I’ve set the `class_emb_size` to 4, but this is completely arbitrary and you could explore having it size 1 (to see if it still works), size 10 (to match the number of classes), or replacing the learned nn.Embedding with a simple one-hot encoding of the class label directly.

This is what the implementation looks like:

In [None]:
class class_conditioned_unet(nn.Module):
  def __init__(self, num_classes=10, class_emb_size=4):
    super().__init__()

    # Embedding layer will map the class label to a vector of size class_emb_size
    self.class_emb = nn.Embedding(num_classes, class_emb_size)

    # Self.model is an unconditional UNet with extra input channels to accept the conditioning information (the class embedding)
    self.model = UNet2DModel(
        sample_size=28,           # the target image resolution
        in_channels=1 + class_emb_size, # Additional input for class condition
        out_channels=1,           # grea
        layers_per_block=2,       # how many ResNet layers to use per UNet block
        block_out_channels=(32, 64, 64),
        down_block_types=(
            "DownBlock2D",        # Regular ResNet downsampling block
            "AttnDownBlock2D",    # Downsampling block with spatial self-attention
            "AttnDownBlock2D",
        ),
        up_block_types=(
            "AttnUpBlock2D",
            "AttnUpBlock2D",      # a ResNet upsampling block with spatial self-attention
            "UpBlock2D",          # a regular ResNet upsampling block
          ),
    )

  # forward method  takes the class labels as an additional argument
  def forward(self, x, t, class_labels):
    # Shape of x:
    bs, ch, w, h = x.shape

    # class conditioning in right shape to give as additional input
    class_cond = self.class_emb(class_labels) # Map to embedding dimension
    class_cond = class_cond.view(bs, class_cond.shape[1], 1, 1).expand(bs, class_cond.shape[1], w, h) # Expand to contain class embedding

    # Net input is now x and class cond concatenated together along dimension 1
    net_input = torch.cat((x, class_cond), 1) # concat to one dim input

    # Feed this to the UNet alongside the timestep and return the prediction
    return self.model(net_input, t).sample # (bs, 1, 28, 28)

If any of the shapes or transforms are confusing, add in print statements to show the relevant shapes and check that they match your expectations. I’ve also annotated the shapes of some intermediate variables in the hopes of making things clearer.

## Training and Sampling

Where previously we’d do something like `prediction = unet(x, t)` we’ll now add the correct labels as a third argument (`prediction = unet(x, t, y)`) during training, and at inference we can pass whatever labels we want and if all goes well the model should generate images that match. `y` in this case is the labels of the FashionMNIST digits, with values from 0 to 9.

The training loop is very similar to the unconditioned diffusion model. We’re now predicting the noise to match the objective expected by the default DDPMScheduler which we’re using to add noise during training and to generate samples at inference time. Training takes a while - speeding this up could be a fun mini-project, but most of you can probably just skim the code (and indeed this whole notebook) without running it since we’re just illustrating an idea.

In [None]:
# Create noise
noise_scheduler = DDPMScheduler(num_train_timesteps=1000, beta_schedule='squaredcos_cap_v2')

In [None]:
# Train dataloader
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

# Number of epochs
n_epochs = 5

# Our network
net = class_conditioned_unet().to(device)

# Our loss function
loss_fn = nn.MSELoss()

# Adam  optimizer and learning rate of
opt = torch.optim.Adam(net.parameters(), lr=1e-3)

# Save loss for plot later
losses = []

# Thraditional training loop
for epoch in range(n_epochs):
    for x, y in tqdm(train_dataloader):

        # Retrive data and corrupt it with noise
        x = x.to(device) * 2 - 1
        y = y.to(device)
        noise = torch.randn_like(x) # Generate noise
        timesteps = torch.randint(0, 999, (x.shape[0],)).long().to(device) # Determine the dregree in which the image is gradually turned into noise,
        noisy_x = noise_scheduler.add_noise(x, noise, timesteps) # Add noise  for the given timestep

        # Get the model prediction
        pred = net(noisy_x, timesteps, y)

        # Calculate the loss
        loss = loss_fn(pred, noise)

        # Backprop and update the params:
        opt.zero_grad()
        loss.backward()
        opt.step()

        # Store loss
        losses.append(loss.item())

    # Average of the last 50
    avg_loss = sum(losses[-50:])/50
    print(f'Finished epoch {epoch}. Average of the last 50 loss values: {avg_loss:05f}')

# View the loss curve
plt.plot(losses)

Once training finishes, we can sample some images feeding in different labels as our conditioning:

In [None]:
def prediction(a, b):
  a = a # number of images/predictions for each number given as input
  b = b # range of number u want to predict remember we start in zero
  z = a * b
  x = torch.rand(z, 1, 28, 28).to(device)
  y = torch.tensor([[i]*a for i in range(b)]).flatten().to(device)

  # Sampling loop
  for i, t in tqdm(enumerate(noise_scheduler.timesteps)):

    # Do model prediction
      with torch.no_grad():
          residual = net(x, t, y)

    # Update sample with step
      x = noise_scheduler.step(residual, t, x).prev_sample

  # Results
  fig, ax = plt.subplots(1, 1, figsize=(12, 12))
  ax.imshow(torchvision.utils.make_grid(x.detach().cpu().clip(-1, 1), nrow=b)[0], cmap='Greys')

In [None]:
text = int(input("Number of predictions for each category (try 1): "))
cnt = int(input("Number of clothes categories you want to predict (try 10, then categories 0, 1, ..., 9 will be predicted): "))
prediction(text, cnt)

In [None]:
# Exercise (optional): Try this with FashionMNIST. Tweak the learning rate, batch size and number of epochs.
# Can you get some decent-looking fashion images with less training time than the example above?