## Project Introduction
In this project, we will look into Score Matching method of training an energy based model.

In [22]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt
import os
import sys
import math
from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cuda


In [23]:
class NoisyCIFAR10(torch.utils.data.Dataset):
    def __init__(self, sigma_levels, train=True):
        super().__init__()
        self.data = datasets.CIFAR10(
            root="./data",
            train=train,
            download=True,
            transform=transforms.ToTensor()
        )
        self.sigma_levels = torch.tensor(sigma_levels).float()

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x, _ = self.data[idx]

        sigma = self.sigma_levels[torch.randint(0, len(self.sigma_levels), (1,))].item()

        noise = torch.randn_like(x) * sigma
        x_noisy = x + noise

        return x, x_noisy, sigma


In [47]:
class GaussianFourierProjection(torch.nn.Module):
    """Encode log(sigma) using random Fourier features."""
    def __init__(self, embedding_size=128, scale=1.0):
        super().__init__()
        self.W = torch.randn(embedding_size // 2, dtype=torch.float32) * scale

    def forward(self, sigma):
        # sigma: [B]
        sigma = sigma.view(-1, 1).float()  # Explicit float32 conversion
        x_proj = sigma * self.W.to(sigma.device) * 2 * math.pi
        return torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)

In [49]:
class ScoreNet(nn.Module):
    def __init__(self, embedding_size=128):
        super().__init__()
        self.embedding = GaussianFourierProjection(embedding_size)

        self.cond_proj = nn.Sequential(
            nn.Linear(embedding_size, 128),
            nn.ReLU()
        )

        self.net = nn.Sequential(
            nn.Conv2d(3 + 1, 64, 3, padding=1),  # 3 image channels + 1 noise channel
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 512, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(512, 256, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(256, 128, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(128, 64, 3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 3, 3, padding=1)
        )

    def forward(self, x_noisy, sigma):
        emb = self.embedding(torch.log(sigma))
        cond = self.cond_proj(emb)

        cond = cond.view(-1, 128, 1, 1).expand(-1, 128, x_noisy.shape[2], x_noisy.shape[3])

        sigma_map = sigma.view(-1, 1, 1, 1).expand(-1, 1, x_noisy.shape[2], x_noisy.shape[3])
        input = torch.cat([x_noisy, sigma_map], dim=1)  # [B, 4, H, W]

        return self.net(input)


In [50]:
def dsm_loss(score_model, x, x_noisy, sigma):
    """
    x: clean image [B, 3, H, W]
    x_noisy: noisy image [B, 3, H, W]
    sigma: noise level [B]
    """
    # Ensure all inputs are float32
    x = x.float()
    x_noisy = x_noisy.float()
    sigma = sigma.float()

    z = (x_noisy - x) / sigma.view(-1, 1, 1, 1)  # [B, 3, 32, 32]

    # Model prediction
    score_pred = score_model(x_noisy, sigma)

    target = -z
    
    loss = ((score_pred - target) ** 2).sum(dim=(1, 2, 3))  # per-sample loss
    return loss.mean()  # average over batch

In [51]:
@torch.no_grad()
def annealed_langevin_sample(score_model, sigmas, num_steps=20, step_size=0.01, shape=(16, 3, 32, 32)):
    score_model.eval()
    x = torch.randn(*shape).to(device)

    for sigma in reversed(sigmas):  # large noise → small noise
        sigma = torch.tensor([sigma] * shape[0], device=device)

        for _ in range(num_steps):
            x.requires_grad = True
            score = score_model(x, sigma)
            x = x + 0.5 * step_size**2 * score
            x = x + step_size * torch.randn_like(x)
            x = x.detach()

    return x.clamp(0, 1)  # Clamp to [0, 1] range for display


def show_samples(x, nrow=4):
    x = x.cpu()
    fig, axes = plt.subplots(nrow, nrow, figsize=(6, 6))
    for i, ax in enumerate(axes.flat):
        img = x[i].permute(1, 2, 0).numpy()
        ax.imshow(img)
        ax.axis("off")
    plt.tight_layout()
    plt.show()



In [52]:
SIGMA_LEVELS = torch.exp(torch.linspace(math.log(0.01), math.log(1.0), 10))
num_epochs = 5000

train_dataset = NoisyCIFAR10(sigma_levels=SIGMA_LEVELS, train=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)

  self.sigma_levels = torch.tensor(sigma_levels).float()


In [53]:
model = ScoreNet().to(device)
# Make sure all model parameters are float32
for param in model.parameters():
    param.data = param.data.float()
    
optimizer = torch.optim.Adam(model.parameters(), lr=2e-5)

for epoch in range(num_epochs):
    total_loss = 0
    for x, x_noisy, sigma in train_loader:
        x = x.to(device).float()  # Ensure float32
        x_noisy = x_noisy.to(device).float()  # Ensure float32
        sigma = sigma.to(device).float()  # Ensure float32

        # 1. Forward pass + DSM loss
        loss = dsm_loss(model, x, x_noisy, sigma)

        # 2. Backprop
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)

    avg_loss = total_loss / len(train_loader.dataset)
    
    if epoch % 100 == 0:
        print(f"Epoch {epoch+1}: DSM Loss = {avg_loss:.4f}")
        samples = annealed_langevin_sample(model, sigmas=SIGMA_LEVELS, shape=(16, 3, 32, 32))
        show_samples(samples)

KeyboardInterrupt: 