# Stable Diffusion Multi-Task

In [None]:
!pip install torch_fidelity

Collecting torch_fidelity
  Downloading torch_fidelity-0.3.0-py3-none-any.whl.metadata (2.0 kB)
Downloading torch_fidelity-0.3.0-py3-none-any.whl (37 kB)
Installing collected packages: torch_fidelity
Successfully installed torch_fidelity-0.3.0


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Single Task

In [None]:
# Tiny Diffusion (DDPM-style) — single-head with validation + metrics logging/plots
# One-cell, Colab-friendly script. Everything is saved under `save_dir`.

from __future__ import annotations
import math, os, csv
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from tqdm import tqdm

# Headless plotting
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# -------------------------
# Time embedding utilities
# -------------------------
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
    def forward(self, t: torch.Tensor):
        device = t.device
        half = self.dim // 2
        emb = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_mlp = None
        if time_dim is not None:
            self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch))
        self.act = nn.SiLU()
        self.norm1 = nn.GroupNorm(4, out_ch)
        self.norm2 = nn.GroupNorm(4, out_ch)
    def forward(self, x, t_emb=None):
        x = self.conv1(x)
        if self.time_mlp is not None and t_emb is not None:
            x = x + self.time_mlp(t_emb)[:, :, None, None]
        x = self.norm1(x); x = self.act(x)
        x = self.conv2(x)
        x = self.norm2(x); x = self.act(x)
        return x

class TinyUNet(nn.Module):
    def __init__(self, in_channels=1, base=32, time_dim=128):
        super().__init__()
        self.time_dim = time_dim
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
        )
        self.inc = ConvBlock(in_channels, base, time_dim)
        self.down1 = nn.Sequential(nn.Conv2d(base, base, 3, stride=2, padding=1), nn.SiLU())
        self.block1 = ConvBlock(base, base * 2, time_dim)
        self.down2 = nn.Sequential(nn.Conv2d(base * 2, base * 2, 3, stride=2, padding=1), nn.SiLU())
        self.block2 = ConvBlock(base * 2, base * 4, time_dim)
        self.mid = ConvBlock(base * 4, base * 4, time_dim)
        self.up1 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2)
        self.block_up1 = ConvBlock(base * 4, base * 2, time_dim)
        self.up2 = nn.ConvTranspose2d(base * 2, base, 2, stride=2)
        self.block_up2 = ConvBlock(base * 2, base, time_dim)
        self.outc = nn.Conv2d(base, in_channels, 1)
    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        x0 = self.inc(x, t_emb)
        x1 = self.down1(x0); x1 = self.block1(x1, t_emb)
        x2 = self.down2(x1); x2 = self.block2(x2, t_emb)
        m = self.mid(x2, t_emb)
        u1 = self.up1(m); u1 = torch.cat([u1, x1], dim=1); u1 = self.block_up1(u1, t_emb)
        u2 = self.up2(u1); u2 = torch.cat([u2, x0], dim=1); u2 = self.block_up2(u2, t_emb)
        return self.outc(u2)

@dataclass
class DiffusionConfig:
    timesteps: int = 1000
    beta_start: float = 1e-4
    beta_end: float = 0.02

class DDPM:
    def __init__(self, cfg: DiffusionConfig):
        self.cfg = cfg
        self.register_buffers()
    def register_buffers(self):
        T = self.cfg.timesteps
        betas = torch.linspace(self.cfg.beta_start, self.cfg.beta_end, T)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
        self.posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_ac = self._extract(self.sqrt_alphas_cumprod, t, x0.shape)
        sqrt_om = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape)
        return sqrt_ac * x0 + sqrt_om * noise, noise
    @staticmethod
    def _extract(a, t, x_shape):
        b = t.shape[0]
        out = a.gather(-1, t).float().view(b, *((1,) * (len(x_shape) - 1)))
        return out

def p_losses(model, ddpm: DDPM, x0, t):
    x_t, noise = ddpm.q_sample(x0, t)
    pred = model(x_t, t.float())
    return F.mse_loss(pred, noise)

# ---- Data ----
def make_dataloader(name: str, batch_size: int, img_size: int, channels: int, val_split: float = 0.05):
    tfm = [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
    if channels == 3:
        tfm = [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize((0.5,)*3, (0.5,)*3)]
    tfm = transforms.Compose(tfm)
    if name.lower() == "mnist":
        ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
    elif name.lower() == "cifar10":
        ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm)
    else:
        raise ValueError("Unsupported dataset.")
    val_size = max(1, int(len(ds) * val_split))
    train_size = len(ds) - val_size
    train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader

# ---- Sampling ----
@torch.no_grad()
def sample(model, ddpm: DDPM, shape, device, n_steps=1000, save_path="samples.png"):
    model.eval()
    T = ddpm.cfg.timesteps
    x = torch.randn(shape, device=device)
    for i in tqdm(reversed(range(T)), total=T, desc="sampling"):
        t = torch.full((shape[0],), i, device=device, dtype=torch.long)
        beta_t = ddpm._extract(ddpm.betas, t, x.shape)
        sqrt_one_minus_ac = ddpm._extract(ddpm.sqrt_one_minus_alphas_cumprod, t, x.shape)
        sqrt_recip_alpha = ddpm._extract(ddpm.sqrt_recip_alphas, t, x.shape)
        model_pred = model(x, t.float())  # predicts eps
        x0_pred = (x - sqrt_one_minus_ac * model_pred) * sqrt_recip_alpha
        alphas = ddpm._extract(ddpm.alphas, t, x.shape)
        alphas_cum = ddpm._extract(ddpm.alphas_cumprod, t, x.shape)
        alphas_cum_prev = ddpm._extract(ddpm.alphas_cumprod_prev, t, x.shape)
        posterior_var = ddpm._extract(ddpm.posterior_variance, t, x.shape)
        posterior_mean = (
            (beta_t * torch.sqrt(alphas_cum_prev) / (1.0 - alphas_cum)) * x0_pred
            + ((torch.sqrt(alphas) * (1.0 - alphas_cum_prev)) / (1.0 - alphas_cum)) * x
        )
        noise = torch.randn_like(x) if i > 0 else torch.zeros_like(x)
        x = posterior_mean + torch.sqrt(posterior_var) * noise
    grid = vutils.make_grid((x.clamp(-1, 1) + 1) * 0.5, nrow=int(math.sqrt(shape[0])))
    vutils.save_image(grid, save_path)
    return save_path, x.detach().cpu()  # also return tensor batch for validation

# ---- Validation utils ----
def denorm(x, channels):
    # Convert from [-1,1] to [0,1]; replicate to 3 ch for grayscale if needed
    x = (x.clamp(-1, 1) + 1) * 0.5
    if channels == 1:
        x = x.repeat(1, 3, 1, 1)
    return x

@torch.no_grad()
def evaluate_mse(model, ddpm: DDPM, val_loader, device):
    model.eval()
    total, count = 0.0, 0
    for imgs, _ in val_loader:
        imgs = imgs.to(device)
        t = torch.randint(0, ddpm.cfg.timesteps, (imgs.size(0),), device=device)
        x_t, noise = ddpm.q_sample(imgs, t)
        pred = model(x_t, t.float())
        loss = F.mse_loss(pred, noise, reduction='sum')
        total += loss.item()
        count += imgs.numel()
    return total / max(1, count)

@torch.no_grad()
def dump_images(tensor_bchw, out_dir: str, prefix: str = "img"):
    out = Path(out_dir)
    out.mkdir(parents=True, exist_ok=True)
    for i, img in enumerate(tensor_bchw):
        vutils.save_image(img, out / f"{prefix}_{i:05d}.png")

def compute_fid(real_dir: str, fake_dir: str, device: torch.device):
    try:
        from torch_fidelity import calculate_metrics
    except Exception:
        print("torch-fidelity not found. Install with: pip install torch-fidelity")
        return float('nan')
    metrics = calculate_metrics(
        input1=real_dir, input2=fake_dir,
        fid=True, isc=False, kid=False,
        cuda=device.type == 'cuda', batch_size=64, verbose=False,
    )
    return float(metrics.get('frechet_inception_distance', float('nan')))

# ---- Metrics logging / plotting ----
def ema_series(values, decay=0.98):
    if not values:
        return []
    out, m = [], values[0]
    for v in values:
        m = decay * m + (1 - decay) * v
        out.append(m)
    return out

def save_curves(train_steps, train_losses, val_epochs, val_mses, val_fids, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    # CSVs
    with open(out_dir / "train_loss.csv", "w", newline="") as f:
        w = csv.writer(f); w.writerow(["step", "loss"]); w.writerows(zip(train_steps, train_losses))
    with open(out_dir / "val_metrics.csv", "w", newline="") as f:
        w = csv.writer(f); w.writerow(["epoch", "val_mse", "fid"]); w.writerows(zip(val_epochs, val_mses, val_fids))
    # Plots
    plt.figure()
    plt.plot(train_steps, train_losses, label="loss (raw)")
    plt.plot(train_steps, ema_series(train_losses), label="loss (EMA)")
    plt.xlabel("step"); plt.ylabel("loss"); plt.title("Training loss"); plt.legend(); plt.tight_layout()
    plt.savefig(out_dir / "training_loss.png", dpi=150); plt.close()

    plt.figure()
    plt.plot(val_epochs, val_mses, label="val MSE")
    plt.plot(val_epochs, val_fids, label="FID")
    plt.xlabel("epoch"); plt.ylabel("metric"); plt.title("Validation metrics"); plt.legend(); plt.tight_layout()
    plt.savefig(out_dir / "val_metrics.png", dpi=150); plt.close()

# ---- Train ----
def train(
    save_dir: str = "runs/exp1",   # <--- all outputs go under this folder
    data: str = "mnist",           # "mnist" or "cifar10"
    epochs: int = 1,
    batch_size: int = 128,
    lr: float = 2e-4,
    timesteps: int = 200,
    beta_start: float = 1e-4,
    beta_end: float = 0.02,
    base: int = 32,
    time_dim: int = 128,
    n_sample: int = 64,            # more samples for better FID estimate
    sample_every: int = 500,
    val_split: float = 0.05,
    fid_eval_images: int = 1024,   # number of real/fake images to use for FID (reduce if slow)
):
    # --- Paths ---
    save_dir = Path(save_dir)
    images_dir = save_dir / "images"
    metrics_dir = save_dir / "metrics"
    fid_real_dir = save_dir / "fid" / "real"
    fid_fake_dir = save_dir / "fid" / "fake"
    images_dir.mkdir(parents=True, exist_ok=True)
    metrics_dir.mkdir(parents=True, exist_ok=True)
    fid_real_dir.mkdir(parents=True, exist_ok=True)
    fid_fake_dir.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    channels = 1 if data.lower() == "mnist" else 3
    img_size = 28 if data.lower() == "mnist" else 32

    train_loader, val_loader = make_dataloader(data, batch_size, img_size, channels, val_split=val_split)
    model = TinyUNet(in_channels=channels, base=base, time_dim=time_dim).to(device)
    ddpm = DDPM(DiffusionConfig(timesteps=timesteps, beta_start=beta_start, beta_end=beta_end))
    optim = torch.optim.AdamW(model.parameters(), lr=lr)

    # Logging buffers
    train_steps, train_losses = [], []
    val_epochs, val_mses, val_fids = [], [], []

    step = 0
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(train_loader, desc=f"epoch {epoch+1}/{epochs}")
        for imgs, _ in pbar:
            imgs = imgs.to(device)
            t = torch.randint(0, ddpm.cfg.timesteps, (imgs.size(0),), device=device)
            loss = p_losses(model, ddpm, imgs, t)
            optim.zero_grad(); loss.backward(); optim.step()

            step += 1
            loss_val = float(loss.item())
            train_steps.append(step); train_losses.append(loss_val)
            pbar.set_postfix({"loss": f"{loss_val:.4f}"})

            if step % sample_every == 0:
                with torch.no_grad():
                    grid_path = images_dir / f"samples_step{step}.png"
                    _path, _ = sample(
                        model, ddpm,
                        shape=(n_sample, channels, img_size, img_size),
                        device=device, n_steps=ddpm.cfg.timesteps,
                        save_path=str(grid_path),
                    )

        # --- Validation at end of epoch ---
        val_mse = evaluate_mse(model, ddpm, val_loader, device)

        # Prepare real images for FID
        # (re-create dirs each epoch to avoid mixing from previous runs)
        for d in [fid_real_dir, fid_fake_dir]:
            for f in d.glob("*.png"): f.unlink()

        collected, imgs_accum = 0, []
        for imgs, _ in val_loader:
            imgs_accum.append(imgs)
            collected += imgs.size(0)
            if collected >= fid_eval_images:
                break
        real_imgs = torch.cat(imgs_accum, dim=0)[:fid_eval_images]
        real_imgs = denorm(real_imgs, channels)
        dump_images(real_imgs, str(fid_real_dir), prefix="real")

        # generate fake (tile until fid_eval_images)
        _, fake_batch = sample(
            model, ddpm,
            shape=(min(fid_eval_images, n_sample), channels, img_size, img_size),
            device=device, n_steps=ddpm.cfg.timesteps,
            save_path=str(images_dir / f"samples_val_epoch{epoch+1}.png"),
        )
        fake_list = [denorm(fake_batch.cpu(), channels)]
        while sum(x.size(0) for x in fake_list) < fid_eval_images:
            _, fb = sample(
                model, ddpm,
                shape=(min(fid_eval_images - sum(x.size(0) for x in fake_list), n_sample),
                       channels, img_size, img_size),
                device=device, n_steps=ddpm.cfg.timesteps, save_path=str(images_dir / "_tmp.png"),
            )
            fake_list.append(denorm(fb.cpu(), channels))
        fake_imgs = torch.cat(fake_list, dim=0)[:fid_eval_images]
        dump_images(fake_imgs, str(fid_fake_dir), prefix="fake")

        fid_score = compute_fid(str(fid_real_dir), str(fid_fake_dir), device)
        print(f"[val] epoch {epoch+1}: MSE={val_mse:.6f} | FID={fid_score:.2f}")

        # log epoch metrics + save curves/CSVs
        val_epochs.append(epoch + 1)
        val_mses.append(float(val_mse))
        val_fids.append(float(fid_score))
        save_curves(train_steps, train_losses, val_epochs, val_mses, val_fids, metrics_dir)

    # Final sample grid
    with torch.no_grad():
        path, _ = sample(
            model, ddpm,
            shape=(n_sample, channels, img_size, img_size),
            device=device, n_steps=ddpm.cfg.timesteps,
            save_path=str(images_dir / "samples_final.png"),
        )
    print(f"Saved final samples to {path}")

    # Final curves (redundant but harmless)
    save_curves(train_steps, train_losses, val_epochs, val_mses, val_fids, metrics_dir)

# ---- Run ----
if __name__ == "__main__":
    # Customize the destination folder here:
    train(save_dir="/content/drive/MyDrive/prototypes/mtsd_exp/mnist_st_run1", data="mnist", epochs=10)

100%|██████████| 9.91M/9.91M [00:00<00:00, 59.2MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.69MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 14.5MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 9.38MB/s]
epoch 1/10: 100%|██████████| 446/446 [14:20<00:00,  1.93s/it, loss=0.0547]
sampling: 100%|██████████| 200/200 [01:08<00:00,  2.90it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.01it/s]
sampling: 100%|██████████| 200/200 [01:04<00:00,  3.10it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.02it/s]
sampling: 100%|██████████| 200/200 [01:07<00:00,  2.96it/s]
sampling: 100%|██████████| 200/200 [01:05<00:00,  3.04it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.03it/s]
sampling: 100%|██████████| 200/200 [01:04<00:00,  3.10it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.02it/s]
sampling: 100%|██████████| 200/200 [01:04<00:00,  3.08it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.02it/s]
sampling: 100%|██████████| 200/200 [01:04<00:00,  3.0

[val] epoch 1: MSE=0.075917 | FID=217.34


epoch 2/10:  12%|█▏        | 53/446 [01:46<12:32,  1.92s/it, loss=0.0689]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:09,  2.85it/s][A
sampling:   1%|          | 2/200 [00:00<01:04,  3.07it/s][A
sampling:   2%|▏         | 3/200 [00:00<01:00,  3.24it/s][A
sampling:   2%|▏         | 4/200 [00:01<00:59,  3.31it/s][A
sampling:   2%|▎         | 5/200 [00:01<00:57,  3.37it/s][A
sampling:   3%|▎         | 6/200 [00:01<00:58,  3.32it/s][A
sampling:   4%|▎         | 7/200 [00:02<00:57,  3.37it/s][A
sampling:   4%|▍         | 8/200 [00:02<00:56,  3.40it/s][A
sampling:   4%|▍         | 9/200 [00:02<00:56,  3.40it/s][A
sampling:   5%|▌         | 10/200 [00:03<00:56,  3.38it/s][A
sampling:   6%|▌         | 11/200 [00:03<00:55,  3.39it/s][A
sampling:   6%|▌         | 12/200 [00:03<00:55,  3.40it/s][A
sampling:   6%|▋         | 13/200 [00:03<00:56,  3.33it/s][A
sampling:   7%|▋         | 14/200 [00:04<00:55,  3.38it/s][A
sampling:   8%

[val] epoch 2: MSE=0.061174 | FID=182.75


epoch 3/10:  24%|██▍       | 107/446 [03:32<10:08,  1.80s/it, loss=0.0661]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:27,  2.28it/s][A
sampling:   1%|          | 2/200 [00:00<01:27,  2.26it/s][A
sampling:   2%|▏         | 3/200 [00:01<01:29,  2.21it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:27,  2.24it/s][A
sampling:   2%|▎         | 5/200 [00:02<01:27,  2.23it/s][A
sampling:   3%|▎         | 6/200 [00:02<01:27,  2.21it/s][A
sampling:   4%|▎         | 7/200 [00:03<01:26,  2.22it/s][A
sampling:   4%|▍         | 8/200 [00:03<01:17,  2.47it/s][A
sampling:   4%|▍         | 9/200 [00:03<01:10,  2.70it/s][A
sampling:   5%|▌         | 10/200 [00:04<01:05,  2.90it/s][A
sampling:   6%|▌         | 11/200 [00:04<01:02,  3.00it/s][A
sampling:   6%|▌         | 12/200 [00:04<01:00,  3.12it/s][A
sampling:   6%|▋         | 13/200 [00:04<00:58,  3.21it/s][A
sampling:   7%|▋         | 14/200 [00:05<00:56,  3.28it/s][A
sampling:   8

[val] epoch 3: MSE=0.060356 | FID=156.09


epoch 4/10:  36%|███▌      | 161/446 [05:16<09:36,  2.02s/it, loss=0.0527]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<00:58,  3.41it/s][A
sampling:   1%|          | 2/200 [00:00<00:58,  3.38it/s][A
sampling:   2%|▏         | 3/200 [00:00<00:58,  3.39it/s][A
sampling:   2%|▏         | 4/200 [00:01<00:59,  3.29it/s][A
sampling:   2%|▎         | 5/200 [00:01<00:58,  3.32it/s][A
sampling:   3%|▎         | 6/200 [00:01<00:57,  3.35it/s][A
sampling:   4%|▎         | 7/200 [00:02<00:58,  3.32it/s][A
sampling:   4%|▍         | 8/200 [00:02<00:57,  3.35it/s][A
sampling:   4%|▍         | 9/200 [00:02<00:57,  3.33it/s][A
sampling:   5%|▌         | 10/200 [00:02<00:56,  3.35it/s][A
sampling:   6%|▌         | 11/200 [00:03<00:56,  3.33it/s][A
sampling:   6%|▌         | 12/200 [00:03<00:56,  3.35it/s][A
sampling:   6%|▋         | 13/200 [00:03<00:55,  3.35it/s][A
sampling:   7%|▋         | 14/200 [00:04<00:55,  3.32it/s][A
sampling:   8

[val] epoch 4: MSE=0.055761 | FID=154.35


epoch 5/10:  48%|████▊     | 215/446 [06:59<06:52,  1.79s/it, loss=0.0511]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:27,  2.27it/s][A
sampling:   1%|          | 2/200 [00:00<01:27,  2.27it/s][A
sampling:   2%|▏         | 3/200 [00:01<01:28,  2.22it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:25,  2.28it/s][A
sampling:   2%|▎         | 5/200 [00:02<01:14,  2.60it/s][A
sampling:   3%|▎         | 6/200 [00:02<01:09,  2.78it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:05,  2.96it/s][A
sampling:   4%|▍         | 8/200 [00:02<01:02,  3.09it/s][A
sampling:   4%|▍         | 9/200 [00:03<01:00,  3.13it/s][A
sampling:   5%|▌         | 10/200 [00:03<00:59,  3.21it/s][A
sampling:   6%|▌         | 11/200 [00:03<00:58,  3.25it/s][A
sampling:   6%|▌         | 12/200 [00:04<00:57,  3.29it/s][A
sampling:   6%|▋         | 13/200 [00:04<00:57,  3.25it/s][A
sampling:   7%|▋         | 14/200 [00:04<00:56,  3.29it/s][A
sampling:   8

[val] epoch 5: MSE=0.054092 | FID=150.96


epoch 6/10:  60%|██████    | 269/446 [08:51<05:29,  1.86s/it, loss=0.0452]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:03,  3.16it/s][A
sampling:   1%|          | 2/200 [00:00<01:00,  3.30it/s][A
sampling:   2%|▏         | 3/200 [00:00<00:59,  3.29it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:01,  3.19it/s][A
sampling:   2%|▎         | 5/200 [00:01<01:00,  3.22it/s][A
sampling:   3%|▎         | 6/200 [00:01<00:59,  3.24it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:00,  3.21it/s][A
sampling:   4%|▍         | 8/200 [00:02<01:08,  2.81it/s][A
sampling:   4%|▍         | 9/200 [00:03<01:14,  2.56it/s][A
sampling:   5%|▌         | 10/200 [00:03<01:16,  2.47it/s][A
sampling:   6%|▌         | 11/200 [00:03<01:18,  2.40it/s][A
sampling:   6%|▌         | 12/200 [00:04<01:21,  2.32it/s][A
sampling:   6%|▋         | 13/200 [00:04<01:21,  2.28it/s][A
sampling:   7%|▋         | 14/200 [00:05<01:21,  2.29it/s][A
sampling:   8

[val] epoch 6: MSE=0.051865 | FID=135.59


epoch 7/10:  72%|███████▏  | 323/446 [10:36<03:49,  1.86s/it, loss=0.0495]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:33,  2.13it/s][A
sampling:   1%|          | 2/200 [00:00<01:29,  2.20it/s][A
sampling:   2%|▏         | 3/200 [00:01<01:18,  2.52it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:10,  2.80it/s][A
sampling:   2%|▎         | 5/200 [00:01<01:05,  2.97it/s][A
sampling:   3%|▎         | 6/200 [00:02<01:03,  3.04it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:02,  3.09it/s][A
sampling:   4%|▍         | 8/200 [00:02<01:01,  3.14it/s][A
sampling:   4%|▍         | 9/200 [00:03<00:59,  3.21it/s][A
sampling:   5%|▌         | 10/200 [00:03<01:00,  3.16it/s][A
sampling:   6%|▌         | 11/200 [00:03<00:58,  3.22it/s][A
sampling:   6%|▌         | 12/200 [00:03<00:58,  3.24it/s][A
sampling:   6%|▋         | 13/200 [00:04<00:58,  3.22it/s][A
sampling:   7%|▋         | 14/200 [00:04<00:56,  3.26it/s][A
sampling:   8

[val] epoch 7: MSE=0.052811 | FID=149.33


epoch 8/10:  85%|████████▍ | 377/446 [12:26<02:27,  2.14s/it, loss=0.0545]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:01,  3.26it/s][A
sampling:   1%|          | 2/200 [00:00<01:00,  3.26it/s][A
sampling:   2%|▏         | 3/200 [00:00<00:59,  3.29it/s][A
sampling:   2%|▏         | 4/200 [00:01<00:59,  3.31it/s][A
sampling:   2%|▎         | 5/200 [00:01<00:58,  3.34it/s][A
sampling:   3%|▎         | 6/200 [00:01<00:59,  3.26it/s][A
sampling:   4%|▎         | 7/200 [00:02<00:58,  3.28it/s][A
sampling:   4%|▍         | 8/200 [00:02<00:58,  3.31it/s][A
sampling:   4%|▍         | 9/200 [00:02<00:58,  3.27it/s][A
sampling:   5%|▌         | 10/200 [00:03<00:57,  3.30it/s][A
sampling:   6%|▌         | 11/200 [00:03<00:56,  3.33it/s][A
sampling:   6%|▌         | 12/200 [00:03<00:56,  3.33it/s][A
sampling:   6%|▋         | 13/200 [00:03<00:56,  3.28it/s][A
sampling:   7%|▋         | 14/200 [00:04<00:56,  3.31it/s][A
sampling:   8

[val] epoch 8: MSE=0.050956 | FID=122.69


epoch 9/10:  97%|█████████▋| 431/446 [14:03<00:28,  1.87s/it, loss=0.0505]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:03,  3.13it/s][A
sampling:   1%|          | 2/200 [00:00<01:01,  3.24it/s][A
sampling:   2%|▏         | 3/200 [00:00<00:59,  3.31it/s][A
sampling:   2%|▏         | 4/200 [00:01<00:59,  3.29it/s][A
sampling:   2%|▎         | 5/200 [00:01<00:59,  3.29it/s][A
sampling:   3%|▎         | 6/200 [00:01<00:58,  3.32it/s][A
sampling:   4%|▎         | 7/200 [00:02<00:57,  3.36it/s][A
sampling:   4%|▍         | 8/200 [00:02<00:57,  3.32it/s][A
sampling:   4%|▍         | 9/200 [00:02<01:05,  2.92it/s][A
sampling:   5%|▌         | 10/200 [00:03<01:13,  2.59it/s][A
sampling:   6%|▌         | 11/200 [00:03<01:16,  2.48it/s][A
sampling:   6%|▌         | 12/200 [00:04<01:18,  2.41it/s][A
sampling:   6%|▋         | 13/200 [00:04<01:20,  2.31it/s][A
sampling:   7%|▋         | 14/200 [00:05<01:21,  2.29it/s][A
sampling:   8

[val] epoch 9: MSE=0.049995 | FID=133.59


epoch 10/10: 100%|██████████| 446/446 [14:30<00:00,  1.95s/it, loss=0.0390]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.81it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.03it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.00it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.02it/s]
sampling: 100%|██████████| 200/200 [01:05<00:00,  3.06it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  2.99it/s]
sampling: 100%|██████████| 200/200 [01:05<00:00,  3.08it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.00it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.02it/s]
sampling: 100%|██████████| 200/200 [01:05<00:00,  3.06it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  2.99it/s]
sampling: 100%|██████████| 200/200 [01:05<00:00,  3.07it/s]
sampling: 100%|██████████| 200/200 [01:07<00:00,  2.95it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  2.99it/s]
sampling: 100%|██████████| 200/200 [01:06<00:00,  3.03it/s]
sampling: 100%|█████████

[val] epoch 10: MSE=0.049288 | FID=110.55


sampling: 100%|██████████| 200/200 [01:06<00:00,  3.00it/s]


Saved final samples to /content/drive/MyDrive/prototypes/mtsd_exp/mnist_st_run1/images/samples_final.png


# Multi-Task

In [3]:
# Tiny Diffusion (DDPM-style) — multi-task (ε & x0 heads) with validation + metrics logging/plots
# One-cell, Colab-friendly script. Everything is saved under configurable `save_dir`.

from __future__ import annotations
import math, os, csv
from dataclasses import dataclass
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, utils as vutils
from tqdm import tqdm

# Headless plotting for Colab/servers
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

# -------------------------
# Time embedding utilities
# -------------------------
class SinusoidalPositionEmbeddings(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim
    def forward(self, t: torch.Tensor):
        device = t.device
        half = self.dim // 2
        emb = math.log(10000) / (half - 1)
        emb = torch.exp(torch.arange(half, device=device) * -emb)
        emb = t[:, None] * emb[None, :]
        emb = torch.cat([emb.sin(), emb.cos()], dim=-1)
        return emb

class ConvBlock(nn.Module):
    def __init__(self, in_ch, out_ch, time_dim=None):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        self.time_mlp = None
        if time_dim is not None:
            self.time_mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_dim, out_ch))
        self.act = nn.SiLU()
        self.norm1 = nn.GroupNorm(4, out_ch)
        self.norm2 = nn.GroupNorm(4, out_ch)
    def forward(self, x, t_emb=None):
        x = self.conv1(x)
        if self.time_mlp is not None and t_emb is not None:
            x = x + self.time_mlp(t_emb)[:, :, None, None]
        x = self.norm1(x); x = self.act(x)
        x = self.conv2(x)
        x = self.norm2(x); x = self.act(x)
        return x

class TinyUNet(nn.Module):
    def __init__(self, in_channels=1, base=32, time_dim=128):
        super().__init__()
        self.time_dim = time_dim
        self.time_mlp = nn.Sequential(
            SinusoidalPositionEmbeddings(time_dim),
            nn.Linear(time_dim, time_dim),
            nn.SiLU(),
        )
        self.inc = ConvBlock(in_channels, base, time_dim)
        self.down1 = nn.Sequential(nn.Conv2d(base, base, 3, stride=2, padding=1), nn.SiLU())
        self.block1 = ConvBlock(base, base * 2, time_dim)
        self.down2 = nn.Sequential(nn.Conv2d(base * 2, base * 2, 3, stride=2, padding=1), nn.SiLU())
        self.block2 = ConvBlock(base * 2, base * 4, time_dim)
        self.mid = ConvBlock(base * 4, base * 4, time_dim)
        self.up1 = nn.ConvTranspose2d(base * 4, base * 2, 2, stride=2)
        self.block_up1 = ConvBlock(base * 4, base * 2, time_dim)
        self.up2 = nn.ConvTranspose2d(base * 2, base, 2, stride=2)
        self.block_up2 = ConvBlock(base * 2, base, time_dim)
        # Two lightweight heads:
        self.out_eps = nn.Conv2d(base, in_channels, 1)   # ε (noise)
        self.out_x0  = nn.Conv2d(base, in_channels, 1)   # x0 (clean)
    def forward(self, x, t):
        t_emb = self.time_mlp(t)
        x0 = self.inc(x, t_emb)
        x1 = self.down1(x0); x1 = self.block1(x1, t_emb)
        x2 = self.down2(x1); x2 = self.block2(x2, t_emb)
        m = self.mid(x2, t_emb)
        u1 = self.up1(m); u1 = torch.cat([u1, x1], dim=1); u1 = self.block_up1(u1, t_emb)
        u2 = self.up2(u1); u2 = torch.cat([u2, x0], dim=1); u2 = self.block_up2(u2, t_emb)
        return {"eps": self.out_eps(u2), "x0": self.out_x0(u2)}

@dataclass
class DiffusionConfig:
    timesteps: int = 1000
    beta_start: float = 1e-4
    beta_end: float = 0.02

class DDPM:
    def __init__(self, cfg: DiffusionConfig):
        self.cfg = cfg
        self.register_buffers()
    def register_buffers(self):
        T = self.cfg.timesteps
        betas = torch.linspace(self.cfg.beta_start, self.cfg.beta_end, T)
        alphas = 1.0 - betas
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
        self.betas = betas
        self.alphas = alphas
        self.alphas_cumprod = alphas_cumprod
        self.alphas_cumprod_prev = alphas_cumprod_prev
        self.sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)
        self.sqrt_recip_alphas = torch.sqrt(1.0 / alphas)
        self.posterior_variance = (
            betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
        )
    def q_sample(self, x0, t, noise=None):
        if noise is None:
            noise = torch.randn_like(x0)
        sqrt_ac = self._extract(self.sqrt_alphas_cumprod, t, x0.shape)
        sqrt_om = self._extract(self.sqrt_one_minus_alphas_cumprod, t, x0.shape)
        return sqrt_ac * x0 + sqrt_om * noise, noise
    @staticmethod
    def _extract(a, t, x_shape):
        b = t.shape[0]
        out = a.gather(-1, t).float().view(b, *((1,) * (len(x_shape) - 1)))
        return out

# ---- Loss with consistency (multi-task) ----
def p_losses_multi(model, ddpm: DDPM, x0, t, w_eps=1.0, w_x0=1.0, w_consistency=0.1):
    # Diffuse x0 to xt
    x_t, noise = ddpm.q_sample(x0, t)
    preds = model(x_t, t.float())
    pred_eps = preds["eps"]
    pred_x0  = preds["x0"]

    # Primary targets
    loss_eps = F.mse_loss(pred_eps, noise)
    loss_x0  = F.mse_loss(pred_x0, x0)

    # Optional consistency: tie the two heads with the DDPM relation
    sqrt_ac = ddpm._extract(ddpm.sqrt_alphas_cumprod, t, x_t.shape)
    sqrt_om = ddpm._extract(ddpm.sqrt_one_minus_alphas_cumprod, t, x_t.shape)
    x0_from_eps = (x_t - sqrt_om * pred_eps) / (sqrt_ac + 1e-8)
    eps_from_x0 = (x_t - sqrt_ac * pred_x0) / (sqrt_om + 1e-8)
    cons1 = F.mse_loss(x0_from_eps.detach(), pred_x0)
    cons2 = F.mse_loss(eps_from_x0.detach(), pred_eps)
    loss_cons = 0.5 * (cons1 + cons2)

    total = w_eps * loss_eps + w_x0 * loss_x0 + w_consistency * loss_cons
    return total, {
        "loss_eps": loss_eps.detach(),
        "loss_x0": loss_x0.detach(),
        "loss_cons": loss_cons.detach()
    }

# ---- Data ----
def make_dataloader(name: str, batch_size: int, img_size: int, channels: int, val_split: float = 0.05):
    tfm = [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize(0.5, 0.5)]
    if channels == 3:
        tfm = [transforms.Resize(img_size), transforms.ToTensor(), transforms.Normalize((0.5,)*3, (0.5,)*3)]
    tfm = transforms.Compose(tfm)
    if name.lower() == "mnist":
        ds = datasets.MNIST(root="./data", train=True, download=True, transform=tfm)
    elif name.lower() == "cifar10":
        ds = datasets.CIFAR10(root="./data", train=True, download=True, transform=tfm)
    else:
        raise ValueError("Unsupported dataset.")
    val_size = max(1, int(len(ds) * val_split))
    train_size = len(ds) - val_size
    train_ds, val_ds = torch.utils.data.random_split(ds, [train_size, val_size])
    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
    return train_loader, val_loader

# ---- Sampling (uses ε head) ----
@torch.no_grad()
def sample(model, ddpm: DDPM, shape, device, n_steps=1000, save_path="samples.png"):
    model.eval()
    T = ddpm.cfg.timesteps
    x = torch.randn(shape, device=device)
    for i in tqdm(reversed(range(T)), total=T, desc="sampling"):
        t = torch.full((shape[0],), i, device=device, dtype=torch.long)
        beta_t = ddpm._extract(ddpm.betas, t, x.shape)
        sqrt_one_minus_ac = ddpm._extract(ddpm.sqrt_one_minus_alphas_cumprod, t, x.shape)
        sqrt_recip_alpha = ddpm._extract(ddpm.sqrt_recip_alphas, t, x.shape)
        preds = model(x, t.float())
        pred_eps = preds["eps"]
        x0_pred = (x - sqrt_one_minus_ac * pred_eps) * sqrt_recip_alpha
        alphas = ddpm._extract(ddpm.alphas, t, x.shape)
        alphas_cum = ddpm._extract(ddpm.alphas_cumprod, t, x.shape)
        alphas_cum_prev = ddpm._extract(ddpm.alphas_cumprod_prev, t, x.shape)
        posterior_var = ddpm._extract(ddpm.posterior_variance, t, x.shape)
        posterior_mean = (
            (beta_t * torch.sqrt(alphas_cum_prev) / (1.0 - alphas_cum)) * x0_pred
            + ((torch.sqrt(alphas) * (1.0 - alphas_cum_prev)) / (1.0 - alphas_cum)) * x
        )
        noise = torch.randn_like(x) if i > 0 else torch.zeros_like(x)
        x = posterior_mean + torch.sqrt(posterior_var) * noise
    grid = vutils.make_grid((x.clamp(-1, 1) + 1) * 0.5, nrow=int(math.sqrt(shape[0])))
    vutils.save_image(grid, save_path)
    return save_path, x.detach().cpu()

# ---- Validation utils ----
def denorm(x, channels):
    x = (x.clamp(-1, 1) + 1) * 0.5
    if channels == 1:
        x = x.repeat(1, 3, 1, 1)
    return x

@torch.no_grad()
def evaluate_mse_multi(model, ddpm: DDPM, val_loader, device):
    """Returns tuple: (mse_eps, mse_x0) averaged over pixels."""
    model.eval()
    tot_eps, tot_x0, denom = 0.0, 0.0, 0
    for imgs, _ in val_loader:
        imgs = imgs.to(device)
        t = torch.randint(0, ddpm.cfg.timesteps, (imgs.size(0),), device=device)
        x_t, noise = ddpm.q_sample(imgs, t)
        preds = model(x_t, t.float())
        mse_eps = F.mse_loss(preds["eps"], noise, reduction='sum')
        mse_x0  = F.mse_loss(preds["x0"], imgs, reduction='sum')
        tot_eps += mse_eps.item()
        tot_x0  += mse_x0.item()
        denom   += imgs.numel()
    denom = max(1, denom)
    return (tot_eps / denom), (tot_x0 / denom)

@torch.no_grad()
def dump_images(tensor_bchw, out_dir: str, prefix: str = "img"):
    out = Path(out_dir)
    out.mkdir(parents=True, exist_ok=True)
    for i, img in enumerate(tensor_bchw):
        vutils.save_image(img, out / f"{prefix}_{i:05d}.png")

def compute_fid(real_dir: str, fake_dir: str, device: torch.device):
    try:
        from torch_fidelity import calculate_metrics
    except Exception:
        print("torch-fidelity not found. Install with: pip install torch-fidelity")
        return float('nan')
    metrics = calculate_metrics(
        input1=real_dir, input2=fake_dir,
        fid=True, isc=False, kid=False,
        cuda=device.type == 'cuda', batch_size=64, verbose=False,
    )
    return float(metrics.get('frechet_inception_distance', float('nan')))

# ---- Metrics logging / plotting ----
def ema_series(values, decay=0.98):
    if not values:
        return []
    out, m = [], values[0]
    for v in values:
        m = decay * m + (1 - decay) * v
        out.append(m)
    return out

def save_curves_multi(train_steps, total_losses, eps_losses, x0_losses, cons_losses,
                      val_epochs, val_mse_eps, val_mse_x0, val_fids, out_dir: Path):
    out_dir.mkdir(parents=True, exist_ok=True)
    # CSVs
    with open(out_dir / "train_losses.csv", "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["step", "total", "loss_eps", "loss_x0", "loss_cons"])
        for i in range(len(train_steps)):
            w.writerow([train_steps[i],
                        total_losses[i],
                        eps_losses[i] if i < len(eps_losses) else "",
                        x0_losses[i] if i < len(x0_losses) else "",
                        cons_losses[i] if i < len(cons_losses) else ""])
    with open(out_dir / "val_metrics.csv", "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(["epoch", "val_mse_eps", "val_mse_x0", "fid"])
        w.writerows(zip(val_epochs, val_mse_eps, val_mse_x0, val_fids))

    # Plots
    # 1) Training total loss (raw + EMA)
    plt.figure()
    plt.plot(train_steps, total_losses, label="total (raw)")
    plt.plot(train_steps, ema_series(total_losses), label="total (EMA)")
    plt.xlabel("step"); plt.ylabel("loss"); plt.title("Training loss (total)"); plt.legend(); plt.tight_layout()
    plt.savefig(out_dir / "training_loss_total.png", dpi=150); plt.close()

    # 2) Training component losses (EMA to reduce noise)
    plt.figure()
    if eps_losses: plt.plot(train_steps[:len(eps_losses)], ema_series(eps_losses), label="loss_eps (EMA)")
    if x0_losses:  plt.plot(train_steps[:len(x0_losses)],  ema_series(x0_losses),  label="loss_x0 (EMA)")
    if cons_losses:plt.plot(train_steps[:len(cons_losses)],ema_series(cons_losses), label="loss_cons (EMA)")
    plt.xlabel("step"); plt.ylabel("loss"); plt.title("Training loss components"); plt.legend(); plt.tight_layout()
    plt.savefig(out_dir / "training_loss_components.png", dpi=150); plt.close()

    # 3) Validation metrics per epoch
    plt.figure()
    plt.plot(val_epochs, val_mse_eps, label="val MSE (eps)")
    plt.plot(val_epochs, val_mse_x0, label="val MSE (x0)")
    plt.plot(val_epochs, val_fids, label="FID")
    plt.xlabel("epoch"); plt.ylabel("metric"); plt.title("Validation metrics"); plt.legend(); plt.tight_layout()
    plt.savefig(out_dir / "val_metrics.png", dpi=150); plt.close()

# ---- Train ----
def train(
    save_dir: str = "runs/exp_multitask",
    data: str = "mnist",           # "mnist" or "cifar10"
    epochs: int = 1,
    batch_size: int = 128,
    lr: float = 2e-4,
    timesteps: int = 200,
    beta_start: float = 1e-4,
    beta_end: float = 0.02,
    base: int = 32,
    time_dim: int = 128,
    n_sample: int = 64,            # more samples for better FID estimate
    sample_every: int = 500,
    val_split: float = 0.05,
    fid_eval_images: int = 1024,   # number of real/fake images to use for FID (reduce if slow)
    w_eps: float = 1.0,
    w_x0: float = 1.0,
    w_consistency: float = 0.1,
):
    # --- Paths ---
    save_dir = Path(save_dir)
    images_dir = save_dir / "images"
    metrics_dir = save_dir / "metrics"
    fid_real_dir = save_dir / "fid" / "real"
    fid_fake_dir = save_dir / "fid" / "fake"
    images_dir.mkdir(parents=True, exist_ok=True)
    metrics_dir.mkdir(parents=True, exist_ok=True)
    fid_real_dir.mkdir(parents=True, exist_ok=True)
    fid_fake_dir.mkdir(parents=True, exist_ok=True)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    channels = 1 if data.lower() == "mnist" else 3
    img_size = 28 if data.lower() == "mnist" else 32

    train_loader, val_loader = make_dataloader(data, batch_size, img_size, channels, val_split=val_split)
    model = TinyUNet(in_channels=channels, base=base, time_dim=time_dim).to(device)
    ddpm = DDPM(DiffusionConfig(timesteps=timesteps, beta_start=beta_start, beta_end=beta_end))
    optim = torch.optim.AdamW(model.parameters(), lr=lr)

    # Logging buffers
    train_steps, total_losses, eps_losses, x0_losses, cons_losses = [], [], [], [], []
    val_epochs, val_mse_eps, val_mse_x0, val_fids = [], [], [], []

    step = 0
    model.train()
    for epoch in range(epochs):
        pbar = tqdm(train_loader, desc=f"epoch {epoch+1}/{epochs}")
        for imgs, _ in pbar:
            imgs = imgs.to(device)
            t = torch.randint(0, ddpm.cfg.timesteps, (imgs.size(0),), device=device)
            total, parts = p_losses_multi(model, ddpm, imgs, t, w_eps=w_eps, w_x0=w_x0, w_consistency=w_consistency)
            optim.zero_grad(); total.backward(); optim.step()

            step += 1
            total_val = float(total.item())
            train_steps.append(step); total_losses.append(total_val)
            eps_losses.append(float(parts["loss_eps"]))
            x0_losses.append(float(parts["loss_x0"]))
            cons_losses.append(float(parts["loss_cons"]))
            pbar.set_postfix({"loss": f"{total_val:.4f}"})

            if step % sample_every == 0:
                with torch.no_grad():
                    grid_path = images_dir / f"samples_step{step}.png"
                    _path, _ = sample(
                        model, ddpm,
                        shape=(n_sample, channels, img_size, img_size),
                        device=device, n_steps=ddpm.cfg.timesteps,
                        save_path=str(grid_path),
                    )

        # --- Validation at end of epoch ---
        mse_eps, mse_x0 = evaluate_mse_multi(model, ddpm, val_loader, device)

        # Reset FID dirs (avoid mixing from prior epochs)
        for d in [fid_real_dir, fid_fake_dir]:
            for f in d.glob("*.png"): f.unlink()

        # Collect real images
        collected, imgs_accum = 0, []
        for imgs, _ in val_loader:
            imgs_accum.append(imgs)
            collected += imgs.size(0)
            if collected >= fid_eval_images:
                break
        real_imgs = torch.cat(imgs_accum, dim=0)[:fid_eval_images]
        real_imgs = denorm(real_imgs, channels)
        dump_images(real_imgs, str(fid_real_dir), prefix="real")

        # Generate fake images (tile until fid_eval_images)
        _, fake_batch = sample(
            model, ddpm,
            shape=(min(fid_eval_images, n_sample), channels, img_size, img_size),
            device=device, n_steps=ddpm.cfg.timesteps,
            save_path=str(images_dir / f"samples_val_epoch{epoch+1}.png"),
        )
        fake_list = [denorm(fake_batch.cpu(), channels)]
        while sum(x.size(0) for x in fake_list) < fid_eval_images:
            _, fb = sample(
                model, ddpm,
                shape=(min(fid_eval_images - sum(x.size(0) for x in fake_list), n_sample),
                       channels, img_size, img_size),
                device=device, n_steps=ddpm.cfg.timesteps, save_path=str(images_dir / "_tmp.png"),
            )
            fake_list.append(denorm(fb.cpu(), channels))
        fake_imgs = torch.cat(fake_list, dim=0)[:fid_eval_images]
        dump_images(fake_imgs, str(fid_fake_dir), prefix="fake")

        fid_score = compute_fid(str(fid_real_dir), str(fid_fake_dir), device)
        print(f"[val] epoch {epoch+1}: MSE_eps={mse_eps:.6f} | MSE_x0={mse_x0:.6f} | FID={fid_score:.2f}")

        # Log + save curves/CSVs
        val_epochs.append(epoch + 1)
        val_mse_eps.append(float(mse_eps))
        val_mse_x0.append(float(mse_x0))
        val_fids.append(float(fid_score))
        save_curves_multi(train_steps, total_losses, eps_losses, x0_losses, cons_losses,
                          val_epochs, val_mse_eps, val_mse_x0, val_fids, metrics_dir)

    # Final sample grid
    with torch.no_grad():
        path, _ = sample(
            model, ddpm,
            shape=(n_sample, channels, img_size, img_size),
            device=device, n_steps=ddpm.cfg.timesteps,
            save_path=str(images_dir / "samples_final.png"),
        )
    print(f"Saved final samples to {path}")

    # Final curves (redundant but harmless)
    save_curves_multi(train_steps, total_losses, eps_losses, x0_losses, cons_losses,
                      val_epochs, val_mse_eps, val_mse_x0, val_fids, metrics_dir)

# ---- Run ----
if __name__ == "__main__":
    # Customize destination folder and options here:
    train(save_dir="/content/drive/MyDrive/prototypes/mtsd_exp/mnist_mt_run1/", data="mnist", epochs=10)

100%|██████████| 9.91M/9.91M [00:00<00:00, 39.9MB/s]
100%|██████████| 28.9k/28.9k [00:00<00:00, 1.11MB/s]
100%|██████████| 1.65M/1.65M [00:00<00:00, 9.43MB/s]
100%|██████████| 4.54k/4.54k [00:00<00:00, 1.87MB/s]
epoch 1/10: 100%|██████████| 446/446 [15:26<00:00,  2.08s/it, loss=0.6037]
sampling: 100%|██████████| 200/200 [01:14<00:00,  2.68it/s]
sampling: 100%|██████████| 200/200 [01:16<00:00,  2.61it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.77it/s]
sampling: 100%|██████████| 200/200 [01:10<00:00,  2.83it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.77it/s]
sampling: 100%|██████████| 200/200 [01:10<00:00,  2.84it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.77it/s]
sampling: 100%|██████████| 200/200 [01:10<00:00,  2.85it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.79it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.80it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.77it/s]
sampling: 100%|██████████| 200/200 [01:10<00:00,  2.8

[val] epoch 1: MSE_eps=0.095872 | MSE_x0=0.054170 | FID=207.71


epoch 2/10:  12%|█▏        | 53/446 [01:57<13:55,  2.13s/it, loss=0.2614]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:29,  2.21it/s][A
sampling:   1%|          | 2/200 [00:00<01:12,  2.73it/s][A
sampling:   2%|▏         | 3/200 [00:01<01:06,  2.94it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:03,  3.08it/s][A
sampling:   2%|▎         | 5/200 [00:01<01:04,  3.04it/s][A
sampling:   3%|▎         | 6/200 [00:02<01:04,  3.01it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:04,  3.00it/s][A
sampling:   4%|▍         | 8/200 [00:02<01:04,  3.00it/s][A
sampling:   4%|▍         | 9/200 [00:03<01:01,  3.10it/s][A
sampling:   5%|▌         | 10/200 [00:03<00:59,  3.17it/s][A
sampling:   6%|▌         | 11/200 [00:03<00:59,  3.16it/s][A
sampling:   6%|▌         | 12/200 [00:03<01:00,  3.11it/s][A
sampling:   6%|▋         | 13/200 [00:04<01:00,  3.07it/s][A
sampling:   7%|▋         | 14/200 [00:04<01:01,  3.02it/s][A
sampling:   8%

[val] epoch 2: MSE_eps=0.071473 | MSE_x0=0.046228 | FID=187.86


epoch 3/10:  24%|██▍       | 107/446 [04:01<12:43,  2.25s/it, loss=0.1214]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:07,  2.96it/s][A
sampling:   1%|          | 2/200 [00:00<01:04,  3.08it/s][A
sampling:   2%|▏         | 3/200 [00:00<01:02,  3.13it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:02,  3.15it/s][A
sampling:   2%|▎         | 5/200 [00:01<01:03,  3.09it/s][A
sampling:   3%|▎         | 6/200 [00:01<01:02,  3.11it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:01,  3.13it/s][A
sampling:   4%|▍         | 8/200 [00:02<01:03,  3.04it/s][A
sampling:   4%|▍         | 9/200 [00:02<01:02,  3.07it/s][A
sampling:   5%|▌         | 10/200 [00:03<01:01,  3.09it/s][A
sampling:   6%|▌         | 11/200 [00:03<01:01,  3.06it/s][A
sampling:   6%|▌         | 12/200 [00:03<01:01,  3.08it/s][A
sampling:   6%|▋         | 13/200 [00:04<01:00,  3.07it/s][A
sampling:   7%|▋         | 14/200 [00:04<01:01,  3.01it/s][A
sampling:   8

[val] epoch 3: MSE_eps=0.067932 | MSE_x0=0.043209 | FID=155.80


epoch 4/10:  36%|███▌      | 161/446 [05:47<09:56,  2.09s/it, loss=0.1942]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:34,  2.11it/s][A
sampling:   1%|          | 2/200 [00:00<01:35,  2.08it/s][A
sampling:   2%|▏         | 3/200 [00:01<01:21,  2.41it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:18,  2.51it/s][A
sampling:   2%|▎         | 5/200 [00:02<01:14,  2.63it/s][A
sampling:   3%|▎         | 6/200 [00:02<01:09,  2.79it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:07,  2.87it/s][A
sampling:   4%|▍         | 8/200 [00:02<01:04,  2.96it/s][A
sampling:   4%|▍         | 9/200 [00:03<01:05,  2.92it/s][A
sampling:   5%|▌         | 10/200 [00:03<01:05,  2.89it/s][A
sampling:   6%|▌         | 11/200 [00:04<01:06,  2.86it/s][A
sampling:   6%|▌         | 12/200 [00:04<01:03,  2.94it/s][A
sampling:   6%|▋         | 13/200 [00:04<01:02,  2.99it/s][A
sampling:   7%|▋         | 14/200 [00:05<01:02,  2.99it/s][A
sampling:   8

[val] epoch 4: MSE_eps=0.071516 | MSE_x0=0.041497 | FID=160.04


epoch 5/10:  48%|████▊     | 215/446 [07:33<07:25,  1.93s/it, loss=0.1132]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:36,  2.05it/s][A
sampling:   1%|          | 2/200 [00:00<01:32,  2.15it/s][A
sampling:   2%|▏         | 3/200 [00:01<01:31,  2.16it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:32,  2.12it/s][A
sampling:   2%|▎         | 5/200 [00:02<01:31,  2.13it/s][A
sampling:   3%|▎         | 6/200 [00:02<01:29,  2.16it/s][A
sampling:   4%|▎         | 7/200 [00:03<01:30,  2.12it/s][A
sampling:   4%|▍         | 8/200 [00:03<01:30,  2.13it/s][A
sampling:   4%|▍         | 9/200 [00:04<01:19,  2.39it/s][A
sampling:   5%|▌         | 10/200 [00:04<01:13,  2.60it/s][A
sampling:   6%|▌         | 11/200 [00:04<01:08,  2.77it/s][A
sampling:   6%|▌         | 12/200 [00:04<01:05,  2.86it/s][A
sampling:   6%|▋         | 13/200 [00:05<01:02,  2.98it/s][A
sampling:   7%|▋         | 14/200 [00:05<01:00,  3.05it/s][A
sampling:   8

[val] epoch 5: MSE_eps=0.058867 | MSE_x0=0.039373 | FID=135.84


epoch 6/10:  60%|██████    | 269/446 [09:22<06:09,  2.08s/it, loss=0.0873]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:09,  2.85it/s][A
sampling:   1%|          | 2/200 [00:00<01:07,  2.92it/s][A
sampling:   2%|▏         | 3/200 [00:01<01:04,  3.04it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:02,  3.12it/s][A
sampling:   2%|▎         | 5/200 [00:01<01:01,  3.18it/s][A
sampling:   3%|▎         | 6/200 [00:01<01:03,  3.06it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:04,  3.01it/s][A
sampling:   4%|▍         | 8/200 [00:02<01:04,  3.00it/s][A
sampling:   4%|▍         | 9/200 [00:02<01:03,  3.02it/s][A
sampling:   5%|▌         | 10/200 [00:03<01:01,  3.09it/s][A
sampling:   6%|▌         | 11/200 [00:03<00:59,  3.15it/s][A
sampling:   6%|▌         | 12/200 [00:03<01:00,  3.09it/s][A
sampling:   6%|▋         | 13/200 [00:04<01:03,  2.96it/s][A
sampling:   7%|▋         | 14/200 [00:04<01:12,  2.57it/s][A
sampling:   8

[val] epoch 6: MSE_eps=0.057027 | MSE_x0=0.036565 | FID=142.79


epoch 7/10:  72%|███████▏  | 323/446 [11:13<04:28,  2.18s/it, loss=0.1727]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:03,  3.15it/s][A
sampling:   1%|          | 2/200 [00:00<01:01,  3.21it/s][A
sampling:   2%|▏         | 3/200 [00:00<01:04,  3.07it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:05,  2.98it/s][A
sampling:   2%|▎         | 5/200 [00:01<01:07,  2.90it/s][A
sampling:   3%|▎         | 6/200 [00:02<01:06,  2.90it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:06,  2.90it/s][A
sampling:   4%|▍         | 8/200 [00:02<01:07,  2.86it/s][A
sampling:   4%|▍         | 9/200 [00:03<01:06,  2.89it/s][A
sampling:   5%|▌         | 10/200 [00:03<01:05,  2.90it/s][A
sampling:   6%|▌         | 11/200 [00:03<01:05,  2.90it/s][A
sampling:   6%|▌         | 12/200 [00:04<01:02,  3.01it/s][A
sampling:   6%|▋         | 13/200 [00:04<01:00,  3.11it/s][A
sampling:   7%|▋         | 14/200 [00:04<01:00,  3.10it/s][A
sampling:   8

[val] epoch 7: MSE_eps=0.054594 | MSE_x0=0.036943 | FID=142.20


epoch 8/10:  85%|████████▍ | 377/446 [13:51<02:22,  2.07s/it, loss=0.1410]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:16,  2.62it/s][A
sampling:   1%|          | 2/200 [00:00<01:13,  2.71it/s][A
sampling:   2%|▏         | 3/200 [00:01<01:10,  2.81it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:07,  2.90it/s][A
sampling:   2%|▎         | 5/200 [00:01<01:12,  2.68it/s][A
sampling:   3%|▎         | 6/200 [00:02<01:23,  2.34it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:28,  2.19it/s][A
sampling:   4%|▍         | 8/200 [00:03<01:27,  2.18it/s][A
sampling:   4%|▍         | 9/200 [00:03<01:29,  2.15it/s][A
sampling:   5%|▌         | 10/200 [00:04<01:28,  2.14it/s][A
sampling:   6%|▌         | 11/200 [00:04<01:28,  2.13it/s][A
sampling:   6%|▌         | 12/200 [00:05<01:26,  2.18it/s][A
sampling:   6%|▋         | 13/200 [00:05<01:27,  2.14it/s][A
sampling:   7%|▋         | 14/200 [00:06<01:28,  2.10it/s][A
sampling:   8

[val] epoch 8: MSE_eps=0.053356 | MSE_x0=0.034980 | FID=135.50


epoch 9/10:  97%|█████████▋| 431/446 [15:09<00:32,  2.15s/it, loss=0.0944]
sampling:   0%|          | 0/200 [00:00<?, ?it/s][A
sampling:   0%|          | 1/200 [00:00<01:03,  3.12it/s][A
sampling:   1%|          | 2/200 [00:00<01:01,  3.21it/s][A
sampling:   2%|▏         | 3/200 [00:00<01:00,  3.24it/s][A
sampling:   2%|▏         | 4/200 [00:01<01:01,  3.20it/s][A
sampling:   2%|▎         | 5/200 [00:01<01:00,  3.24it/s][A
sampling:   3%|▎         | 6/200 [00:01<00:59,  3.27it/s][A
sampling:   4%|▎         | 7/200 [00:02<01:00,  3.22it/s][A
sampling:   4%|▍         | 8/200 [00:02<00:59,  3.24it/s][A
sampling:   4%|▍         | 9/200 [00:02<00:58,  3.26it/s][A
sampling:   5%|▌         | 10/200 [00:03<00:58,  3.22it/s][A
sampling:   6%|▌         | 11/200 [00:03<00:58,  3.23it/s][A
sampling:   6%|▌         | 12/200 [00:03<00:57,  3.27it/s][A
sampling:   6%|▋         | 13/200 [00:04<00:57,  3.26it/s][A
sampling:   7%|▋         | 14/200 [00:04<00:58,  3.21it/s][A
sampling:   8

[val] epoch 9: MSE_eps=0.053928 | MSE_x0=0.035097 | FID=137.25


epoch 10/10: 100%|██████████| 446/446 [15:39<00:00,  2.11s/it, loss=0.2777]
sampling: 100%|██████████| 200/200 [01:17<00:00,  2.58it/s]
sampling: 100%|██████████| 200/200 [01:14<00:00,  2.67it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.76it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.78it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.74it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.80it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.77it/s]
sampling: 100%|██████████| 200/200 [01:17<00:00,  2.59it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.82it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.80it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.75it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.79it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.80it/s]
sampling: 100%|██████████| 200/200 [01:12<00:00,  2.77it/s]
sampling: 100%|██████████| 200/200 [01:11<00:00,  2.79it/s]
sampling: 100%|█████████

[val] epoch 10: MSE_eps=0.051901 | MSE_x0=0.036182 | FID=121.44


sampling: 100%|██████████| 200/200 [01:14<00:00,  2.68it/s]


Saved final samples to /content/drive/MyDrive/prototypes/mtsd_exp/mnist_mt_run1/images/samples_final.png
