# 03_pix2pix_mri.ipynb â€” T1â†’T2 synthesis

TraducciÃ³n de T1 a T2 con Pix2Pix (pares alineados). Incluye mÃ©tricas SSIM/PSNR.


In [None]:
import sys
if '.' not in sys.path: sys.path.append('.')
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
from src.models.pix2pix import UNetGenerator, PatchGANDiscriminator
from src.utils.metrics import ssim, psnr
from src.utils.visualization import show_images


In [None]:
class PairedMRIDataset(Dataset):
    def __init__(self, root, img_size=256):
        self.root = Path(root)
        self.t1 = sorted((self.root/'T1').glob('*.png'))
        self.t2 = sorted((self.root/'T2').glob('*.png'))
        assert len(self.t1) == len(self.t2), 'NÃºmero de pares T1/T2 debe coincidir'
        self.tf = transforms.Compose([transforms.Grayscale(1), transforms.Resize((img_size,img_size)), transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
    def __len__(self): return len(self.t1)
    def __getitem__(self, idx):
        a = self.tf(Image.open(self.t1[idx]).convert('L'))
        b = self.tf(Image.open(self.t2[idx]).convert('L'))
        return a, b

# TODO: cambia esta ruta a tu dataset pareado
# Estructura esperada: dataset_root/ T1/*.png, T2/*.png
root = 'data/paired_mri'
# Fallback: intenta generar muestras y usa sample_mri_pairs si paired_mri estÃ¡ vacÃ­o
from pathlib import Path
import subprocess, sys
sample_root = Path('data/sample_mri_pairs')
if not any((Path(root)/'T1').glob('*.png')) or not any((Path(root)/'T2').glob('*.png')):
    try:
        subprocess.run([sys.executable, 'scripts/make_sample_data.py'], check=False)
        if any((sample_root/'T1').glob('*.png')) and any((sample_root/'T2').glob('*.png')):
            root = str(sample_root)
    except Exception as e:
        print('Warning: could not generate sample data:', e)
img_size = 256
batch_size = 4
train_ds = PairedMRIDataset(root, img_size=img_size)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)

G = UNetGenerator(in_channels=1, out_channels=1).cuda() if torch.cuda.is_available() else UNetGenerator(1,1)
D = PatchGANDiscriminator(in_channels=2).cuda() if torch.cuda.is_available() else PatchGANDiscriminator(2)
opt_g = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
opt_d = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))
L1 = nn.L1Loss()
BCE = nn.BCEWithLogitsLoss()


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G.to(device); D.to(device)
for epoch in range(1):
    for i, (a, b) in enumerate(train_dl):
        a, b = a.to(device), b.to(device)
        # ---- D ----
        opt_d.zero_grad()
        pred_real = D(a, b)
        z = G(a)
        pred_fake = D(a, z.detach())
        loss_d = BCE(pred_real, torch.ones_like(pred_real)) + BCE(pred_fake, torch.zeros_like(pred_fake))
        loss_d.backward(); opt_d.step()
        # ---- G ----
        opt_g.zero_grad()
        pred_fake = D(a, z)
        loss_g = BCE(pred_fake, torch.ones_like(pred_fake)) + 100*L1(z, b)
        loss_g.backward(); opt_g.step()
        if i % 50 == 0:
            with torch.no_grad():
                s = torch.cat([a[:4], z[:4], b[:4]], dim=0)
                show_images(s, nrow=4, title=f'Epoch {epoch} iter {i}')
            print('D', float(loss_d), 'G', float(loss_g), 'SSIM', float(ssim(z, b)), 'PSNR', float(psnr(z, b)))


In [None]:
from src.utils.visualization import plot_curves, show_images
from torch.utils.data import random_split
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
import os, torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G.eval()
# build validation split (10%)
full_ds = PairedMRIDataset(root, img_size=img_size)
n_val = max(1, int(0.1 * len(full_ds)))
n_train = len(full_ds) - n_val
train_subset, val_ds = random_split(full_ds, [n_train, n_val])
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
ssims, psnrs = [], []
with torch.no_grad():
    for a,b in val_dl:
        a,b = a.to(device), b.to(device)
        z = G(a)
        z01, b01 = (z+1)/2, (b+1)/2
        ssims.append(ssim_metric(z01,b01).item())
        psnrs.append(psnr_metric(z01,b01).item())
val_ssim = float(sum(ssims)/len(ssims)) if ssims else 0.0
val_psnr = float(sum(psnrs)/len(psnrs)) if psnrs else 0.0
print({'val_ssim': val_ssim, 'val_psnr': val_psnr})
history = {'val_ssim':[val_ssim], 'val_psnr':[val_psnr]}
os.makedirs('outputs/pix2pix', exist_ok=True)
plot_curves(history, title='Pix2Pix validation', save_path='outputs/pix2pix/metrics.png')
samp_a, samp_b = next(iter(val_dl)) if len(val_dl)>0 else (a.cpu()[:4], b.cpu()[:4])
samp_a = samp_a.to(device)
samp_z = G(samp_a).cpu()
show_images(torch.cat([samp_a.cpu()[:4], samp_z[:4], samp_b[:4]], dim=0), nrow=4, title='Val: A, G(A), B', save_path='outputs/pix2pix/val_grid.png')
