# 06_t1_to_t2_synthesis.ipynb — T1→T2 generation

Goal:
- Load paired T1/T2 PNGs (256×256, grayscale).
- Use a conditional generator (Pix2Pix U-Net) to synthesize T2 from T1.
- Evaluate with SSIM/PSNR and save example grids.
- Includes Colab-friendly setup and automatic fallbacks to sample data.


In [None]:
# Colab setup: clone repo, install deps, and create sample data
import os, sys, subprocess
from pathlib import Path

IN_COLAB = False
try:
    import google.colab  # type: ignore
    IN_COLAB = True
except Exception:
    pass

repo_url = "https://github.com/julietam/GenerativeAI_Medical_Images.git"
workdir = Path("/content/GenerativeAI_Medical_Images") if IN_COLAB else Path.cwd()

if IN_COLAB:
    if not workdir.exists():
        subprocess.run(["git", "clone", repo_url, str(workdir)], check=True)
    else:
        subprocess.run(["git", "-C", str(workdir), "pull", "--ff-only"], check=False)
    os.chdir(workdir)

subprocess.run([sys.executable, "-m", "pip", "install", "-q", "-r", "requirements.txt"], check=False)
subprocess.run([sys.executable, "-m", "pip", "install", "-q", "torchmetrics", "monai", "pandas"], check=False)
subprocess.run([sys.executable, "scripts/make_sample_data.py"], check=False)
print("Setup done. CWD =", Path.cwd())


In [None]:
# Data paths: prefer real pairs in data/paired_mri; fallback to data/sample_mri_pairs
from pathlib import Path
import subprocess, sys
paired_root = Path('data/paired_mri')
sample_root = Path('data/sample_mri_pairs')
t1_dir = paired_root/'T1'; t2_dir = paired_root/'T2'
if not any(t1_dir.glob('*.png')) or not any(t2_dir.glob('*.png')):
    # try to ensure samples exist
    try:
        subprocess.run([sys.executable, 'scripts/make_sample_data.py'], check=False)
    except Exception as e:
        print('Warning: could not generate sample data:', e)
    if any((sample_root/'T1').glob('*.png')) and any((sample_root/'T2').glob('*.png')):
        paired_root = sample_root
        t1_dir = sample_root/'T1'; t2_dir = sample_root/'T2'
print('Using pairs from', paired_root)
n_t1 = len(list(t1_dir.glob('*.png'))); n_t2 = len(list(t2_dir.glob('*.png')))
print('Found', n_t1, 'T1 and', n_t2, 'T2')


In [None]:
# Dataset and transforms (grayscale → [-1,1])
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

img_size = 256
tf = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize((img_size, img_size)),
    transforms.ToTensor(),
    transforms.Normalize([0.5],[0.5]),
])

class PairedMRIDataset(Dataset):
    def __init__(self, root):
        root = Path(root)
        self.a = sorted((root/'T1').glob('*.png'))
        self.b = sorted((root/'T2').glob('*.png'))
        assert len(self.a) == len(self.b) and len(self.a) > 0, 'Need paired PNGs under T1/ and T2/'
    def __len__(self): return len(self.a)
    def __getitem__(self, i):
        A = tf(Image.open(self.a[i]).convert('L'))
        B = tf(Image.open(self.b[i]).convert('L'))
        return A, B

ds = PairedMRIDataset(paired_root)
dl = DataLoader(ds, batch_size=4, shuffle=False, num_workers=2)
print('Dataset size =', len(ds))


In [None]:
# Model: Pix2Pix U-Net generator; load checkpoint if available or do a brief fine-tune
from src.models.pix2pix import UNetGenerator, PatchGANDiscriminator
from torch import nn, optim

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G = UNetGenerator(in_channels=1, out_channels=1).to(device)
ckpt_candidates = [
    Path('outputs/pix2pix/best.pt'),
    Path('outputs/pix2pix/last.pt'),
]
loaded=False
for c in ckpt_candidates:
    if c.exists():
        try:
            G.load_state_dict(torch.load(c, map_location=device))
            print('Loaded checkpoint:', c)
            loaded=True; break
        except Exception as e:
            print('Failed to load', c, e)

if not loaded:
    print('No checkpoint found. Running a very brief fine-tune (few iterations) to warm-start...')
    D = PatchGANDiscriminator(in_channels=2).to(device)
    opt_g = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
    opt_d = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))
    L1 = nn.L1Loss(); BCE = nn.BCEWithLogitsLoss()
    G.train(); D.train()
    steps=0
    for A,B in dl:
        A,B = A.to(device), B.to(device)
        # D
        opt_d.zero_grad()
        Z = G(A)
        pred_real = D(A,B)
        pred_fake = D(A,Z.detach())
        loss_d = BCE(pred_real, torch.ones_like(pred_real)) + BCE(pred_fake, torch.zeros_like(pred_fake))
        loss_d.backward(); opt_d.step()
        # G
        opt_g.zero_grad()
        pred_fake = D(A,Z)
        loss_g = BCE(pred_fake, torch.ones_like(pred_fake)) + 100*L1(Z,B)
        loss_g.backward(); opt_g.step()
        steps += 1
        if steps >= 50: break
    G.eval()
else:
    G.eval()


In [None]:
# Evaluate both models and build a metrics table
import os
import pandas as pd
from src.utils.visualization import show_images
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio

os.makedirs('outputs/t1_to_t2', exist_ok=True)

def evaluate(forward):
    ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
    psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
    ssims, psnrs = [], []
    first = True
    cache = None
    with torch.no_grad():
        for i, (A,B) in enumerate(dl):
            A,B = A.to(device), B.to(device)
            Z = forward(A)
            Z01, B01 = (Z+1)/2, (B+1)/2
            ssims.append(ssim_metric(Z01, B01).item())
            psnrs.append(psnr_metric(Z01, B01).item())
            if first:
                cache = (A[:4].cpu(), Z[:4].cpu(), B[:4].cpu())
                first = False
    val_ssim = float(sum(ssims)/len(ssims)) if ssims else 0.0
    val_psnr = float(sum(psnrs)/len(psnrs)) if psnrs else 0.0
    return val_ssim, val_psnr, cache

# Pix2Pix
ssim_g, psnr_g, cache_g = evaluate(lambda x: G(x))
if cache_g is not None:
    a,z,b = cache_g
    show_images(torch.cat([a, z, b], dim=0), nrow=4, title='Pix2Pix: A, G(A), B', save_path='outputs/t1_to_t2/grid_pix2pix.png')

# MONAI UNet
ssim_m, psnr_m, cache_m = evaluate(lambda x: M(x))
if cache_m is not None:
    a2,z2,b2 = cache_m
    show_images(torch.cat([a2, z2, b2], dim=0), nrow=4, title='MONAI UNet: A, M(A), B', save_path='outputs/t1_to_t2/grid_monai_unet.png')

# Table
results = [
    {'model': 'Pix2Pix-UNet', 'SSIM': ssim_g, 'PSNR': psnr_g},
    {'model': 'MONAI-UNet (L1)', 'SSIM': ssim_m, 'PSNR': psnr_m},
]
df = pd.DataFrame(results)
print(df)
df.to_csv('outputs/t1_to_t2/metrics.csv', index=False)
