# Diffusion-regularised UNet Autoencoder

Modular training and evaluation notebook (AE warm-start, joint training, metrics, visualisations, ablations).

In [None]:
# CELL 0: Environment setup (Colab, optional)
import sys
IN_COLAB = 'google.colab' in sys.modules
if IN_COLAB:
    !pip install --upgrade pip -q
    !pip install torch torchvision torchaudio -q
    !pip install diffusers[torch] transformers accelerate einops -q
    !pip install scikit-image scikit-learn umap-learn matplotlib seaborn tqdm lpips pytorch-msssim pandas -q
    !pip install torch-fidelity -q

In [None]:
# CELL 1: Global setup & imports
import warnings
warnings.filterwarnings("ignore")

import matplotlib.pyplot as plt

# LaTeX-like fonts for all figures
plt.rcParams.update({
    "font.size": 11,
    "font.family": "serif",
    "mathtext.fontset": "cm",
    "mathtext.rm": "serif",
})

from train_ae import train_ae
from train_joint import train_joint
from eval import run_eval
from ablation import run_ablation

In [None]:
# CELL 2: Autoencoder warm-start training
history_ae = train_ae()

In [None]:
# CELL 3: Joint training (decoder + denoiser)
history_joint = train_joint()

In [None]:
# CELL 4: Evaluation and visualisations (PSNR, SSIM, LPIPS, FID, grids, UMAP, histograms)
run_eval()

In [None]:
# CELL 5: Ablation experiments (latent dims, schedules, with/without diffusion)
# WARNING: This can be computationally expensive.
run_ablation()