# Figure/Table Generator
Configure run dirs and indices below; generates tables and qualitative panels from saved results.


In [31]:
from pathlib import Path
import json
import pandas as pd
import matplotlib.pyplot as plt
import torch
import numpy as np
import sys

PROJECT_ROOT = Path.cwd().resolve()
if PROJECT_ROOT.name == 'notebooks':
    PROJECT_ROOT = PROJECT_ROOT.parent
SRC_ROOT = PROJECT_ROOT / 'src'
if str(SRC_ROOT) not in sys.path:
    sys.path.append(str(SRC_ROOT))

# ==== User config ====
PROJECT_ROOT = Path.cwd().resolve().parent if Path.cwd().name == 'notebooks' else Path.cwd().resolve()
FIG_DIR = PROJECT_ROOT / 'report' / 'paper' / 'figures'
FIG_DIR.mkdir(parents=True, exist_ok=True)

ZERO_FILL_RUN = PROJECT_ROOT / 'results' / 'zerofill_sweep'
REAL_RUN = PROJECT_ROOT / 'results' / 'realunet_R6_seed0'
COMPLEX_RUN = PROJECT_ROOT / 'results' / 'complexunet_R6_seed0'

# ZERO_FILL_RUN = PROJECT_ROOT / 'results' / 'zerofill_quick'
# REAL_RUN = PROJECT_ROOT / 'results' / 'realunet_quick'
# COMPLEX_RUN = PROJECT_ROOT / 'results' / 'complexunet_quick'

CKPT_CHOICE = 'best'  # 'best' or 'latest'
QUAL_INDICES = [50, 250, 630]    # list of slice indices to render qualitatives for
SHOW_MODELS = ['zf','real','complex','gt']  # order of panels
ZF_ACCEL = 6
ZF_ACS = 20
ZF_GRID_INDICES = [50]  # which slice indices to show for zero-fill grid (mask sweep)
MAX_VAL_SLICES = None  # optional cap when loading datasets

def latest_timestamped(prefix: str):
    candidates = sorted((PROJECT_ROOT / 'results').glob(f'*_{prefix}'))
    return candidates[-1] if candidates else None

if ZERO_FILL_RUN.name == 'zerofill_sweep':
    maybe = latest_timestamped('zerofill_sweep')
    if maybe: ZERO_FILL_RUN = maybe
if REAL_RUN.name == 'realunet_R6_seed0':
    maybe = latest_timestamped('realunet_R6_seed0')
    if maybe: REAL_RUN = maybe
if COMPLEX_RUN.name == 'complexunet_R6_seed0':
    maybe = latest_timestamped('complexunet_R6_seed0')
    if maybe: COMPLEX_RUN = maybe

print('Using runs:')
print('ZERO_FILL_RUN', ZERO_FILL_RUN)
print('REAL_RUN', REAL_RUN)
print('COMPLEX_RUN', COMPLEX_RUN)


Using runs:
ZERO_FILL_RUN /home/gdegeron/Desktop/ece570-tinyreproductions/results/zerofill_sweep
REAL_RUN /home/gdegeron/Desktop/ece570-tinyreproductions/results/realunet_R6_seed0
COMPLEX_RUN /home/gdegeron/Desktop/ece570-tinyreproductions/results/complexunet_R6_seed0


## Helpers


In [32]:
from data.dataset import SingleCoilDataset
from data.masking import EquispacedMasker
from models.real_unet import RealUnet
from models.cx_unet import ComplexUnet
import matplotlib.image as mpimg

def load_epoch_metrics(run_dir: Path):
    path = run_dir / 'metrics' / 'epoch_metrics.csv'
    if not path.exists():
        raise FileNotFoundError(path)
    df = pd.read_csv(path)
    df['run_dir'] = str(run_dir)
    return df

def load_step_metrics(run_dir: Path):
    path = run_dir / 'metrics' / 'step_metrics.csv'
    if not path.exists():
        raise FileNotFoundError(path)
    df = pd.read_csv(path)
    df['run_dir'] = str(run_dir)
    return df

def load_zero_fill(run_dir: Path):
    path = run_dir / 'zero_fill_metrics.csv'
    if not path.exists():
        raise FileNotFoundError(path)
    return pd.read_csv(path)

def best_row(df: pd.DataFrame):
    return df.loc[df['val_loss'].idxmin()]

def load_config(run_dir: Path):
    cfg_path = run_dir / 'config.json'
    if not cfg_path.exists():
        cfg_path = run_dir / 'config.yaml'
    with open(cfg_path, 'r') as f:
        return json.load(f)

def make_masker(cfg: dict):
    mask_cfg = cfg.get('mask', {'accel': ZF_ACCEL, 'acs': ZF_ACS})
    return EquispacedMasker(accel=mask_cfg.get('accel', ZF_ACCEL), acs=mask_cfg.get('acs', ZF_ACS))

def load_dataset(cfg: dict, split: str):
    folder_key = 'train_folder' if split=='train' else 'val_folder'
    masker = make_masker(cfg)
    ds = SingleCoilDataset(cfg[folder_key], mask_func=masker)
    if MAX_VAL_SLICES is not None and split=='val':
        from torch.utils.data import Subset
        ds = Subset(ds, list(range(min(len(ds), MAX_VAL_SLICES))))
    return ds

def load_model(run_dir: Path, cfg: dict, model_type: str):
    if model_type == 'complex':
        model = ComplexUnet(in_channels=1, out_channels=1, features=cfg.get('features', [16,32,64,128,256]))
    else:
        model = RealUnet(in_channels=1, out_channels=1, features=cfg.get('features', [16,32,64,128,256]), width_scale=cfg.get('width_scale', 1.0))
    ckpt_name = 'best.pt' if CKPT_CHOICE == 'best' else 'latest.pt'
    ckpt_path = run_dir / 'checkpoints' / ckpt_name
    state = torch.load(ckpt_path, map_location='cpu')
    model.load_state_dict(state['model'])
    model.eval()
    return model

def to_mag(x: torch.Tensor):
    return x.abs()

def psnr_db(x: torch.Tensor, y: torch.Tensor, eps: float = 1e-8):
    mse = torch.mean((x - y) ** 2)
    return 10.0 * torch.log10(1.0 / (mse + eps))

def ssim_simple(x: torch.Tensor, y: torch.Tensor, C1: float = 0.01**2, C2: float = 0.03**2):
    mu_x, mu_y = x.mean(), y.mean()
    sigma_x = ((x - mu_x) ** 2).mean()
    sigma_y = ((y - mu_y) ** 2).mean()
    sigma_xy = ((x - mu_x) * (y - mu_y)).mean()
    num = (2 * mu_x * mu_y + C1) * (2 * sigma_xy + C2)
    den = (mu_x**2 + mu_y**2 + C1) * (sigma_x + sigma_y + C2)
    return num / (den + 1e-8)

def compute_metrics(recon_mag, gt_mag):
    return {'psnr': float(psnr_db(recon_mag, gt_mag).item()), 'ssim': float(ssim_simple(recon_mag, gt_mag).item()), 'l1': float(torch.mean(torch.abs(recon_mag - gt_mag)).item())}

def recon_and_metrics(ds, idx, model=None):
    masked, target = ds[idx]
    with torch.no_grad():
        if model is None:
            recon = masked
        else:
            recon = model(masked.unsqueeze(0)).squeeze(0)
    gt_mag = to_mag(target).cpu()
    recon_mag = to_mag(recon).cpu()
    zf_mag = to_mag(masked).cpu()
    scale = gt_mag.max().clamp_min(1e-8)
    gt_mag = gt_mag/scale; recon_mag = recon_mag/scale; zf_mag = zf_mag/scale
    metrics = compute_metrics(recon_mag, gt_mag)
    zf_metrics = compute_metrics(zf_mag, gt_mag)
    return recon_mag, gt_mag, zf_mag, metrics, zf_metrics

def qualitative_panel(indices, real_run, cx_run):
    real_cfg = load_config(real_run)
    cx_cfg = load_config(cx_run)
    ds = load_dataset(real_cfg, split='val')
    real_model = load_model(real_run, real_cfg, 'real') if 'real' in SHOW_MODELS else None
    cx_model = load_model(cx_run, cx_cfg, 'complex') if 'complex' in SHOW_MODELS else None
    cols = [m for m in SHOW_MODELS]
    fig, axes = plt.subplots(len(indices), len(cols), figsize=(3.5*len(cols), 3.5*len(indices)), squeeze=False)
    for r, idx in enumerate(indices):
        recon_real, gt_mag, zf_mag, met_real, met_zf = recon_and_metrics(ds, idx, real_model if real_model else None)
        recon_cx, _, _, met_cx, _ = recon_and_metrics(ds, idx, cx_model if cx_model else None)
        for c,col in enumerate(cols):
            ax = axes[r][c]
            if col=='zf':
                ax.imshow(zf_mag.squeeze(), cmap='gray')
                ax.set_title(f"Zero Fill\nPSNR {met_zf['psnr']:.2f} SSIM {met_zf['ssim']:.3f}")
            elif col=='real':
                ax.imshow(recon_real.squeeze(), cmap='gray')
                ax.set_title(f"Real U-Net\nPSNR {met_real['psnr']:.2f} SSIM {met_real['ssim']:.3f}")
            elif col=='complex':
                ax.imshow(recon_cx.squeeze(), cmap='gray')
                ax.set_title(f"Complex U-Net\nPSNR {met_cx['psnr']:.2f} SSIM {met_cx['ssim']:.3f}")
            elif col=='gt':
                ax.imshow(gt_mag.squeeze(), cmap='gray')
                ax.set_title(f"Ground Truth")
            ax.axis('off')
    fig.tight_layout()
    out_path = FIG_DIR / f"fig_qualitative_idx_{'_'.join(map(str,indices))}.png"
    fig.savefig(out_path, dpi=200)
    plt.close(fig)
    print('Saved', out_path)

def zero_fill_grid(run_dir: Path, indices):
    cfg = load_config(run_dir)
    mask_grid = cfg.get('mask_grid', {'accels':[ZF_ACCEL], 'acs':[ZF_ACS]})
    accels = mask_grid.get('accels', [ZF_ACCEL])
    acs_list = mask_grid.get('acs', [ZF_ACS])
    # build datasets per mask to avoid reapplying mask to already-masked images
    datasets = { (a,acs): SingleCoilDataset(cfg['val_folder'], mask_func=EquispacedMasker(accel=a, acs=acs))
                 for a in accels for acs in acs_list }
    for idx in indices:
        fig, axes = plt.subplots(len(acs_list), len(accels), figsize=(3*len(accels), 3*len(acs_list)), squeeze=False)
        for i,acs in enumerate(acs_list):
            for j,a in enumerate(accels):
                ds = datasets[(a,acs)]
                masked, target = ds[idx]
                gt = to_mag(target).cpu()
                zf = to_mag(masked).cpu()
                scale = gt.max().clamp_min(1e-8)
                gt = gt/scale; zf=zf/scale
                metrics = compute_metrics(zf, gt)
                ax = axes[i][j]
                ax.imshow(zf.squeeze(), cmap='gray')
                ax.set_title(f"R={a}, ACS={acs}\nPSNR {metrics['psnr']:.2f} SSIM {metrics['ssim']:.3f}", fontsize=8)
                ax.axis('off')
        fig.tight_layout()
        out_path = FIG_DIR / f"fig_zerofill_idx_{idx}.png"
        fig.savefig(out_path, dpi=200)
        plt.close(fig)
        print('Saved', out_path)




## Tables


In [33]:
real_df = load_epoch_metrics(REAL_RUN)
cx_df = load_epoch_metrics(COMPLEX_RUN)
zf_df = load_zero_fill(ZERO_FILL_RUN)

main_rows = []
zf_match = zf_df[(zf_df['accel']==ZF_ACCEL) & (zf_df['acs']==ZF_ACS)]
if not zf_match.empty:
    zf_r = zf_match.iloc[0]
    main_rows.append({'model': 'Zero-fill', 'psnr': zf_r['psnr'], 'ssim': zf_r['ssim'], 'l1': zf_r['l1']})
r_best = best_row(real_df)
main_rows.append({'model': 'Real U-Net', 'psnr': r_best['psnr'], 'ssim': r_best['ssim'], 'l1': r_best['l1']})
c_best = best_row(cx_df)
main_rows.append({'model': 'Complex U-Net', 'psnr': c_best['psnr'], 'ssim': c_best['ssim'], 'l1': c_best['l1']})
main_table = pd.DataFrame(main_rows)
main_table_path = FIG_DIR / 'table_main_results.csv'
main_table.to_csv(main_table_path, index=False)
display(main_table)
print('Saved', main_table_path)

zf_table_path = FIG_DIR / 'table_zerofill_metrics.csv'
zf_df.to_csv(zf_table_path, index=False)
print('Saved', zf_table_path)


Unnamed: 0,model,psnr,ssim,l1
0,Zero-fill,21.782419,0.878368,0.058358
1,Real U-Net,25.099617,0.785107,0.046196
2,Complex U-Net,25.494175,0.887243,0.040798


Saved /home/gdegeron/Desktop/ece570-tinyreproductions/report/paper/figures/table_main_results.csv
Saved /home/gdegeron/Desktop/ece570-tinyreproductions/report/paper/figures/table_zerofill_metrics.csv


## Training curves (step-level losses, epoch-level PSNR)


In [34]:
real_steps = load_step_metrics(REAL_RUN)
cx_steps = load_step_metrics(COMPLEX_RUN)
fig, ax = plt.subplots(1,2, figsize=(12,4))

# Plot epoch losses: both train and val.
if not real_df.empty:
    if 'train_loss' in real_df:
        ax[0].plot(real_df['epoch'], real_df['train_loss'], label='Real train loss', alpha=0.8, color='C0')
    if 'val_loss' in real_df:
        ax[0].plot(real_df['epoch'], real_df['val_loss'], label='Real val loss', alpha=0.8, color='C0', linestyle='--')
if not cx_df.empty:
    if 'train_loss' in cx_df:
        ax[0].plot(cx_df['epoch'], cx_df['train_loss'], label='Complex train loss', alpha=0.8, color='C1')
    if 'val_loss' in cx_df:
        ax[0].plot(cx_df['epoch'], cx_df['val_loss'], label='Complex val loss', alpha=0.8, color='C1', linestyle='--')
ax[0].set_title('Loss vs epoch (train/val)')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('L1 loss')
ax[0].legend()
ax[0].grid(True, alpha=0.3)

# PSNR curves (val).
if 'psnr' in real_df:
    ax[1].plot(real_df['epoch'], real_df['psnr'], label='Real PSNR', color='C0')
if 'psnr' in cx_df:
    ax[1].plot(cx_df['epoch'], cx_df['psnr'], label='Complex PSNR', color='C1')
ax[1].set_title('PSNR vs epoch (val)')
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('PSNR (dB)')
ax[1].legend()
ax[1].grid(True, alpha=0.3)
fig.tight_layout()
train_fig_path = FIG_DIR / 'fig_training_curve.png'
fig.savefig(train_fig_path, dpi=200)
plt.close(fig)
print('Saved', train_fig_path)


Saved /home/gdegeron/Desktop/ece570-tinyreproductions/report/paper/figures/fig_training_curve.png


## Qualitative panels (select indices)


In [35]:
qualitative_panel(QUAL_INDICES, REAL_RUN, COMPLEX_RUN)


Saved /home/gdegeron/Desktop/ece570-tinyreproductions/report/paper/figures/fig_qualitative_idx_50_250_630.png


## Zero-fill qualitative grid (mask sweep)


In [36]:
zero_fill_grid(ZERO_FILL_RUN, ZF_GRID_INDICES)


Saved /home/gdegeron/Desktop/ece570-tinyreproductions/report/paper/figures/fig_zerofill_idx_50.png
