In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np

# Load MNIST
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Lambda(lambda x: (x - 0.5) * 2)  # Scale between [-1, 1]
])

mnist = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
loader = DataLoader(mnist, batch_size=128, shuffle=True)

# Diffusion Hyperparameters
T = 1000  # number of noise steps
beta = np.linspace(1e-4, 0.02, T)
alpha = 1. - beta
alpha_bar = np.cumprod(alpha)

# Forward process (q(x_t | x_0)): Adds noise
def q_sample(x_start, t, noise=None):
    if noise is None:
        noise = torch.randn_like(x_start)
    sqrt_alpha_bar = torch.sqrt(torch.tensor(alpha_bar[t], dtype=torch.float32)).to(x_start.device)
    sqrt_one_minus_alpha_bar = torch.sqrt(1 - torch.tensor(alpha_bar[t], dtype=torch.float32)).to(x_start.device)
    return sqrt_alpha_bar * x_start + sqrt_one_minus_alpha_bar * noise

# Simple model for denoising (like a baby U-Net)
class DenoiseModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 32, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 1, 3, padding=1)
        )

    def forward(self, x, t_emb):
        return self.net(x)

model = DenoiseModel().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

# Training
epochs = 5
for epoch in range(epochs):
    for images, _ in loader:
        images = images.cuda()
        t = torch.randint(0, T, (images.size(0),), device=images.device).long()
        noise = torch.randn_like(images)
        x_noisy = torch.sqrt(torch.tensor(alpha_bar[t], device=images.device).view(-1, 1, 1, 1)) * images + \
                  torch.sqrt(1 - torch.tensor(alpha_bar[t], device=images.device).view(-1, 1, 1, 1)) * noise

        noise_pred = model(x_noisy, t)
        loss = loss_fn(noise_pred, noise)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    print(f"Epoch {epoch+1} Loss: {loss.item():.4f}")

# Reverse Sampling
@torch.no_grad()
def sample(model, n_samples=16):
    x = torch.randn(n_samples, 1, 28, 28).cuda()
    for t in reversed(range(T)):
        t_tensor = torch.full((n_samples,), t, dtype=torch.long).cuda()
        predicted_noise = model(x, t_tensor)
        alpha_t = alpha[t]
        alpha_bar_t = alpha_bar[t]
        beta_t = beta[t]

        x = (1 / np.sqrt(alpha_t)) * (x - ((1 - alpha_t) / np.sqrt(1 - alpha_bar_t)) * predicted_noise)
        if t > 0:
            x += np.sqrt(beta_t) * torch.randn_like(x)
    return x

# Plot samples
samples = sample(model)
samples = samples.cpu().numpy()
fig, axes = plt.subplots(1, 10, figsize=(15, 2))
for i in range(10):
    axes[i].imshow(samples[i][0], cmap='gray')
    axes[i].axis('off')
plt.suptitle("Stable Diffusion Generated Digits")
plt.show()
