# Experiment Runner — Automated Batch
Runs zero-filled and real/complex U-Nets sequentially so overnight jobs log everything needed for the paper.


## Workflow
1. Configure experiments in the cell below.
2. Execute the final cell to iterate through each run.
3. Inspect `results/` in the morning for metrics, checkpoints, and figures.


In [5]:
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from pathlib import Path
import csv
import json
from datetime import datetime

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import yaml
from torch.utils.data import DataLoader, Subset

PROJECT_ROOT = Path.cwd().resolve()
if PROJECT_ROOT.name == 'notebooks':
    PROJECT_ROOT = PROJECT_ROOT.parent
SRC_ROOT = PROJECT_ROOT / 'src'
RESULTS_ROOT = PROJECT_ROOT / 'results'
RESULTS_ROOT.mkdir(exist_ok=True)

import sys
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

from data.dataset import SingleCoilDataset
from data.masking import EquispacedMasker
from models.real_unet import RealUnet
from models.cx_unet import ComplexUnet
from training.utils import train_loop, test_loop


In [6]:

def init_run(run_tag: str, config: dict) -> Path:
    timestamp = datetime.now().strftime('%Y%m%d-%H%M%S')
    run_dir = RESULTS_ROOT / f"{timestamp}_{run_tag}"
    (run_dir / 'checkpoints').mkdir(parents=True, exist_ok=True)
    (run_dir / 'qualitative').mkdir(exist_ok=True)
    (run_dir / 'tensors').mkdir(exist_ok=True)
    with open(run_dir / 'config.yaml', 'w') as f:
        yaml.safe_dump(config, f)
    with open(run_dir / 'config.json', 'w') as f:
        json.dump(config, f, indent=2)
    return run_dir


def to_mag(x: torch.Tensor) -> torch.Tensor:
    return x.abs()


def psnr_db(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-8) -> torch.Tensor:
    mse = torch.mean((x - y) ** 2)
    return 10.0 * torch.log10(1.0 / (mse + eps))


def ssim_simple(x: torch.Tensor, y: torch.Tensor, C1: float = 0.01**2, C2: float = 0.03**2) -> torch.Tensor:
    mu_x, mu_y = x.mean(), y.mean()
    sigma_x = ((x - mu_x) ** 2).mean()
    sigma_y = ((y - mu_y) ** 2).mean()
    sigma_xy = ((x - mu_x) * (y - mu_y)).mean()
    num = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
    den = (mu_x**2 + mu_y**2 + C1) * (sigma_x + sigma_y + C2)
    return num / (den + 1e-8)


@torch.no_grad()
def evaluate_metrics(model, dataloader, device, num_batches: int = 2):
    model.eval()
    psnrs, ssim_scores = [], []
    for batch_idx, (masked, target) in enumerate(dataloader):
        masked = masked.to(device)
        target = target.to(device)
        recon = model(masked)
        gt = to_mag(target)
        pred = to_mag(recon)
        gt = gt / gt.max().clamp_min(1e-8)
        pred = pred / pred.max().clamp_min(1e-8)
        psnrs.append(psnr_db(pred, gt).item())
        ssim_scores.append(ssim_simple(pred, gt).item())
        if (batch_idx + 1) >= num_batches:
            break
    return float(sum(psnrs) / max(len(psnrs), 1)), float(sum(ssim_scores) / max(len(ssim_scores), 1))


@torch.no_grad()
def save_qualitative(run_dir: Path, model, dataset, device, epoch: int, sample_idx: int | None = None):
    model.eval()
    if len(dataset) == 0:
        return
    idx = sample_idx if sample_idx is not None else (epoch * 11) % len(dataset)
    masked, target = dataset[idx]
    masked = masked.unsqueeze(0).to(device)
    target = target.unsqueeze(0).to(device)
    recon = model(masked)
    gt = to_mag(target[0]).cpu()
    pred = to_mag(recon[0]).cpu()
    zf = to_mag(masked[0]).cpu()
    for img in (gt, pred, zf):
        img /= img.max().clamp_min(1e-8)
    fig, axes = plt.subplots(1, 3, figsize=(12, 4))
    for ax, img, title in zip(axes, [zf, pred, gt], ['Zero-filled', 'Model', 'Ground Truth']):
        ax.imshow(img.squeeze(), cmap='gray')
        ax.set_title(title)
        ax.axis('off')
    fig.suptitle(f'Epoch {epoch}')
    fig.tight_layout()
    fig.savefig(run_dir / 'qualitative' / f'epoch{epoch:03d}.png', dpi=200)
    plt.close(fig)


def select_indices(n_total: int, n_pick: int) -> np.ndarray:
    n = min(n_total, n_pick)
    if n_total == 0:
        return np.array([], dtype=int)
    return np.linspace(0, n_total - 1, num=n, dtype=int)


def save_zero_fill_grid(cfg: dict, run_dir: Path):
    accels = cfg['mask_grid']['accels']
    acs_list = cfg['mask_grid']['acs']
    idx = cfg.get('grid_index', 0)
    n_rows = len(accels)
    n_cols = len(acs_list) + 1
    fig, axs = plt.subplots(n_rows, n_cols, figsize=(3 * n_cols, 3 * n_rows), squeeze=False)
    for r_i, accel in enumerate(accels):
        gt_shown = False
        for a_i, acs in enumerate(acs_list):
            ds = SingleCoilDataset(cfg['val_folder'], mask_func=EquispacedMasker(accel=accel, acs=acs))
            slice_idx = min(idx, len(ds) - 1)
            masked, target = ds[slice_idx]
            gt = to_mag(target).cpu()
            zf = to_mag(masked).cpu()
            gt = gt / gt.max().clamp_min(1e-8)
            zf = zf / zf.max().clamp_min(1e-8)
            if not gt_shown:
                ax_gt = axs[r_i, 0]
                ax_gt.imshow(gt.squeeze(), cmap='gray')
                ax_gt.set_title('Ground Truth', fontsize=8)
                ax_gt.set_ylabel(f'R={accel}', fontsize=8)
                ax_gt.axis('off')
                gt_shown = True
            ax = axs[r_i, a_i + 1]
            ax.imshow(zf.squeeze(), cmap='gray')
            ax.set_title(f'ACS={acs}', fontsize=8)
            ax.axis('off')
    fig.tight_layout()
    fig.savefig(run_dir / 'zero_fill_grid.png', dpi=200)
    plt.close(fig)


def run_zero_fill_experiment(cfg: dict):
    run_dir = init_run(cfg['run_tag'], cfg)
    rows = []
    accels = cfg['mask_grid']['accels']
    acs_list = cfg['mask_grid']['acs']
    slice_count = cfg.get('val_subset', 64)
    for accel in accels:
        for acs in acs_list:
            ds = SingleCoilDataset(cfg['val_folder'], mask_func=EquispacedMasker(accel=accel, acs=acs))
            idxs = select_indices(len(ds), slice_count)
            psnrs, ssims, l1s = [], [], []
            for idx in idxs:
                masked, target = ds[idx]
                gt = to_mag(target).float()
                zf = to_mag(masked).float()
                gt = gt / gt.max().clamp_min(1e-8)
                zf = zf / zf.max().clamp_min(1e-8)
                psnrs.append(psnr_db(zf, gt).item())
                ssims.append(ssim_simple(zf, gt).item())
                l1s.append(torch.mean(torch.abs(zf - gt)).item())
            rows.append({
                'accel': accel,
                'acs': acs,
                'N': len(idxs),
                'PSNR_mean': float(np.mean(psnrs)) if psnrs else 0.0,
                'PSNR_std': float(np.std(psnrs)) if psnrs else 0.0,
                'SSIM_mean': float(np.mean(ssims)) if ssims else 0.0,
                'SSIM_std': float(np.std(ssims)) if ssims else 0.0,
                'L1_mean': float(np.mean(l1s)) if l1s else 0.0,
                'L1_std': float(np.std(l1s)) if l1s else 0.0,
            })
    df = pd.DataFrame(rows)
    df.to_csv(run_dir / 'zero_fill_metrics.csv', index=False)
    display(df)
    save_zero_fill_grid(cfg, run_dir)


def build_subset(dataset, subset_size):
    if subset_size is None:
        return dataset
    n = min(len(dataset), subset_size)
    indices = list(range(n))
    return Subset(dataset, indices)


def run_model_experiment(cfg: dict):
    torch.manual_seed(cfg.get('seed', 0))
    device = torch.device('mps' if torch.backends.mps.is_available() else 'cuda' if torch.cuda.is_available() else 'cpu')
    run_dir = init_run(cfg['run_tag'], cfg)
    mask_cfg = cfg.get('mask', {'accel': 4, 'acs': 24})
    masker = EquispacedMasker(accel=mask_cfg['accel'], acs=mask_cfg['acs'])
    train_full = SingleCoilDataset(cfg['train_folder'], mask_func=masker)
    val_full = SingleCoilDataset(cfg['val_folder'], mask_func=masker)
    train_set = build_subset(train_full, cfg.get('train_subset'))
    val_set = build_subset(val_full, cfg.get('val_subset'))
    train_loader = DataLoader(train_set, batch_size=cfg['batch_size'], shuffle=True, num_workers=cfg['num_workers'])
    val_loader = DataLoader(val_set, batch_size=cfg['batch_size'], shuffle=False, num_workers=cfg['num_workers'])
    features = cfg.get('features', [32, 64, 128, 256, 512])
    if cfg.get('model', 'real') == 'complex':
        model = ComplexUnet(in_channels=1, out_channels=1, features=features).to(device)
    else:
        model = RealUnet(in_channels=1, out_channels=1, features=features, width_scale=cfg.get('width_scale', 1.0)).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg['learning_rate'])
    loss_fn = lambda pred, target: (pred - target).abs().mean()
    metrics_path = run_dir / 'metrics.csv'
    with metrics_path.open('w', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=['epoch', 'train_loss', 'val_loss', 'psnr', 'ssim'])
        writer.writeheader()
    best_val = float('inf')
    summary_path = run_dir / 'metrics_summary.json'
    qual_idx = cfg.get('qualitative_index', 0)
    for epoch in range(1, cfg['epochs'] + 1):
        train_loss = train_loop(model, train_loader, optimizer, loss_fn, device)
        val_loss = test_loop(model, val_loader, loss_fn, device)
        psnr, ssim = evaluate_metrics(model, val_loader, device)
        with metrics_path.open('a', newline='') as f:
            writer = csv.DictWriter(f, fieldnames=['epoch', 'train_loss', 'val_loss', 'psnr', 'ssim'])
            writer.writerow({'epoch': epoch, 'train_loss': train_loss, 'val_loss': val_loss, 'psnr': psnr, 'ssim': ssim})
        summary = {'best_val_loss': float(min(best_val, val_loss)), 'last_epoch': epoch, 'device': str(device)}
        summary_path.write_text(json.dumps(summary, indent=2))
        torch.save({'model': model.state_dict(), 'epoch': epoch}, run_dir / 'checkpoints' / 'latest.pt')
        if val_loss < best_val:
            best_val = val_loss
            torch.save({'model': model.state_dict(), 'epoch': epoch}, run_dir / 'checkpoints' / 'best.pt')
        save_qualitative(run_dir, model, val_set, device, epoch, sample_idx=qual_idx)
        print(f"[{cfg['run_tag']}] Epoch {epoch}/{cfg['epochs']} | train {train_loss:.4f} | val {val_loss:.4f} | psnr {psnr:.2f} | ssim {ssim:.3f}")


def run_experiment(cfg: dict):
    exp_type = cfg['type']
    print(f"=== Running {cfg['run_tag']} ({exp_type}) ===")
    if exp_type == 'zero_fill':
        run_zero_fill_experiment(cfg)
    elif exp_type == 'model':
        run_model_experiment(cfg)
    else:
        raise ValueError(f'Unknown experiment type: {exp_type}')


In [12]:

EXPERIMENTS = [
    # {
    #     'run_tag': 'zerofill_sweep',
    #     'type': 'zero_fill',
    #     'val_folder': str(PROJECT_ROOT / 'data' / 'singlecoil_val'),
    #     'val_subset': 64,
    #     'grid_index': 50,
    #     'mask_grid': {'accels': [2, 4, 6, 8], 'acs': [4, 12, 20]},
    # },

    # {
    #     'run_tag': 'realunet_R6_seed0',
    #     'type': 'model',
    #     'model': 'real',
    #     'train_folder': str(PROJECT_ROOT / 'data' / 'singlecoil_train'),
    #     'val_folder': str(PROJECT_ROOT / 'data' / 'singlecoil_val'),
    #     'mask': {'accel': 6, 'acs': 20},
    #     'train_subset': 16384,
    #     'val_subset': 1024,
    #     'batch_size': 4, 
    #     'num_workers': 4,
    #     'epochs': 1, # 6 * 16384 / 2 = 49152 steps
    #     'learning_rate': 1e-3,
    #     'width_scale': 1.42,
    #     'features': [16, 32, 64, 128, 256],
    #     'seed': 0,
    #     'qualitative_index': 50,
    # },
    {
        'run_tag': 'complexunet_R6_seed0',
        'type': 'model',
        'model': 'complex',
        'train_folder': str(PROJECT_ROOT / 'data' / 'singlecoil_train'),
        'val_folder': str(PROJECT_ROOT / 'data' / 'singlecoil_val'),
        'mask': {'accel': 6, 'acs': 20},
        'train_subset': 8192,
        'val_subset': 1024,
        'batch_size': 2,
        'num_workers': 4,
        'epochs': 1,
        'learning_rate': 1e-3,
        'features': [16, 32, 64, 128, 256],
        'seed': 0,
        'qualitative_index': 50,
    },
]
EXPERIMENTS


[{'run_tag': 'complexunet_R6_seed0',
  'type': 'model',
  'model': 'complex',
  'train_folder': '/home/gdegeron/Desktop/ece570-tinyreproductions/data/singlecoil_train',
  'val_folder': '/home/gdegeron/Desktop/ece570-tinyreproductions/data/singlecoil_val',
  'mask': {'accel': 6, 'acs': 20},
  'train_subset': 8192,
  'val_subset': 1024,
  'batch_size': 2,
  'num_workers': 4,
  'epochs': 1,
  'learning_rate': 0.001,
  'features': [16, 32, 64, 128, 256],
  'seed': 0,
  'qualitative_index': 50}]

In [None]:
for cfg in EXPERIMENTS:
    run_experiment(cfg)

=== Running complexunet_R6_seed0 (model) ===


train:  59%|█████▉    | 2437/4096 [12:56<08:48,  3.14it/s]