In [18]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

# Ensure that the necessary modules are imported
import torchvision
from torchvision import transforms
from tqdm import tqdm


In [19]:
# Set device to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


In [20]:
# Define the U-Net model (from the code provided earlier)
class Unet(torch.nn.Module):
    """
    A simple U-Net architecture for MNIST that takes an input image and time
    """

    def __init__(self):
        super().__init__()
        nch = 2
        chs = [32, 64, 128, 256, 256]
        self._convs = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    torch.nn.Conv2d(
                        2, chs[0], kernel_size=3, padding=1
                    ),  # (batch, ch, 28, 28)
                    torch.nn.LogSigmoid(),  # (batch, 32, 28, 28)
                ),
                torch.nn.Sequential(
                    torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 14, 14)
                    torch.nn.Conv2d(
                        chs[0], chs[1], kernel_size=3, padding=1
                    ),  # (batch, ch, 14, 14)
                    torch.nn.LogSigmoid(),  # (batch, 64, 14, 14)
                ),
                torch.nn.Sequential(
                    torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 7, 7)
                    torch.nn.Conv2d(
                        chs[1], chs[2], kernel_size=3, padding=1
                    ),  # (batch, ch, 7, 7)
                    torch.nn.LogSigmoid(),  # (batch, 128, 7, 7)
                ),
                torch.nn.Sequential(
                    torch.nn.MaxPool2d(
                        kernel_size=2, stride=2, padding=1
                    ),  # (batch, ch, 4, 4)
                    torch.nn.Conv2d(
                        chs[2], chs[3], kernel_size=3, padding=1
                    ),  # (batch, ch, 4, 4)
                    torch.nn.LogSigmoid(),  # (batch, 256, 4, 4)
                ),
                torch.nn.Sequential(
                    torch.nn.MaxPool2d(kernel_size=2, stride=2),  # (batch, ch, 2, 2)
                    torch.nn.Conv2d(
                        chs[3], chs[4], kernel_size=3, padding=1
                    ),  # (batch, ch, 2, 2)
                    torch.nn.LogSigmoid(),  # (batch, 256, 2, 2)
                ),
            ]
        )
        self._tconvs = torch.nn.ModuleList(
            [
                torch.nn.Sequential(
                    # input is the output of convs[4]
                    torch.nn.ConvTranspose2d(
                        chs[4],
                        chs[3],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1,
                    ),  # (batch, 256, 4, 4)
                    torch.nn.LogSigmoid(),
                ),
                torch.nn.Sequential(
                    # input is the output from the above sequential concatenated with the output from convs[3]
                    torch.nn.ConvTranspose2d(
                        chs[3] * 2,
                        chs[2],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=0,
                    ),  # (batch, 128, 7, 7)
                    torch.nn.LogSigmoid(),
                ),
                torch.nn.Sequential(
                    # input is the output from the above sequential concatenated with the output from convs[2]
                    torch.nn.ConvTranspose2d(
                        chs[2] * 2,
                        chs[1],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1,
                    ),  # (batch, 64, 14, 14)
                    torch.nn.LogSigmoid(),
                ),
                torch.nn.Sequential(
                    # input is the output from the above sequential concatenated with the output from convs[1]
                    torch.nn.ConvTranspose2d(
                        chs[1] * 2,
                        chs[0],
                        kernel_size=3,
                        stride=2,
                        padding=1,
                        output_padding=1,
                    ),  # (batch, 32, 28, 28)
                    torch.nn.LogSigmoid(),
                ),
                torch.nn.Sequential(
                    # input is the output from the above sequential concatenated with the output from convs[0]
                    torch.nn.Conv2d(
                        chs[0] * 2, chs[0], kernel_size=3, padding=1
                    ),  # (batch, 32, 28, 28)
                    torch.nn.LogSigmoid(),
                    torch.nn.Conv2d(
                        chs[0], 1, kernel_size=3, padding=1
                    ),  # (batch, 1, 28, 28)
                ),
            ]
        )

    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        # x: (..., ch0 * 28 * 28), t: (..., 1)
        x2 = torch.reshape(x, (*x.shape[:-1], 1, 28, 28))  # (..., ch0, 28, 28)
        tt = t[..., None, None].expand(*t.shape[:-1], 1, 28, 28)  # (..., 1, 28, 28)
        x2t = torch.cat((x2, tt), dim=-3)
        signal = x2t
        signals = []
        for i, conv in enumerate(self._convs):
            signal = conv(signal)
            if i < len(self._convs) - 1:
                signals.append(signal)

        for i, tconv in enumerate(self._tconvs):
            if i == 0:
                signal = tconv(signal)
            else:
                signal = torch.cat((signal, signals[-i]), dim=-3)
                signal = tconv(signal)
        signal = torch.reshape(signal, (*signal.shape[:-3], -1))  # (..., 1 * 28 * 28)
        return signal


In [None]:
# Define the noise schedule (consistent with training)
def linear_beta_schedule(timesteps):
    beta_start = 1e-4
    beta_end = 0.02
    return torch.linspace(beta_start, beta_end, timesteps)


T = 200  # Number of diffusion steps
betas = linear_beta_schedule(timesteps=T)

# Pre-calculate different terms for the diffusion process
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = torch.cat(
    [torch.tensor([1.0], device=alphas.device), alphas_cumprod[:-1]]
)

sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)

sqrt_recip_alphas = torch.sqrt(1.0 / alphas)


# Move tensors to the appropriate device
betas = betas.to(device)
alphas = alphas.to(device)
alphas_cumprod = alphas_cumprod.to(device)
alphas_cumprod_prev = alphas_cumprod_prev.to(device)
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device)
sqrt_recip_alphas = sqrt_recip_alphas.to(device)
# posterior_variance = posterior_variance.to(device)

# Load the MNIST dataset
# transform = transforms.Compose(
#     [
#         transforms.ToTensor(),  # Convert to tensor
#         transforms.Normalize((0.5,), (0.5,)),  # Scale to [-1, 1]
#     ]
# )

transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Lambda(
            lambda x: x + torch.rand(x.shape) / 255
        ),  # Dequantize pixel values
        transforms.Lambda(lambda x: (x - 0.5) * 2.0),  # Map from [0,1] -> [-1, -1]
        transforms.Lambda(lambda x: x.flatten()),
    ]
)


train_dataset = torchvision.datasets.MNIST(
    root="./data", train=True, transform=transform, download=True
)
train_dataloader = torch.utils.data.DataLoader(
    train_dataset, batch_size=128, shuffle=True, drop_last=True
)

# Instantiate the model
model = Unet().to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

# Training loop
num_epochs = 1  # Adjust as needed
for epoch in range(num_epochs):
    print(f"Epoch {epoch+1}/{num_epochs}")
    for batch in tqdm(train_dataloader):
        optimizer.zero_grad()
        x0, _ = batch  # x0 is of shape [B, 1, 28, 28]
        x0 = x0.to(device)
        B = x0.shape[0]
        x0 = x0.reshape(B, -1)  # Flatten to [B, 28*28]

        # Sample random times t for each batch element
        t = torch.randint(0, T, (B,), device=device).long()

        # Get the noise terms
        noise = torch.randn_like(x0).to(device)

        # Calculate the perturbed x_t at time t
        sqrt_alpha_t = sqrt_alphas_cumprod[t].unsqueeze(1)
        sqrt_one_minus_alpha_t = sqrt_one_minus_alphas_cumprod[t].unsqueeze(1)
        xt = sqrt_alpha_t * x0 + sqrt_one_minus_alpha_t * noise

        # Normalize t for input to the model
        t_norm = t.float() / T
        t_norm = t_norm.unsqueeze(1).to(device)

        # Predict the noise using the model
        predicted_noise = model(xt, t_norm)

        # Calculate the loss
        loss = nn.MSELoss()(predicted_noise, noise)
        loss.backward()
        optimizer.step()

    print(f"Loss at epoch {epoch+1}: {loss.item()}")


  model.load_state_dict(torch.load("ddpm_unet_mnist.pth", map_location=device))


Epoch 1/1


 12%|█▏        | 58/468 [00:21<02:28,  2.75it/s]


KeyboardInterrupt: 

In [10]:
# Save the model
model_path = "ddpm_unet_mnist.pth"
torch.save(model.state_dict(), model_path)
print(f"Model saved as {model_path}")


Model saved as ddpm_unet_mnist.pth


In [11]:
@torch.no_grad()
def p_sample(model, x, t):
    """
    Sample from the model at time t
    """
    t = torch.tensor([t], device=device).long()
    t_float = t.float() / T
    t_float = t_float.expand(x.size(0), 1)  # Shape [B, 1]

    # Predict noise using the model
    eps_theta = model(x, t_float)

    # Calculate coefficients
    beta_t = betas[t].unsqueeze(1)
    sqrt_one_minus_alphas_cumprod_t = torch.sqrt(1 - alphas_cumprod[t]).unsqueeze(1)
    sqrt_recip_alphas_t = sqrt_recip_alphas[t].unsqueeze(1)

    # Update x
    x_pred = sqrt_recip_alphas_t * (
        x - beta_t / sqrt_one_minus_alphas_cumprod_t * eps_theta
    )

    # If t > 0, add noise; else, return x_pred
    if t > 0:
        noise = torch.randn_like(x)
        x_new = x_pred + torch.sqrt(beta_t) * noise
    else:
        x_new = x_pred

    return x_new


In [None]:
@torch.no_grad()
def p_sample_loop(model, shape):
    """
    Run the reverse diffusion process to generate samples
    """
    x = torch.randn(shape, device=device)  # Start from pure noise
    for t in tqdm(reversed(range(T)), desc="Sampling"):
        x = p_sample(model, x, t)
    return x


In [None]:
# Instantiate the model and load the trained weights
model = Unet().to(device)
model.load_state_dict(torch.load("ddpm_unet_mnist.pth", map_location=device))
model.eval()  # Set the model to evaluation mode

In [None]:
# Generate samples
num_samples = 16  # Number of images to generate
img_shape = (num_samples, 28 * 28)  # Flattened image shape

# Run the sampling process
generated_images = p_sample_loop(model, img_shape)

# Reshape and denormalize the images
generated_images = generated_images.cpu().numpy()
generated_images = np.clip(generated_images, -1.0, 1.0)  # Ensure values are in [-1, 1]
generated_images = (generated_images + 1.0) / 2.0  # Scale to [0, 1]
generated_images = generated_images.reshape(num_samples, 28, 28)

# Plot the images
fig, axes = plt.subplots(4, 4, figsize=(6, 6))
for i, ax in enumerate(axes.flatten()):
    ax.imshow(generated_images[i], cmap="gray")
    ax.axis("off")
plt.tight_layout()
plt.show()