# MRI Super‑Resolution – Unified Workflow Notebook

This notebook replaces all individual CLI scripts with a **single, configuration‑driven workflow** covering:

1. **Setup & Configuration**  
2. **Data Exploration**  
3. **Model Training**  
4. **Evaluation**  
5. **Inference / Visualisation**

Edit YAML files in `configs/` to modify parameters — the notebook simply *loads* them and passes a frozen `CfgNode` around.

> **Tip:** Run each section independently; the state (e.g.\ `trainer`, `evaluator`, `model`) is stored in top‑level variables so you can jump around without rerunning everything.

In [None]:
# --- 1. Setup -----------------------------------------------------------------
import os, sys, yaml, torch, importlib, pprint, json
from pathlib import Path

# Make src importable
PROJECT_ROOT = Path().resolve()
SRC_DIR = PROJECT_ROOT / "src"
if str(SRC_DIR) not in sys.path:
    sys.path.insert(0, str(SRC_DIR))

from utils.config import load_cfg   # new helper we'll add to utils
cfg_path = PROJECT_ROOT / "configs" / "experiment.yaml"
CFG = load_cfg(cfg_path)

print("Loaded configuration:")
pprint.pprint(CFG)
print(f"Running on device: {torch.device('cuda' if torch.cuda.is_available() else 'cpu')}")

## Why a YAML‑only config?

* **Single source of truth** – no hidden defaults scattered across scripts.  
* **Reproducibility** – every experiment can be re‑run by pointing to the same YAML.  
* **Composability** – configs can inherit from a *base* config and override only what changes.

In [None]:
# --- 2. Data Exploration ------------------------------------------------------
from utils.data import create_dataloaders
import matplotlib.pyplot as plt
import numpy as np

train_loader, val_loader = create_dataloaders(
    data_dir=CFG['data']['dir'],
    config_path=CFG['data']['transforms'],
    loader_to_create='both',
    batch_size=1,
    num_workers=CFG['data'].get('workers', 4)
)

def show_sample(loader, title):
    lr, hr = next(iter(loader))
    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1); plt.imshow(lr[0,0], cmap='gray'); plt.title('LR'); plt.axis('off')
    plt.subplot(1,2,2); plt.imshow(hr[0,0], cmap='gray'); plt.title('HR'); plt.axis('off')
    plt.suptitle(title); plt.show()

show_sample(train_loader, "Training sample")

In [None]:
# --- 3. Model Training --------------------------------------------------------
from train_refactored import Trainer
trainer = Trainer(  # modified Trainer now accepts a Namespace‑like dict
    argparse.Namespace(
        **CFG['train']  # spread config keys as if they were CLI args
    )
)
trainer.fit()

In [None]:
# --- 4. Evaluation ------------------------------------------------------------
from evaluate_refactored import Evaluator
evaluator = Evaluator(argparse.Namespace(
    **CFG['eval'],
    checkpoint=str(trainer.output_dir / f"{CFG['train']['model']}_best.pth")
))
metrics = evaluator.run()

In [None]:
# --- 5. Inference -------------------------------------------------------------
from inference import preprocess_image
from PIL import Image
import torch.nn.functional as F
import numpy as np

test_img_path = list(Path(CFG['inference']['input_dir']).glob('*.tif'))[0]
lr = preprocess_image(test_img_path).to(evaluator.device)
lr_up = F.interpolate(lr, scale_factor=2, mode='bicubic', align_corners=False)
with torch.no_grad():
    sr = evaluator.model(lr_up)

# Visualise
from utils.visualization import plot_comparison
plot_comparison(lr_up.cpu(), sr.cpu(), sr.cpu())  # placeholder HR = SR for demo