## Classifier-Free Guidance

In the previous notebooks, we've seen how to build conditional diffusion models that can generate specific types of images based on labels (like generating a specific digit from MNIST). However, there's a common problem with this approach: **the model sometimes ignores or only partially follows the conditioning signal**.

This happens because the model has a natural tendency to generate "generic valid images" rather than being strongly steered by the label. The model learns that generating any plausible image is acceptable, so it doesn't always prioritize the specific condition we provide.

**Classifier-Free Guidance (CFG)** is a powerful technique that solves this problem by training the model to perform both conditional and unconditional generation simultaneously. During inference, we can then blend these two predictions to get stronger adherence to the condition while maintaining high-quality outputs.

In this notebook, we'll:
- Understand the intuition behind classifier-free guidance
- Modify our conditional U-Net to support both conditional and unconditional generation
- Train the model with label dropout to learn both modes
- Implement the guidance mechanism during sampling
- Explore how different guidance scales affect generation quality

### The Core Idea

The key insight of classifier-free guidance is to train a single model that can operate in two modes:
1. **Conditional mode**: $\epsilon_\theta(x_t, t, y)$ - predicts noise given the noisy image, timestep, and condition (label)
2. **Unconditional mode**: $\epsilon_\theta(x_t, t)$ - predicts noise given only the noisy image and timestep

During training, we randomly drop the condition (replace it with a special "null" token) so the model learns both behaviors. At inference time, we compute both predictions and combine them using a guidance scale $w$:

$$\epsilon_{final} = \epsilon_\theta(x_t, t) + w \cdot (\epsilon_\theta(x_t, t, y) - \epsilon_\theta(x_t, t))$$

This can be rewritten as:
$$\epsilon_{final} = (1-w) \cdot \epsilon_\theta(x_t, t) + w \cdot \epsilon_\theta(x_t, t, y)$$

**Intuition**: The difference $(\epsilon_\theta(x_t, t, y) - \epsilon_\theta(x_t, t))$ represents the "direction" that the condition pushes the generation. By scaling this difference and adding it to the unconditional prediction, we amplify the effect of the condition.

**Guidance Scale Behavior**:
- $w = 0$: Pure unconditional generation (ignores the label completely)
- $w = 1$: Standard conditional generation (no guidance amplification)
- $w > 1$: Stronger adherence to the condition (higher values = stronger guidance, but can sometimes reduce diversity)

The beauty of this approach is that we don't need a separate classifier or any additional models, as everything is learned within a single network.

Reference paper: [_Classifier-Free Diffusion Guidance_](https://arxiv.org/abs/2006.11239)

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import torchvision
import torchvision.transforms as transforms

### Step 1: Imports and Setup

We'll reuse the same libraries and utilities from our previous image generation notebook. 

In [None]:
def sample_t(x_0, t, alpha_bars):
    epsilon = torch.randn_like(x_0)
    x_t = torch.sqrt(alpha_bars[t]) * x_0 + torch.sqrt(1-alpha_bars[t]) * epsilon
    
    return x_t, epsilon

class SinusoidalEmbedding(nn.Module):
    def __init__(self, embedding_dim):
        super().__init__()
        self.embedding_dim = embedding_dim
        
    def forward(self, t):
            device = t.device
            emb = torch.zeros(t.shape[0], self.embedding_dim, device=device)
            
            for i in range(self.embedding_dim // 2):
                # We ensure the constant is on the correct device too
                const = torch.tensor(10000.0, device=device) 
                omega_i = torch.exp(-(2*i/self.embedding_dim) * torch.log(const))
                
                emb[:, 2*i] = torch.sin(omega_i * t)
                emb[:, 2*i+1] = torch.cos(omega_i * t)
            
            return emb

In [None]:
class Block(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim):
        super().__init__()
        
        # 1. First Convolution: Change channels (e.g. 1 -> 32)
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_ch)
        self.relu = nn.ReLU()
        
        # 2. Time Projection: Map time to match 'out_ch'
        self.time_mlp = nn.Linear(time_emb_dim, out_ch)
        
        # 3. Second Convolution: Refine features
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_ch)
        
    def forward(self, x, t):
        # First Conv
        h = self.conv1(x)
        h = self.bn1(h)
        h = self.relu(h)
        
        # Add Time Embedding
        # (Batch, Time_Dim) -> (Batch, Out_Ch) -> (Batch, Out_Ch, 1, 1)
        time_emb = self.time_mlp(t)
        time_emb = time_emb[(..., ) + (None, ) * 2] # Broadcast to 4D
        h = h + time_emb
        
        # Second Conv
        h = self.conv2(h)
        h = self.bn2(h)
        h = self.relu(h)
        
        return h

Now we define the full U-Net architecture. The key modification for classifier-free guidance is in the label embedding: we use **11 classes** instead of 10. The 11th class (index 10) represents the "null" or "unconditional" token that we'll use during training when we drop the label.


In [None]:
class CFGUnet(nn.Module):
    def __init__(self):
        super().__init__()
        # Time Embedding
        self.time_mlp = nn.Sequential(
            SinusoidalEmbedding(32),
            nn.Linear(32, 32),
            nn.ReLU()
        )
        
        self.label_embedding = nn.Embedding(11, 32)
        
        # 1. Down Path
        self.down1 = Block(1, 32, 32)
        self.down2 = Block(32, 64, 32)
        self.down3 = Block(64, 128, 32)
        
        self.pool = nn.MaxPool2d(2)
        
        # 2. Bottleneck
        self.bottleneck = Block(128, 256, 32)
        
        # 3. Up Path
        # We separate the Upsampling (ConvTrans) from the Block (Processing)
        
        # Up 1: 4x4 -> 8x8
        self.up_trans1 = nn.ConvTranspose2d(256, 256, 4, 2, 1) 
        self.up1 = Block(256 + 128, 128, 32) # In: Bottle + Skip(x3)
        
        # Up 2: 8x8 -> 16x16
        self.up_trans2 = nn.ConvTranspose2d(128, 128, 4, 2, 1)
        self.up2 = Block(128 + 64, 64, 32)   # In: Prev + Skip(x2)
        
        # Up 3: 16x16 -> 32x32
        self.up_trans3 = nn.ConvTranspose2d(64, 64, 4, 2, 1)
        self.up3 = Block(64 + 32, 32, 32)    # In: Prev + Skip(x1)
        
        # Final projection
        self.final = nn.Conv2d(32, 1, 3, padding=1)

    def forward(self, x, t, label):
        t = self.time_mlp(t)
        label_emb = self.label_embedding(label)
        
        t = t + label_emb
        
        # --- Down Path ---
        x1 = self.down1(x, t)        # (32, 32, 32)
        x_p1 = self.pool(x1)         # (32, 16, 16)
        
        x2 = self.down2(x_p1, t)     # (64, 16, 16)
        x_p2 = self.pool(x2)         # (64, 8, 8)
        
        x3 = self.down3(x_p2, t)     # (128, 8, 8)
        x_p3 = self.pool(x3)         # (128, 4, 4)
        
        # --- Bottleneck ---
        x = self.bottleneck(x_p3, t) # (256, 4, 4)
        
        # --- Up Path ---
        
        # Step 1: Upsample -> Concat -> Process
        x = self.up_trans1(x)                          # (256, 8, 8)
        x = self.up1(torch.cat((x, x3), dim=1), t)     # (128, 8, 8)
        
        # Step 2
        x = self.up_trans2(x)                          # (128, 16, 16)
        x = self.up2(torch.cat((x, x2), dim=1), t)     # (64, 16, 16)
        
        # Step 3
        x = self.up_trans3(x)                          # (64, 32, 32)
        x = self.up3(torch.cat((x, x1), dim=1), t)     # (32, 32, 32)
        
        return self.final(x)

We'll use the same MNIST dataset and data loading function as before. The dataset provides images and their corresponding digit labels (0-9), which we'll use for conditional generation.


In [None]:
def load_mnist(image_size=32, batch_size=128, device='cpu'):
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.ToTensor(),
        transforms.Lambda(lambda t: (t * 2) - 1)
    ])
    dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    
    return dataset, dataloader

In [None]:
IMAGE_SIZE = 32
BATCH_SIZE = 128

device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")

dataset, dataloader = load_mnist(image_size=IMAGE_SIZE, batch_size=BATCH_SIZE, device=device)

We'll also need our helper function to add noise to images during training:


In [None]:
def get_noisy_image(x_0, t):
    # This now handles batch of images (B, C, H, W)
    sqrt_alpha_bar = torch.sqrt(alpha_bars[t])[:, None, None, None]
    sqrt_one_minus_alpha_bar = torch.sqrt(1 - alpha_bars[t])[:, None, None, None]
    
    epsilon = torch.randn_like(x_0)
    x_t = sqrt_alpha_bar * x_0 + sqrt_one_minus_alpha_bar * epsilon
    return x_t, epsilon


### Step 2: Training Setup

Now we configure the training parameters. The key addition here is `LABEL_DROPOUT`, which controls the probability that we'll replace a real label with the null token (index 10) during training. This is what allows the model to learn both conditional and unconditional generation.

A typical value for label dropout is around 0.1-0.2, meaning 10-20% of training samples will be unconditional. This ensures the model sees enough unconditional examples to learn that mode, while still primarily learning conditional generation.


In [None]:
TIMESTEPS = 300
EPOCHS = 5 

LABEL_DROPOUT = 0.1 # Probability of dropping label conditioning

# --- 3. Model & Utils ---
model = CFGUnet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.MSELoss()

# Recalculate alphas/betas for the image process
betas = torch.linspace(1e-4, 0.02, TIMESTEPS).to(device)
alphas = 1 - betas
alpha_bars = torch.cumprod(alphas, dim=0)

In [None]:
# --- 4. Training Loop ---
print("Starting Training...")
model.train()

for epoch in range(EPOCHS):
    epoch_losses = []
    for step, (images, labels) in enumerate(dataloader):
        images = images.to(device)
        labels = labels.to(device)
        
        # apply label dropout
        for i in range(labels.shape[0]):
            if torch.rand(1) < LABEL_DROPOUT:
                labels[i] = torch.tensor(10, device=device)
        
        # 1. Sample timesteps
        t = torch.randint(0, TIMESTEPS, (BATCH_SIZE,), device=device).long()
        
        # 2. Add Noise
        x_t, epsilon = get_noisy_image(images, t)
        
        # 3. Predict Noise
        pred_epsilon = model(x_t, t, labels)
        
        # 4. Optimize
        loss = loss_fn(pred_epsilon, epsilon)
        epoch_losses.append(loss.item())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if step % 100 == 0:
            print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")
    print(f"Average Epoch {epoch} Loss: {sum(epoch_losses) / len(epoch_losses):.4f}")
print("Training Complete!")

### Step 3: Sampling with Classifier-Free Guidance

Now comes the exciting part: using classifier-free guidance during generation! The sampling function works as follows:

1. **Start with pure noise** for the images we want to generate
2. **For each timestep**, we need to compute both conditional and unconditional predictions:
   - We duplicate the noisy images and timesteps
   - We create two sets of labels: the real labels (0-9) and null labels (all 10s)
   - We pass everything through the model in one batch to get both predictions
3. **Apply the guidance formula**: We combine the predictions using the guidance scale
4. **Update the images** using the combined prediction, just like in standard reverse diffusion

The key insight is that we can efficiently compute both predictions in a single forward pass by batching them together. This makes classifier-free guidance computationally efficient, as we only need one model evaluation per timestep, not two separate ones.


In [None]:
@torch.no_grad()
def sample_mnist(model):
    model.eval()
    n_samples = 10 # We want exactly 10 digits (0-9)
    
    # 1. Start with pure noise for 10 images
    x = torch.randn(n_samples, 1, IMAGE_SIZE, IMAGE_SIZE).to(device)
    
    # Generate specific labels 0 through 9
    labels = torch.arange(10, device=device).long()
    
    # Create the Null labels (all 10s)
    null_labels = torch.full((n_samples,), 10, device=device).long()
    
    # Combine labels once (0..9, 10..10)
    labels_in = torch.cat((labels, null_labels), dim=0)
    
    # 2. Loop backwards
    for t in range(TIMESTEPS - 1, -1, -1):
        # --- PREPARE INPUTS ---
        # Double the noise x just for the model pass
        x_in = torch.cat([x, x], dim=0)
        
        # Double the timestep t
        t_batch = torch.full((n_samples,), t, device=device, dtype=torch.long)
        t_in = torch.cat((t_batch, t_batch), dim=0)
        
        # --- MODEL PASS ---
        # We feed in batch of 20
        predicted_noise = model(x_in, t_in, labels_in)
        
        # --- CLASSIFIER FREE GUIDANCE ---
        # Split the output back into size 10
        cond_pred = predicted_noise[:n_samples]
        uncond_pred = predicted_noise[n_samples:]
        
        # Combine using the formula
        # GUIDANCE_SCALE should be defined globally (e.g., 3.0 or 4.0)
        epsilon = uncond_pred + GUIDANCE_SCALE * (cond_pred - uncond_pred)
        
        # --- UPDATE STEP ---
        # Get constants reshaped to (1, 1, 1, 1) for broadcasting
        alpha_t = alphas[t].view(1, 1, 1, 1)
        alpha_bar_t = alpha_bars[t].view(1, 1, 1, 1)
        beta_t = betas[t].view(1, 1, 1, 1)
        sigma_t = torch.sqrt(beta_t)
        
        if t > 0:
            z = torch.randn_like(x)
        else:
            z = torch.zeros_like(x)
        
        # Update the original x (size 10) using the combined epsilon (size 10)
        x = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / torch.sqrt(1 - alpha_bar_t)) * epsilon) + sigma_t * z
        
    return x

### Step 4: Generating Images with Guidance

Let's test our trained model! We'll generate one of each digit (0-9) using a guidance scale of 2.0. This means we're amplifying the conditional signal by a factor of 2, which should result in stronger adherence to the specified labels compared to standard conditional generation ($w=1.0$).


In [None]:

GUIDANCE_SCALE = 2.0
# --- Run the Sampling ---
print("Sampling from the model...")
generated_images = sample_mnist(model)

# --- Visualization ---
# Un-normalize from [-1, 1] back to [0, 1]
generated_images = (generated_images + 1) / 2
generated_images = generated_images.clamp(0, 1).cpu()

# Plot as a grid
fig, axes = plt.subplots(2, 5, figsize=(12, 4))
for i, ax in enumerate(axes.flat):
    ax.imshow(generated_images[i, 0], cmap='gray')
    ax.axis('off')
plt.suptitle("Generated Digits")
plt.show()

Look at the generated digits above: we should be able to obtain higher quality generation with respect to the pure conditional version of our model with no guidance.

### Exploring Different Guidance Scales

One of the most interesting aspects of classifier-free guidance is how the guidance scale affects generation quality. Let's generate images at multiple guidance scales to see the effect:

- **$w = 0.0$**: Pure unconditional generation (the model ignores labels completely)
- **$w = 1.0$**: Standard conditional generation (no guidance amplification)
- **$w = 2.0, 3.0, 4.0$**: Increasingly stronger guidance

As we increase the guidance scale, we typically see:
- **Stronger adherence** to the specified labels
- **Potentially higher quality** samples (the model is more "confident" in following the condition)
- **Reduced diversity** (the model becomes more deterministic)

However, too high guidance scales can sometimes lead to artifacts or over-saturation. The optimal value depends on the task and model, but values between 2-4 are commonly used in practice.


In [None]:
guidance_scales = [0.0, 1.0, 2.0, 3.0, 4.0]
all_generated_images = []

for scale in guidance_scales:
    GUIDANCE_SCALE = scale  # Assuming GUIDANCE_SCALE is a global variable used in sample_mnist
    print(f"Sampling from the model with GUIDANCE_SCALE={scale}...")
    images = sample_mnist(model)
    images = (images + 1) / 2
    images = images.clamp(0, 1).cpu()
    all_generated_images.append(images)

# Plot each guidance scale's samples in a column
fig, axes = plt.subplots(len(guidance_scales), 10, figsize=(16, 5))
for row, (scale, images) in enumerate(zip(guidance_scales, all_generated_images)):
    for col in range(10):
        ax = axes[row, col] if len(guidance_scales) > 1 else axes[col]
        ax.imshow(images[col, 0], cmap='gray')
        ax.axis('off')
    axes[row, 0].set_ylabel(f"Scale {scale}", fontsize=12)

plt.suptitle("Generated Digits at Different Guidance Scales", fontsize=16)
plt.tight_layout(rect=[0, 0.02, 1, 0.95])
plt.show()


## Final Thoughts

Classifier-free guidance is one of the most important techniques in modern diffusion models. It was a key innovation that made models like DALL-E 2 and Stable Diffusion so effective at following text prompts and other conditions.

**Key Concepts:**
- **Dual Training**: The model learns both conditional and unconditional generation by randomly dropping conditions during training
- **Guidance Formula**: At inference, we blend conditional and unconditional predictions to amplify the effect of the condition
- **Single Model**: Unlike other guidance methods, CFG doesn't require a separate classifier, as everything is learned in one network
- **Guidance Scale**: A hyperparameter that controls how strongly the model follows the condition (higher = stronger adherence, but potentially less diverse)

**What Makes This Powerful:**
The elegance of classifier-free guidance lies in its simplicity. By training the model to handle both conditional and unconditional cases, we get a natural way to control generation strength at inference time. The guidance mechanism effectively "steers" the generation in the direction specified by the condition, without requiring any additional models or complex training procedures.

**Why It Works:**
The difference $(\epsilon_\theta(x_t, t, y) - \epsilon_\theta(x_t, t))$ represents the "signal" that the condition provides. When we scale this difference and add it to the unconditional prediction, we're essentially asking: "What would the unconditional model do, plus an amplified version of how the condition changes that prediction?" This results in generation that follows the condition more strongly while maintaining the quality learned from unconditional training.

**Next Steps:**
This notebook covered classifier-free guidance for discrete labels, but the same principle extends to:
- **Text conditioning**: Using text embeddings (like CLIP) instead of discrete labels
- **Multi-conditioning**: Combining multiple conditions (e.g., text + style + class)
