# MM-Reg: Photonics Dataset Experiment

Baseline VAE vs MM-Reg VAE for latent diffusion on binary photonics images.

## Architecture
```
Image (1x64x64) -> SimpleConvVAE -> 64-d latent -> MLP Diffusion -> Generate
```

## Pipeline
1. **Load Data**: Load `imagenorm.npy`, resize to 64x64, threshold to binary
2. **PCA Reference**: Compute PCA embeddings for MM-Reg
3. **Train VAEs**: Baseline (no MM-Reg) vs MM-Reg
4. **Evaluate VAEs**: Reconstruction quality, distance correlation
5. **Train Diffusion**: MLP denoiser on 64-d latents from each VAE
6. **Generate & Compare**: Sample quality from both pipelines

**Hypothesis**: MM-Reg creates smoother latent space -> diffusion learns faster/better.

## 1. Setup

In [None]:
import sys, os
sys.path.insert(0, os.path.abspath('..'))

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import json
from torch.utils.data import Dataset, DataLoader, TensorDataset
from tqdm import tqdm
from scipy.stats import pearsonr, spearmanr
from sklearn.decomposition import PCA

from src.models.losses import MMRegLoss, pairwise_distances, get_upper_triangular
from src.models.diffusion import MLPDenoiser, GaussianDiffusion
from src.diffusion_trainer import DiffusionTrainer

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
CONFIG = {
    # Data
    'image_size': 64,
    'train_split': 0.8,
    'batch_size': 32,

    # PCA
    'pca_components': 64,

    # VAE
    'latent_dim': 64,
    'vae_epochs': 200,
    'vae_lr': 1e-3,
    'lambda_mm': 1.0,
    'beta': 0.001,

    # Diffusion
    'diffusion_epochs': 300,
    'diffusion_lr': 1e-4,
    'diffusion_timesteps': 1000,
}

for k, v in CONFIG.items():
    print(f"  {k}: {v}")

## 2. Load & Preprocess Photonics Data

In [None]:
# Load raw data
raw = np.load('../imagenorm.npy')
print(f"Raw shape: {raw.shape}, dtype: {raw.dtype}")
print(f"Range: [{raw.min():.4f}, {raw.max():.4f}]")

# Resize to 64x64 and threshold to binary
from PIL import Image

size = CONFIG['image_size']
processed = []
for i in range(raw.shape[0]):
    img = Image.fromarray(raw[i]).resize((size, size), Image.BILINEAR)
    arr = np.array(img, dtype=np.float32)
    arr = (arr > 0.5).astype(np.float32)  # threshold to binary
    processed.append(arr)

data = np.stack(processed)  # (N, 64, 64)
print(f"Processed: {data.shape}, unique values: {np.unique(data)}")

# Convert to tensor: (N, 1, 64, 64)
data_tensor = torch.from_numpy(data).unsqueeze(1).float()
print(f"Tensor shape: {data_tensor.shape}")

# Train/val split
n = len(data_tensor)
n_train = int(n * CONFIG['train_split'])
perm = torch.randperm(n)
train_idx, val_idx = perm[:n_train], perm[n_train:]

train_images = data_tensor[train_idx]
val_images = data_tensor[val_idx]
print(f"Train: {train_images.shape[0]}, Val: {val_images.shape[0]}")

In [None]:
# Visualize samples
fig, axes = plt.subplots(1, 8, figsize=(16, 2))
for i, ax in enumerate(axes):
    ax.imshow(train_images[i, 0], cmap='gray')
    ax.axis('off')
    ax.set_title(f'#{i}')
plt.suptitle('Processed Training Samples (64x64 binary)', fontsize=12)
plt.tight_layout()
plt.show()

## 3. Compute PCA Reference Embeddings

In [None]:
n_comp = CONFIG['pca_components']

# Flatten images for PCA
train_flat = train_images.view(train_images.shape[0], -1).numpy()
val_flat = val_images.view(val_images.shape[0], -1).numpy()

pca = PCA(n_components=n_comp)
train_pca = torch.from_numpy(pca.fit_transform(train_flat)).float()
val_pca = torch.from_numpy(pca.transform(val_flat)).float()

print(f"Train PCA: {train_pca.shape}, Val PCA: {val_pca.shape}")
print(f"Explained variance: {pca.explained_variance_ratio_.sum():.4f}")

## 4. Define SimpleConvVAE

Lightweight convolutional VAE for single-channel 64x64 binary images.
Encodes directly to a flat latent vector (no SD VAE needed for these simple images).

In [None]:
class SimpleConvVAE(nn.Module):
    """Conv VAE: 1x64x64 -> latent_dim -> 1x64x64."""

    def __init__(self, latent_dim=64):
        super().__init__()
        self.latent_dim = latent_dim

        # Encoder: 1x64x64 -> 256x4x4
        self.encoder = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.ReLU(True),   # -> 32x32
            nn.Conv2d(32, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),  # -> 16x16
            nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),# -> 8x8
            nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(True),# -> 4x4
        )
        self.flat_dim = 256 * 4 * 4
        self.fc_mu = nn.Linear(self.flat_dim, latent_dim)
        self.fc_logvar = nn.Linear(self.flat_dim, latent_dim)

        # Decoder: latent_dim -> 1x64x64
        self.fc_decode = nn.Linear(latent_dim, self.flat_dim)
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(True),  # -> 8x8
            nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(True),    # -> 16x16
            nn.ConvTranspose2d(64, 32, 4, 2, 1), nn.BatchNorm2d(32), nn.ReLU(True),     # -> 32x32
            nn.ConvTranspose2d(32, 1, 4, 2, 1), nn.Sigmoid(),                            # -> 64x64
        )

    def encode(self, x):
        h = self.encoder(x).view(-1, self.flat_dim)
        return self.fc_mu(h), self.fc_logvar(h)

    def reparameterize(self, mu, logvar):
        if self.training:
            return mu + torch.randn_like(mu) * torch.exp(0.5 * logvar)
        return mu

    def decode(self, z):
        h = self.fc_decode(z).view(-1, 256, 4, 4)
        return self.decoder(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar, z


# Quick test
_vae = SimpleConvVAE(CONFIG['latent_dim'])
_out = _vae(torch.randn(2, 1, 64, 64))
print(f"VAE params: {sum(p.numel() for p in _vae.parameters()):,}")
print(f"Input: (2, 1, 64, 64) -> Recon: {_out[0].shape}, mu: {_out[1].shape}")
del _vae, _out

## 5. VAE Training Function

In [None]:
mm_loss_fn = MMRegLoss(variant='correlation')


def train_vae(model, train_images, train_pca, val_images, val_pca,
              config, use_mmreg=False, device='cuda', save_dir=None):
    """Train a SimpleConvVAE with optional MM-Reg."""
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=config['vae_lr'])
    bs = config['batch_size']
    beta = config['beta']
    lam = config['lambda_mm'] if use_mmreg else 0.0

    if save_dir:
        os.makedirs(save_dir, exist_ok=True)

    history = {'train': [], 'val': []}
    best_val = float('inf')

    for epoch in range(config['vae_epochs']):
        # --- Train ---
        model.train()
        perm = torch.randperm(len(train_images))
        epoch_loss = {'loss': 0, 'recon': 0, 'kl': 0, 'mm': 0}
        n_batches = 0

        for i in range(0, len(train_images), bs):
            idx = perm[i:i+bs]
            imgs = train_images[idx].to(device)
            pca_emb = train_pca[idx].to(device)

            x_recon, mu, logvar, z = model(imgs)

            recon = F.mse_loss(x_recon, imgs)
            kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
            mm = mm_loss_fn(z, pca_emb) if lam > 0 else torch.tensor(0.0, device=device)

            loss = recon + beta * kl + lam * mm

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            epoch_loss['loss'] += loss.item()
            epoch_loss['recon'] += recon.item()
            epoch_loss['kl'] += kl.item()
            epoch_loss['mm'] += mm.item()
            n_batches += 1

        train_metrics = {k: v / n_batches for k, v in epoch_loss.items()}
        history['train'].append(train_metrics)

        # --- Validate ---
        model.eval()
        with torch.no_grad():
            val_loss = {'loss': 0, 'recon': 0, 'kl': 0, 'mm': 0}
            n_val = 0
            for i in range(0, len(val_images), bs):
                imgs = val_images[i:i+bs].to(device)
                pca_emb = val_pca[i:i+bs].to(device)

                x_recon, mu, logvar, z = model(imgs)

                recon = F.mse_loss(x_recon, imgs)
                kl = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1).mean()
                mm = mm_loss_fn(z, pca_emb) if lam > 0 else torch.tensor(0.0, device=device)

                val_loss['loss'] += (recon + beta * kl + lam * mm).item()
                val_loss['recon'] += recon.item()
                val_loss['kl'] += kl.item()
                val_loss['mm'] += mm.item()
                n_val += 1

            val_metrics = {k: v / n_val for k, v in val_loss.items()}
            history['val'].append(val_metrics)

        if val_metrics['loss'] < best_val:
            best_val = val_metrics['loss']
            if save_dir:
                torch.save(model.state_dict(), os.path.join(save_dir, 'best.pt'))

        if (epoch + 1) % 50 == 0 or epoch == 0:
            tag = 'MMReg' if use_mmreg else 'Baseline'
            print(f"[{tag}] Epoch {epoch+1}/{config['vae_epochs']} - "
                  f"Train: loss={train_metrics['loss']:.4f} recon={train_metrics['recon']:.4f} "
                  f"kl={train_metrics['kl']:.2f} mm={train_metrics['mm']:.4f} | "
                  f"Val: loss={val_metrics['loss']:.4f} recon={val_metrics['recon']:.4f}")

    if save_dir:
        torch.save(model.state_dict(), os.path.join(save_dir, 'final.pt'))
        with open(os.path.join(save_dir, 'history.json'), 'w') as f:
            json.dump(history, f, indent=2)

    return history

## 6. Train Baseline VAE (no MM-Reg)

In [None]:
print("="*60)
print("TRAINING BASELINE VAE (no MM-Reg)")
print("="*60)

vae_baseline = SimpleConvVAE(CONFIG['latent_dim'])
hist_baseline = train_vae(
    vae_baseline, train_images, train_pca, val_images, val_pca,
    CONFIG, use_mmreg=False, device=device,
    save_dir='../checkpoints/photonics_baseline'
)

## 7. Train MM-Reg VAE

In [None]:
print("="*60)
print("TRAINING MM-REG VAE")
print("="*60)

vae_mmreg = SimpleConvVAE(CONFIG['latent_dim'])
hist_mmreg = train_vae(
    vae_mmreg, train_images, train_pca, val_images, val_pca,
    CONFIG, use_mmreg=True, device=device,
    save_dir='../checkpoints/photonics_mmreg'
)

## 8. Evaluate VAEs

In [None]:
def evaluate_vae(model, images, pca_emb, name, device):
    """Evaluate reconstruction and distance correlation."""
    model.eval()
    all_z = []
    total_mse = 0
    total_acc = 0
    n = 0

    with torch.no_grad():
        for i in range(0, len(images), 64):
            batch = images[i:i+64].to(device)
            x_recon, mu, logvar, z = model(batch)
            total_mse += F.mse_loss(x_recon, batch, reduction='sum').item()
            # Pixel accuracy (threshold at 0.5)
            pred_binary = (x_recon > 0.5).float()
            total_acc += (pred_binary == batch).float().sum().item()
            n += batch.shape[0]
            all_z.append(mu.cpu())

    all_z = torch.cat(all_z, dim=0)
    pixel_count = images.shape[0] * images.shape[1] * images.shape[2] * images.shape[3]

    # Distance correlation
    n_corr = min(200, len(all_z))
    D_z = pairwise_distances(all_z[:n_corr])
    D_pca = pairwise_distances(pca_emb[:n_corr])
    d_z = get_upper_triangular(D_z).numpy()
    d_pca = get_upper_triangular(D_pca).numpy()
    pearson, _ = pearsonr(d_z, d_pca)
    spearman, _ = spearmanr(d_z, d_pca)

    results = {
        'mse': total_mse / pixel_count,
        'pixel_acc': total_acc / pixel_count,
        'pearson': pearson,
        'spearman': spearman,
        'latents': all_z,
    }

    print(f"\n{name} Results:")
    print(f"  Recon MSE:      {results['mse']:.6f}")
    print(f"  Pixel Accuracy: {results['pixel_acc']:.4f}")
    print(f"  Pearson corr:   {results['pearson']:.4f}")
    print(f"  Spearman corr:  {results['spearman']:.4f}")

    return results


res_baseline = evaluate_vae(vae_baseline, val_images, val_pca, "Baseline", device)
res_mmreg = evaluate_vae(vae_mmreg, val_images, val_pca, "MM-Reg", device)

In [None]:
# Visualize reconstructions
def plot_reconstructions(model, images, title, device):
    model.eval()
    imgs = images[:8].to(device)
    with torch.no_grad():
        recon, _, _, _ = model(imgs)

    fig, axes = plt.subplots(3, 8, figsize=(16, 6))
    fig.suptitle(title, fontsize=14)
    for i in range(8):
        axes[0, i].imshow(imgs[i, 0].cpu(), cmap='gray', vmin=0, vmax=1)
        axes[0, i].axis('off')
        axes[1, i].imshow(recon[i, 0].cpu(), cmap='gray', vmin=0, vmax=1)
        axes[1, i].axis('off')
        axes[2, i].imshow((recon[i, 0].cpu() > 0.5).float(), cmap='gray', vmin=0, vmax=1)
        axes[2, i].axis('off')
    axes[0, 0].set_ylabel('Original', fontsize=11)
    axes[1, 0].set_ylabel('Recon (raw)', fontsize=11)
    axes[2, 0].set_ylabel('Recon (binary)', fontsize=11)
    plt.tight_layout()
    plt.show()


plot_reconstructions(vae_baseline, val_images, "Baseline VAE Reconstructions", device)
plot_reconstructions(vae_mmreg, val_images, "MM-Reg VAE Reconstructions", device)

In [None]:
# Distance correlation scatter plots
fig, axes = plt.subplots(1, 2, figsize=(12, 5))

for ax, z, name in [(axes[0], res_baseline['latents'], 'Baseline'),
                     (axes[1], res_mmreg['latents'], 'MM-Reg')]:
    n_pts = min(200, len(z))
    D_z = get_upper_triangular(pairwise_distances(z[:n_pts])).numpy()
    D_pca = get_upper_triangular(pairwise_distances(val_pca[:n_pts])).numpy()
    r, _ = pearsonr(D_z, D_pca)

    ax.scatter(D_pca, D_z, alpha=0.05, s=2)
    ax.set_xlabel('PCA distances')
    ax.set_ylabel('Latent distances')
    ax.set_title(f'{name} (r={r:.4f})')
    ax.grid(True, alpha=0.3)

plt.suptitle('Pairwise Distance Correlation (Val Set)', fontsize=14)
plt.tight_layout()
plt.show()

## 9. Encode Datasets to Latent Vectors

In [None]:
@torch.no_grad()
def encode_all(model, images, device, batch_size=64):
    """Encode images to latent mu vectors."""
    model.eval()
    zs = []
    for i in range(0, len(images), batch_size):
        batch = images[i:i+batch_size].to(device)
        mu, _ = model.encode(batch)
        zs.append(mu.cpu())
    return torch.cat(zs, dim=0)


train_z_baseline = encode_all(vae_baseline, train_images, device)
val_z_baseline = encode_all(vae_baseline, val_images, device)
print(f"Baseline latents - Train: {train_z_baseline.shape}, Val: {val_z_baseline.shape}")

train_z_mmreg = encode_all(vae_mmreg, train_images, device)
val_z_mmreg = encode_all(vae_mmreg, val_images, device)
print(f"MM-Reg latents   - Train: {train_z_mmreg.shape}, Val: {val_z_mmreg.shape}")

## 10. Train Diffusion on Baseline Latents

In [None]:
print("="*60)
print("TRAINING DIFFUSION ON BASELINE LATENTS")
print("="*60)

diffusion_baseline = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'], device=device
)

mlp_baseline = MLPDenoiser(
    input_dim=CONFIG['latent_dim'],
    hidden_dim=512,
    num_layers=4,
    time_emb_dim=128
).to(device)

print(f"MLP params: {sum(p.numel() for p in mlp_baseline.parameters()):,}")

opt_diff_base = torch.optim.AdamW(mlp_baseline.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_baseline = DiffusionTrainer(
    model=mlp_baseline,
    diffusion=diffusion_baseline,
    optimizer=opt_diff_base,
    train_latents=train_z_baseline,
    val_latents=val_z_baseline,
    batch_size=CONFIG['batch_size'],
    device=device,
    use_amp=False,
    save_dir='../checkpoints/photonics_diffusion_baseline'
)

trainer_diff_baseline.train(num_epochs=CONFIG['diffusion_epochs'])

## 11. Train Diffusion on MM-Reg Latents

In [None]:
print("="*60)
print("TRAINING DIFFUSION ON MM-REG LATENTS")
print("="*60)

diffusion_mmreg = GaussianDiffusion(
    num_timesteps=CONFIG['diffusion_timesteps'], device=device
)

mlp_mmreg = MLPDenoiser(
    input_dim=CONFIG['latent_dim'],
    hidden_dim=512,
    num_layers=4,
    time_emb_dim=128
).to(device)

opt_diff_mm = torch.optim.AdamW(mlp_mmreg.parameters(), lr=CONFIG['diffusion_lr'])

trainer_diff_mmreg = DiffusionTrainer(
    model=mlp_mmreg,
    diffusion=diffusion_mmreg,
    optimizer=opt_diff_mm,
    train_latents=train_z_mmreg,
    val_latents=val_z_mmreg,
    batch_size=CONFIG['batch_size'],
    device=device,
    use_amp=False,
    save_dir='../checkpoints/photonics_diffusion_mmreg'
)

trainer_diff_mmreg.train(num_epochs=CONFIG['diffusion_epochs'])

## 12. Generate Samples & Compare

In [None]:
# Generate latent vectors from diffusion
print("Generating from Baseline Diffusion...")
gen_z_baseline = trainer_diff_baseline.generate_samples(num_samples=16)

print("Generating from MM-Reg Diffusion...")
gen_z_mmreg = trainer_diff_mmreg.generate_samples(num_samples=16)

In [None]:
@torch.no_grad()
def decode_and_plot(vae, z_vectors, title, device):
    """Decode latent vectors to images and plot."""
    vae.eval()
    z = z_vectors.to(device)
    images = vae.decode(z).cpu()

    n = min(16, images.shape[0])
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(title, fontsize=14)

    for i in range(n):
        row, col = i // 8, i % 8
        # Raw output
        axes[row, col].imshow(images[i, 0].clamp(0, 1), cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.show()

    # Also show thresholded (binary) versions
    fig, axes = plt.subplots(2, 8, figsize=(16, 4))
    fig.suptitle(f"{title} (thresholded)", fontsize=14)

    for i in range(n):
        row, col = i // 8, i % 8
        axes[row, col].imshow((images[i, 0] > 0.5).float(), cmap='gray', vmin=0, vmax=1)
        axes[row, col].axis('off')

    plt.tight_layout()
    plt.show()


decode_and_plot(vae_baseline, gen_z_baseline,
                "Generated Samples (Baseline VAE + Diffusion)", device)
decode_and_plot(vae_mmreg, gen_z_mmreg,
                "Generated Samples (MM-Reg VAE + Diffusion)", device)

In [None]:
# Side-by-side: Real vs Baseline Generated vs MM-Reg Generated
fig, axes = plt.subplots(3, 8, figsize=(16, 6))

# Real samples
for i in range(8):
    axes[0, i].imshow(val_images[i, 0], cmap='gray', vmin=0, vmax=1)
    axes[0, i].axis('off')

# Baseline generated (thresholded)
vae_baseline.eval()
with torch.no_grad():
    baseline_imgs = vae_baseline.decode(gen_z_baseline[:8].to(device)).cpu()
for i in range(8):
    axes[1, i].imshow((baseline_imgs[i, 0] > 0.5).float(), cmap='gray', vmin=0, vmax=1)
    axes[1, i].axis('off')

# MM-Reg generated (thresholded)
vae_mmreg.eval()
with torch.no_grad():
    mmreg_imgs = vae_mmreg.decode(gen_z_mmreg[:8].to(device)).cpu()
for i in range(8):
    axes[2, i].imshow((mmreg_imgs[i, 0] > 0.5).float(), cmap='gray', vmin=0, vmax=1)
    axes[2, i].axis('off')

axes[0, 0].set_ylabel('Real', fontsize=12)
axes[1, 0].set_ylabel('Baseline', fontsize=12)
axes[2, 0].set_ylabel('MM-Reg', fontsize=12)
plt.suptitle('Real vs Generated Photonics Patterns', fontsize=14)
plt.tight_layout()
plt.show()

## 13. Latent Space Interpolation

Interpolate between two latent vectors and decode. A well-structured latent space
(MM-Reg) should produce smooth, realistic transitions.

In [None]:
@torch.no_grad()
def plot_interpolation(vae, z_start, z_end, title, device, steps=8):
    vae.eval()
    alphas = torch.linspace(0, 1, steps)
    z_interp = torch.stack([z_start * (1 - a) + z_end * a for a in alphas]).to(device)
    imgs = vae.decode(z_interp).cpu()

    fig, axes = plt.subplots(1, steps, figsize=(2 * steps, 2))
    fig.suptitle(title, fontsize=12)
    for i, ax in enumerate(axes):
        ax.imshow((imgs[i, 0] > 0.5).float(), cmap='gray', vmin=0, vmax=1)
        ax.set_title(f'a={alphas[i]:.1f}')
        ax.axis('off')
    plt.tight_layout()
    plt.show()


# Pick two val samples and interpolate in each latent space
z0_b, z1_b = val_z_baseline[0], val_z_baseline[5]
z0_m, z1_m = val_z_mmreg[0], val_z_mmreg[5]

plot_interpolation(vae_baseline, z0_b, z1_b, "Baseline Interpolation", device)
plot_interpolation(vae_mmreg, z0_m, z1_m, "MM-Reg Interpolation", device)

## 14. Summary & Training Curves

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# VAE total loss
axes[0].plot([h['loss'] for h in hist_baseline['train']], 'b-', alpha=0.7, label='Baseline Train')
axes[0].plot([h['loss'] for h in hist_baseline['val']], 'b--', alpha=0.7, label='Baseline Val')
axes[0].plot([h['loss'] for h in hist_mmreg['train']], 'r-', alpha=0.7, label='MM-Reg Train')
axes[0].plot([h['loss'] for h in hist_mmreg['val']], 'r--', alpha=0.7, label='MM-Reg Val')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('VAE Total Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# VAE recon loss
axes[1].plot([h['recon'] for h in hist_baseline['train']], 'b-', alpha=0.7, label='Baseline Train')
axes[1].plot([h['recon'] for h in hist_baseline['val']], 'b--', alpha=0.7, label='Baseline Val')
axes[1].plot([h['recon'] for h in hist_mmreg['train']], 'r-', alpha=0.7, label='MM-Reg Train')
axes[1].plot([h['recon'] for h in hist_mmreg['val']], 'r--', alpha=0.7, label='MM-Reg Val')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Recon MSE')
axes[1].set_title('VAE Reconstruction Loss')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Diffusion loss
diff_hist_base = trainer_diff_baseline.train_history
diff_hist_mm = trainer_diff_mmreg.train_history
diff_val_base = trainer_diff_baseline.val_history
diff_val_mm = trainer_diff_mmreg.val_history

axes[2].plot([h['loss'] for h in diff_hist_base], 'b-', alpha=0.7, label='Baseline Train')
axes[2].plot([h['loss'] for h in diff_val_base], 'b--', alpha=0.7, label='Baseline Val')
axes[2].plot([h['loss'] for h in diff_hist_mm], 'r-', alpha=0.7, label='MM-Reg Train')
axes[2].plot([h['loss'] for h in diff_val_mm], 'r--', alpha=0.7, label='MM-Reg Val')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Loss')
axes[2].set_title('Diffusion Training Loss')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig('../checkpoints/photonics_training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
def to_python(v):
    return v.item() if hasattr(v, 'item') else float(v)

summary = {
    'config': CONFIG,
    'dataset': {'total': len(data_tensor), 'train': len(train_images), 'val': len(val_images)},
    'vae_results': {
        'baseline': {
            'recon_mse': to_python(res_baseline['mse']),
            'pixel_acc': to_python(res_baseline['pixel_acc']),
            'pearson': to_python(res_baseline['pearson']),
            'spearman': to_python(res_baseline['spearman']),
        },
        'mmreg': {
            'recon_mse': to_python(res_mmreg['mse']),
            'pixel_acc': to_python(res_mmreg['pixel_acc']),
            'pearson': to_python(res_mmreg['pearson']),
            'spearman': to_python(res_mmreg['spearman']),
        },
    },
    'diffusion_final_loss': {
        'baseline_train': diff_hist_base[-1]['loss'],
        'baseline_val': diff_val_base[-1]['loss'],
        'mmreg_train': diff_hist_mm[-1]['loss'],
        'mmreg_val': diff_val_mm[-1]['loss'],
    },
}

print("="*60)
print("PHOTONICS EXPERIMENT SUMMARY")
print("="*60)
print(f"\nDataset: {summary['dataset']}")

print("\nVAE Comparison (val set):")
print(f"  Baseline - MSE: {res_baseline['mse']:.6f}, Acc: {res_baseline['pixel_acc']:.4f}, "
      f"Pearson: {res_baseline['pearson']:.4f}")
print(f"  MM-Reg   - MSE: {res_mmreg['mse']:.6f}, Acc: {res_mmreg['pixel_acc']:.4f}, "
      f"Pearson: {res_mmreg['pearson']:.4f}")

print("\nDiffusion Final Val Loss:")
print(f"  Baseline: {diff_val_base[-1]['loss']:.6f}")
print(f"  MM-Reg:   {diff_val_mm[-1]['loss']:.6f}")

base_val = diff_val_base[-1]['loss']
mm_val = diff_val_mm[-1]['loss']
improvement = (base_val - mm_val) / base_val * 100
print(f"\nDiffusion improvement: {improvement:+.1f}%")

os.makedirs('../checkpoints', exist_ok=True)
with open('../checkpoints/photonics_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)

print("\nSaved to ../checkpoints/photonics_summary.json")