# Training Analysis

Post-training analysis of the VAE + LCFM pipeline.

## Table of Contents
1. [Setup](#setup)
2. [VAE Results](#vae-results)
   - 2.1. [Loss Curves](#vae-loss)
   - 2.2. [Reconstruction & Residuals](#vae-recon)
   - 2.3. [Random Samples](#vae-random)
3. [LCFM Results](#lcfm-results)
   - 3.1. [Loss Curves](#lcfm-loss)
   - 3.2. [Sample Quality Progression](#lcfm-prog)
   - 3.3. [Conditional Generation](#lcfm-gen)
   - 3.4. [Random Generation](#lcfm-random)
   - 3.5. [Flow Trajectory](#lcfm-traj)
4. [Summary](#summary)
   - [Latent Space Visualization](#latent-pca)
   - [Per-Band Pixel Distributions](#pixel-dist)
   - [Summary Statistics](#summary-stats)


## 1. Setup <a id="setup"></a>

In [None]:
%matplotlib widget

import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

from galgenai.models import VAE, VAEEncoder, LCFM
from galgenai.data.hsc import get_dataset_and_loaders
from galgenai import get_device
from datasets import load_from_disk


def strip_compile_prefix(state_dict):
    """Strip '_orig_mod.' prefix added by torch.compile()."""
    prefix = "_orig_mod."
    return {k.removeprefix(prefix): v for k, v in state_dict.items()}

In [None]:
# ---- USER-EDITABLE PARAMETERS ----

# Paths
data_path = "../data/hsc_mmu_mini/"
output_dir = Path("../pipeline_output")

# Model hyperparameters (must match training)
in_channels = 5
latent_dim = 32
input_size = 64
lcfm_base_channels = 64
lcfm_beta = 0.001

# Visualization parameters
n_plot = 4
n_gen_lcfm = 7
num_ode_steps = 50

# Band names and colormap
bands = [f"hsc-{b}" for b in "grizy"]
n_bands = len(bands)
cmap = plt.get_cmap("magma")
cmap.set_bad("0.5")

In [None]:
device = get_device()
print(f"Using device: {device}")

dataset_raw = load_from_disk(data_path)
dataset, train_loader, val_loader = get_dataset_and_loaders(
    dataset_raw,
    nx=input_size,
    batch_size=128,
    num_workers=0,
    split=0.8,
)
print(f"Dataset size: {len(dataset)} samples")
print(f"Train batches: {len(train_loader)}, Val batches: {len(val_loader)}")

In [None]:
# Load VAE from best checkpoint
vae_ckpt_path = output_dir / "vae" / "checkpoints" / "best.pt"
vae_ckpt = torch.load(vae_ckpt_path, map_location=device, weights_only=False)

vae = VAE(
    in_channels=in_channels,
    latent_dim=latent_dim,
    input_size=input_size,
)
vae.load_state_dict(strip_compile_prefix(vae_ckpt["model_state_dict"]))
vae.to(device).eval()

vae_loss_history = vae_ckpt["loss_history"]

n_vae_params = sum(p.numel() for p in vae.parameters())
print(f"VAE parameters: {n_vae_params:,}")
print(f"VAE training steps: {vae_ckpt['global_step']}")
print(f"VAE best loss: {vae_ckpt['best_loss']:.4e}")

In [None]:
# Load LCFM from best checkpoint
encoder = VAEEncoder(
    in_channels=in_channels,
    latent_dim=latent_dim,
    input_size=input_size,
)
encoder_path = output_dir / "encoder.pt"
encoder.load_state_dict(torch.load(encoder_path, map_location=device))

lcfm_ckpt_path = output_dir / "lcfm" / "checkpoints" / "best.pt"
lcfm_ckpt = torch.load(lcfm_ckpt_path, map_location=device, weights_only=False)

lcfm = LCFM(
    vae_encoder=encoder,
    latent_dim=latent_dim,
    in_channels=in_channels,
    input_size=input_size,
    base_channels=lcfm_base_channels,
    beta=lcfm_beta,
)
lcfm.load_state_dict(strip_compile_prefix(lcfm_ckpt["model_state_dict"]))
lcfm.to(device).eval()

lcfm_loss_history = lcfm_ckpt["loss_history"]

n_lcfm_total = sum(p.numel() for p in lcfm.parameters())
n_lcfm_trainable = sum(p.numel() for p in lcfm.parameters() if p.requires_grad)
print(f"LCFM total parameters: {n_lcfm_total:,}")
print(f"LCFM trainable parameters: {n_lcfm_trainable:,}")
print(f"LCFM training steps: {lcfm_ckpt['global_step']}")
print(f"LCFM best loss: {lcfm_ckpt['best_loss']:.4e}")

## 2. VAE Results <a id="vae-results"></a>

### 2.1. Loss Curves <a id="vae-loss"></a>

In [None]:
# VAE loss curves (epoch-based)
vae_epochs = [e["epoch"] for e in vae_loss_history]

vae_total = [e["total_loss"] for e in vae_loss_history]
vae_recon = [e["recon_loss"] for e in vae_loss_history]
vae_kl = [e["kl_loss"] for e in vae_loss_history]

vae_val_total = [
    e["val_total_loss"] for e in vae_loss_history if "val_total_loss" in e
]
vae_val_recon = [
    e["val_recon_loss"] for e in vae_loss_history if "val_recon_loss" in e
]
vae_val_kl = [e["val_kl_loss"] for e in vae_loss_history if "val_kl_loss" in e]
vae_val_epochs = [
    e["epoch"] for e in vae_loss_history if "val_total_loss" in e
]

fig, axs = plt.subplots(1, 3, figsize=(12, 3.5))

for ax, train, val, val_ep, title in zip(
    axs,
    [vae_total, vae_recon, vae_kl],
    [vae_val_total, vae_val_recon, vae_val_kl],
    [vae_val_epochs] * 3,
    ["Total Loss", "Recon Loss", "KL Divergence"],
    strict=True,
):
    ax.plot(vae_epochs, train, label="Train")
    if val:
        ax.semilogy(val_ep, val, label="Val")
    ax.set_xlabel("Epoch")
    ax.set_title(title)
    ax.legend()

fig.suptitle("VAE Training Loss", fontsize=14)
fig.tight_layout()

### 2.2. Reconstruction & Residuals <a id="vae-recon"></a>

Compare with [3.3. LCFM Conditional Generation](#lcfm-gen).

In [None]:
# VAE reconstruction grid
plt.close("all")

imgs = torch.stack(next(iter(val_loader)))  # [3, batch, 5, 64, 64]
imgs = imgs[:, :n_plot, :, :, :]
with torch.no_grad():
    recs, _, _ = vae(imgs[0, :, :, :, :].to(device))

fig, axs = plt.subplots(2 * n_plot, n_bands, figsize=(8, 10))
for i in range(n_plot):
    img, ivr, msk = imgs[:, i, :, :, :].cpu()
    rec = recs[i].cpu()
    for j, band in enumerate(bands):
        vmin, vmax = img[j].min(), img[j].max()
        axs[2 * i, j].imshow(
            img[j] / msk[j],
            origin="lower",
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
        )
        axs[2 * i + 1, j].imshow(
            rec[j] / msk[j],
            origin="lower",
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
        )
        if i == 0:
            axs[0, j].set_title(band)
    axs[2 * i, 0].set_ylabel("Data")
    axs[2 * i + 1, 0].set_ylabel("Reconst.")

for ax in axs.flatten():
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
fig.suptitle("Reconstruction vs truth (validation set)")
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.02, top=0.9)

In [None]:
# Residual maps: (data - reconstruction) * mask
cmap_resid = plt.get_cmap("RdBu_r")
cmap_resid.set_bad("0.5")

fig, axs = plt.subplots(n_plot, n_bands, figsize=(8, 6))
for i in range(n_plot):
    img, ivr, msk = imgs[:, i, :, :, :].cpu()
    rec = recs[i].cpu()
    for j, band in enumerate(bands):
        residual = (img[j] - rec[j]) * msk[j]
        abs_max = residual.abs().max().item()
        axs[i, j].imshow(
            residual,
            origin="lower",
            cmap=cmap_resid,
            vmin=-abs_max,
            vmax=abs_max,
        )
        if i == 0:
            axs[0, j].set_title(band)
    axs[i, 0].set_ylabel(f"Galaxy #{i + 1}")

for ax in axs.flatten():
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
fig.suptitle("Residuals: (Data - Reconstruction) * Mask")
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.02, top=0.9)

### 2.3. Random Samples <a id="vae-random"></a>

Compare with [3.4. LCFM Random Generation](#lcfm-random).

In [None]:
# VAE random generation from N(0, I) prior
fig, axs = plt.subplots(n_plot, n_bands, figsize=(8, 6))
with torch.no_grad():
    gen_imgs = vae.generate(n_plot, device)
for i in range(n_plot):
    img = gen_imgs[i].cpu()
    for j, band in enumerate(bands):
        axs[i, j].imshow(img[j], origin="lower", cmap=cmap)
        if i == 0:
            axs[0, j].set_title(band)
    axs[i, 0].set_ylabel(f"Galaxy #{i + 1}")

for ax in axs.flatten():
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
fig.suptitle("Randomly generated images")
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.02, top=0.9)

## 3. LCFM Results <a id="lcfm-results"></a>

### 3.1. Loss Curves <a id="lcfm-loss"></a>

In [None]:
# LCFM loss curves (step-based, train + val)
lcfm_steps = [e["step"] for e in lcfm_loss_history]
lcfm_total = [e["total_loss"] for e in lcfm_loss_history]
lcfm_flow = [e["flow_loss"] for e in lcfm_loss_history]
lcfm_kl = [e["kl_loss"] for e in lcfm_loss_history]

lcfm_val_total = [
    e["val_total_loss"] for e in lcfm_loss_history if "val_total_loss" in e
]
lcfm_val_flow = [
    e["val_flow_loss"] for e in lcfm_loss_history if "val_flow_loss" in e
]
lcfm_val_kl = [
    e["val_kl_loss"] for e in lcfm_loss_history if "val_kl_loss" in e
]
lcfm_val_steps = [
    e["step"] for e in lcfm_loss_history if "val_total_loss" in e
]

fig, axs = plt.subplots(1, 3, figsize=(12, 3.5))

for ax, train, val, title in zip(
    axs,
    [lcfm_total, lcfm_flow, lcfm_kl],
    [lcfm_val_total, lcfm_val_flow, lcfm_val_kl],
    ["Total Loss", "Flow Loss", "KL Divergence"],
    strict=True,
):
    ax.semilogy(lcfm_steps, train, label="Train")
    if val:
        ax.semilogy(lcfm_val_steps, val, label="Val")
    ax.set_xlabel("Step")
    ax.set_title(title)
    ax.legend()

fig.suptitle("LCFM Training Loss", fontsize=14)
fig.tight_layout()

### 3.2. Sample Quality Progression <a id="lcfm-prog"></a>

In [None]:
# Load pre-saved LCFM samples at different training steps
sample_dir = output_dir / "lcfm" / "samples"
sample_files = sorted(sample_dir.glob("samples_step_*.pt"))
print(f"Found {len(sample_files)} sample files")

n_steps_show = len(sample_files)
# Show first sample from each step
i_show = 0

fig, axs = plt.subplots(
    n_steps_show,
    2 * n_bands,
    figsize=(14, 2 * n_steps_show),
)
if n_steps_show == 1:
    axs = axs[None, :]

for row, sf in enumerate(sample_files):
    data = torch.load(sf, map_location="cpu", weights_only=False)
    samp = data["samples"][i_show]  # [5, 64, 64]
    cond = data["conditioning"][i_show]  # [5, 64, 64]

    step_num = sf.stem.split("_")[-1]

    for j in range(n_bands):
        vmin, vmax = cond[j].min(), cond[j].max()
        axs[row, j].imshow(
            cond[j],
            origin="lower",
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
        )
        axs[row, n_bands + j].imshow(
            samp[j],
            origin="lower",
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
        )
        if row == 0:
            axs[0, j].set_title(bands[j])
            axs[0, n_bands + j].set_title(bands[j])

    axs[row, 0].set_ylabel(f"Step {step_num}")

for ax in axs.flatten():
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])

fig.text(0.27, 0.99, "Conditioning", ha="center", va="top", fontsize=13)
fig.text(0.73, 0.99, "Generated", ha="center", va="top", fontsize=13)
fig.suptitle("LCFM Sample Quality Progression", fontsize=14, y=1.02)
fig.tight_layout()

### 3.3. Conditional Generation <a id="lcfm-gen"></a>

Compare with [2.2. VAE Reconstruction & Residuals](#vae-recon).

In [None]:
# LCFM conditional samples
con_img = torch.stack(next(iter(val_loader)))[0, 0, :, :, :]
con_imgs = torch.stack([con_img for _ in range(n_gen_lcfm)]).to(device)

with torch.no_grad():
    lcfm_gen = lcfm.sample(con_imgs, num_steps=num_ode_steps)

fig, axs = plt.subplots(1 + n_gen_lcfm, n_bands, figsize=(8, 10))
for i in range(1 + n_gen_lcfm):
    img = con_img.cpu() if i == 0 else lcfm_gen[i - 1].cpu()
    for j, band in enumerate(bands):
        vmin, vmax = con_img[j].cpu().min(), con_img[j].cpu().max()
        axs[i, j].imshow(
            img[j],
            origin="lower",
            cmap=cmap,
            vmin=vmin,
            vmax=vmax,
        )
        if i == 0:
            axs[0, j].set_title(band)
    if i == 0:
        axs[i, 0].set_ylabel("Original")
    else:
        axs[i, 0].set_ylabel("Generated")

for ax in axs.flatten():
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
fig.suptitle("Generated galaxies vs their conditional")
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.02, top=0.9)

### 3.4. Random Generation <a id="lcfm-random"></a>

Compare with [2.3. VAE Random Samples](#vae-random).

In [None]:
# LCFM unconditional generation: sample latents from N(0, I)
# instead of encoding a real image
with torch.no_grad():
    f = torch.randn(n_plot, latent_dim, device=device)
    x = torch.randn(
        n_plot,
        in_channels,
        input_size,
        input_size,
        device=device,
    )
    dt = 1.0 / num_ode_steps
    for step in range(num_ode_steps):
        t = torch.full((n_plot,), step * dt, device=device)
        v = lcfm.velocity_net(x, f, t)
        x = x + v * dt
    lcfm_random = x

fig, axs = plt.subplots(n_plot, n_bands, figsize=(8, 6))
for i in range(n_plot):
    img = lcfm_random[i].cpu()
    for j, band in enumerate(bands):
        axs[i, j].imshow(img[j], origin="lower", cmap=cmap)
        if i == 0:
            axs[0, j].set_title(band)
    axs[i, 0].set_ylabel(f"Galaxy #{i + 1}")

for ax in axs.flatten():
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
fig.suptitle("LCFM unconditional generation (prior latents)")
fig.subplots_adjust(left=0.05, right=0.95, bottom=0.02, top=0.9)

### 3.5. Flow Trajectory <a id="lcfm-traj"></a>

In [None]:
# ODE trajectory visualization (single band: i-band, index 2)
cond_img = torch.stack(next(iter(val_loader)))[0, 0:1, :, :, :].to(device)

with torch.no_grad():
    final, trajectory = lcfm.sample(
        cond_img,
        num_steps=num_ode_steps,
        return_trajectory=True,
    )

# Pick ~8 evenly spaced time steps from trajectory
n_traj_show = 8
indices = np.linspace(0, len(trajectory) - 1, n_traj_show, dtype=int)
band_idx = 2  # i-band

fig, axs = plt.subplots(
    1,
    1 + n_traj_show,
    figsize=(2 * (1 + n_traj_show), 2),
)

# Conditioning image
axs[0].imshow(
    cond_img[0, band_idx].cpu(),
    origin="lower",
    cmap=cmap,
)
axs[0].set_title("Cond.")

# Trajectory snapshots
for k, idx in enumerate(indices):
    t_val = idx / (len(trajectory) - 1)
    axs[k + 1].imshow(
        trajectory[idx][0, band_idx].cpu(),
        origin="lower",
        cmap=cmap,
    )
    axs[k + 1].set_title(f"t={t_val:.2f}")

for ax in axs:
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
fig.suptitle(f"Flow Trajectory ({bands[band_idx]})", fontsize=13)
fig.tight_layout()

## 4. Summary <a id="summary"></a>

### Latent Space Visualization <a id="latent-pca"></a>

In [None]:
# PCA of VAE latent space
all_mu = []
all_flux = []
n_collect = 500

vae.eval()
with torch.no_grad():
    for batch in val_loader:
        data = torch.stack(batch)[0].to(device)
        mu, _ = vae.encoder(data)
        total_flux = data.sum(dim=(1, 2, 3))
        all_mu.append(mu.cpu())
        all_flux.append(total_flux.cpu())
        if sum(m.shape[0] for m in all_mu) >= n_collect:
            break

all_mu = torch.cat(all_mu)[:n_collect]
all_flux = torch.cat(all_flux)[:n_collect]

# PCA via torch.pca_lowrank
mu_centered = all_mu - all_mu.mean(dim=0)
U, S, V = torch.pca_lowrank(mu_centered, q=2)
projected = mu_centered @ V  # [n_collect, 2]

# Explained variance ratio
total_var = (S**2).sum()
explained = (S**2) / total_var

fig, ax = plt.subplots(figsize=(6, 5))
sc = ax.scatter(
    projected[:, 0].numpy(),
    projected[:, 1].numpy(),
    c=torch.log10(all_flux).numpy(),
    s=8,
    alpha=0.7,
    cmap="viridis",
)
fig.colorbar(sc, ax=ax, label="log10(total flux)")
ax.set_xlabel(f"PC1 ({explained[0]:.1%} var)")
ax.set_ylabel(f"PC2 ({explained[1]:.1%} var)")
ax.set_title("VAE Latent Space (PCA)")
fig.tight_layout()

### Per-Band Pixel Distributions <a id="pixel-dist"></a>

In [None]:
# Pixel value histograms: real vs VAE recon vs VAE random vs LCFM
# Collect a batch of real, reconstructed, random, and LCFM images
val_batch = torch.stack(next(iter(val_loader)))
real_data = val_batch[0].cpu()  # [batch, 5, 64, 64]

with torch.no_grad():
    vae_recs, _, _ = vae(real_data.to(device))
    vae_recs = vae_recs.cpu()
    vae_rand = vae.generate(real_data.shape[0], device).cpu()

    lcfm_cond = real_data[:n_gen_lcfm].to(device)
    lcfm_samp = lcfm.sample(lcfm_cond, num_steps=num_ode_steps).cpu()

fig, axs = plt.subplots(1, n_bands, figsize=(14, 3))
for j, band in enumerate(bands):
    real_pix = real_data[:, j].flatten().numpy()
    p99 = np.percentile(real_pix, 99)
    bin_edges = np.linspace(0, p99, 80)

    axs[j].hist(
        real_pix,
        bins=bin_edges,
        density=True,
        alpha=0.5,
        label="Real",
    )
    axs[j].hist(
        vae_recs[:, j].flatten().numpy(),
        bins=bin_edges,
        density=True,
        alpha=0.5,
        label="VAE Recon",
    )
    axs[j].hist(
        vae_rand[:, j].flatten().numpy(),
        bins=bin_edges,
        density=True,
        alpha=0.5,
        label="VAE Random",
    )
    axs[j].hist(
        lcfm_samp[:, j].flatten().numpy(),
        bins=bin_edges,
        density=True,
        alpha=0.5,
        label="LCFM",
    )
    axs[j].set_yscale("log")
    axs[j].set_title(band)
    if j == 0:
        axs[j].legend(fontsize=7)

fig.suptitle("Per-Band Pixel Distributions", fontsize=14)
fig.tight_layout()

### Summary Statistics <a id="summary-stats"></a>

In [None]:
# Summary statistics
print("=" * 60)
print("TRAINING ANALYSIS SUMMARY")
print("=" * 60)

print("\nVAE")
print(f"  Parameters: {n_vae_params:,}")
print(f"  Training steps: {vae_ckpt['global_step']}")
print(f"  Best loss: {vae_ckpt['best_loss']:.4e}")
if vae_loss_history:
    last = vae_loss_history[-1]
    print(f"  Final train loss: {last['total_loss']:.4e}")
    if "val_total_loss" in last:
        print(f"  Final val loss: {last['val_total_loss']:.4e}")

print("\nLCFM")
print(f"  Total parameters: {n_lcfm_total:,}")
print(f"  Trainable parameters: {n_lcfm_trainable:,}")
print(f"  Training steps: {lcfm_ckpt['global_step']}")
print(f"  Best loss: {lcfm_ckpt['best_loss']:.4e}")
if lcfm_loss_history:
    last = lcfm_loss_history[-1]
    print(f"  Final train loss: {last['total_loss']:.4e}")
    if "val_total_loss" in last:
        print(f"  Final val loss: {last['val_total_loss']:.4e}")

print("\nArtifact Paths")
print(f"  VAE ckpts: {output_dir / 'vae' / 'checkpoints'}")
print(f"  Encoder: {output_dir / 'encoder.pt'}")
print(f"  LCFM ckpts: {output_dir / 'lcfm' / 'checkpoints'}")
print(f"  LCFM samples: {output_dir / 'lcfm' / 'samples'}")