Implementing a simple diffusion model for image generation involves several key steps: 

* defining the noise schedule, constructing the UNet architecture, 
* and setting up the training and sampling processes. 


**1. Prerequisites**

Ensure you have the necessary libraries installed:

```bash
pip install torch torchvision matplotlib
```

**2. Define the Noise Schedule**

The noise schedule determines how noise is added during the forward diffusion process. A common approach is to use a linear schedule.


In [6]:
import torch
import numpy as np

def linear_beta_schedule(timesteps, beta_start=0.0001, beta_end=0.02):
    return torch.linspace(beta_start, beta_end, timesteps)

timesteps = 1000
betas = linear_beta_schedule(timesteps)
alphas = 1.0 - betas
alpha_hats = torch.cumprod(alphas, dim=0)

## 3. Construct the UNet Model**

The UNet architecture is commonly used in diffusion models for its ability to capture multi-scale features.



In [7]:
import torch.nn as nn

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels, features=[64, 128, 256, 512]):
        super(UNet, self).__init__()
        self.encoder = nn.ModuleList()
        self.decoder = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        # Encoder
        for feature in features:
            self.encoder.append(self._block(in_channels, feature))
            in_channels = feature

        # Decoder
        for feature in reversed(features):
            self.decoder.append(
                nn.ConvTranspose2d(feature * 2, feature, kernel_size=2, stride=2)
            )
            self.decoder.append(self._block(feature * 2, feature))

        self.bottleneck = self._block(features[-1], features[-1] * 2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def forward(self, x):
        skip_connections = []

        for layer in self.encoder:
            x = layer(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1]

        for idx in range(0, len(self.decoder), 2):
            x = self.decoder[idx](x)
            skip_connection = skip_connections[idx // 2]
            if x.shape != skip_connection.shape:
                x = torch.nn.functional.interpolate(x, size=skip_connection.shape[2:])
            x = torch.cat((skip_connection, x), dim=1)
            x = self.decoder[idx + 1](x)

        return self.final_conv(x)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
            nn.ReLU(inplace=True),
        )


## 4. Training Loop

Train the model to predict the noise added at each timestep.


In [8]:
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader


In [9]:
# Hyperparameters
batch_size = 64
learning_rate = 1e-4
epochs = 100

# Data preparation
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

dataset = datasets.MNIST(root='data', train=True, transform=transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)


## Model, optimizer, and loss function


In [10]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = UNet(in_channels=1, out_channels=1).to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()

# Training loop
for epoch in range(epochs):
    for images, _ in dataloader:
        images = images.to(device)
        t = torch.randint(0, timesteps, (images.size(0),), device=device).long()
        noise = torch.randn_like(images)
        noisy_images = (
            torch.sqrt(alpha_hats[t])[:, None, None, None] * images +
            torch.sqrt(1 - alpha_hats[t])[:, None, None, None] * noise
        )

        optimizer.zero_grad()
        noise_pred = model(noisy_images)
        loss = criterion(noise_pred, noise)
        loss.backward()
        optimizer.step()

    print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')


KeyboardInterrupt: 

## 5. Sampling (Image Generation)

Generate new images by starting from random noise and iteratively denoising.


In [None]:
import matplotlib.pyplot as plt

def sample(model, timesteps, image_size, device):
    model.eval()
    with torch.no_grad():
        x = torch.randn((1, 1, image_size, image_size), device=device)
        for t in reversed(range(timesteps)):
            z = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
            alpha_hat = alpha_hats[t]
            beta = betas[t]
            x = (
                1 / torch.sqrt(alphas[t]) * (
                    x - (1 - alphas[t]) / torch.sqrt(1 - alpha_hat) * model(x)
                ) + torch.sqrt(beta) * z
            )
    return x


## Generate and display an image

In [None]:
generated_image = sample(model, timesteps, 64, device).cpu().squeeze()
plt.imshow(generated_image, cmap='gray')
plt.show()

