<a href="https://colab.research.google.com/github/kramerkraus/2155-CP3-mkraus/blob/main/diffusionmodel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np


# --------------------------------------------------
# 1. Small MLP backbone (denoiser)
# --------------------------------------------------

class MLPDenoiser(nn.Module):
    def __init__(self, input_dim=37, hidden=256):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim + 1 + input_dim, hidden), # x + t + mask
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, input_dim),
        )

    def forward(self, x_t, t, mask):
        """
        x_t: noised data
        t: time step (batch, 1)
        mask: binary mask (1 = observed, 0 = missing)
        """
        t = t / 1000.0       # normalize time
        inp = torch.cat([x_t, t, mask], dim=1)
        return self.net(inp)


# --------------------------------------------------
# 2. Diffusion core (betas, schedules, noise)
# --------------------------------------------------

def make_beta_schedule(T=1000, start=1e-4, end=0.02):
    return torch.linspace(start, end, T)


class Diffusion(nn.Module):
    def __init__(self, input_dim, timesteps=1000):
        super().__init__()
        self.T = timesteps
        betas = make_beta_schedule(timesteps)
        alphas = 1.0 - betas
        alphas_cum = torch.cumprod(alphas, dim=0)

        self.register_buffer("betas", betas)
        self.register_buffer("alphas", alphas)
        self.register_buffer("alphas_cum", alphas_cum)

        self.model = MLPDenoiser(input_dim=input_dim)

    # -------------------------
    # q(x_t | x_0)
    # -------------------------
    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        a_bar = self.alphas_cum[t].unsqueeze(1)
        return torch.sqrt(a_bar) * x0 + torch.sqrt(1 - a_bar) * noise, noise

    # -------------------------
    # Training step
    # -------------------------
    def forward(self, x0, mask):
        """
        x0 : clean data (batch, D)
        mask : 1 = observed, 0 = missing
        """
        B = x0.shape[0]
        device = x0.device

        # Random time t for each sample
        t = torch.randint(0, self.T, (B,), device=device)

        # Noise forward
        xt, noise = self.q_sample(x0, t, noise=None)

        # Condition on observed values
        xt = xt * (1 - mask) + x0 * mask

        # Predict noise
        noise_pred = self.model(xt, t.unsqueeze(1).float(), mask)

        # Loss only on missing entries
        loss = ((noise_pred - noise) ** 2 * (1 - mask)).mean()
        return loss

    # -------------------------
    # Sampling / imputation
    # -------------------------
    @torch.no_grad()
    def sample(self, x_obs, mask):
        """x_obs has missing entries set to anything (will overwrite them)."""
        x = torch.randn_like(x_obs)

        for t in reversed(range(self.T)):
            bt = self.betas[t]
            at = self.alphas[t]
            a_bar = self.alphas_cum[t]

            # Conditioner: always respect observed values
            x = x * (1 - mask) + x_obs * mask

            noise_pred = self.model(x, torch.tensor([[t]], device=x.device), mask)

            # DDPM update step
            coef1 = 1 / torch.sqrt(at)
            coef2 = (1 - at) / torch.sqrt(1 - a_bar)

            x = coef1 * (x - coef2 * noise_pred)

            if t > 0:
                x += torch.sqrt(bt) * torch.randn_like(x)

        return x


In [None]:
model = Diffusion(input_dim=37, timesteps=500).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-4)

In [None]:
for epoch in range(200):
    for batch, mask in train_loader:
        batch = batch.to(device)
        mask = mask.to(device)

        loss = model(batch, mask)

        opt.zero_grad()
        loss.backward()
        opt.step()

    print(f"epoch {epoch} | loss {loss.item():.4f}")

In [None]:
x_incomplete = X_test_imputed[i]      # has -1 replaced by something
mask = (X_test_missing_mask[i] == 0)  # convert your mask to 1/0 observed/missing

x_incomplete = torch.tensor(x_incomplete).float().to(device)
mask = torch.tensor(mask).float().to(device)

samples = []

for _ in range(20):
    x_gen = model.sample(x_incomplete.unsqueeze(0), mask.unsqueeze(0))
    samples.append(x_gen.cpu().numpy())

samples = np.array(samples)   # (20, 1, 37)