# 03_pix2pix_mri.ipynb — T1→T2 synthesis

Traducción de T1 a T2 con Pix2Pix (pares alineados). Incluye métricas SSIM/PSNR.


In [None]:
import sys
if '.' not in sys.path: sys.path.append('.')
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
from src.models.pix2pix import UNetGenerator, PatchGANDiscriminator
from src.utils.metrics import ssim, psnr
from src.utils.visualization import show_images


In [None]:
class PairedMRIDataset(Dataset):
    def __init__(self, root, img_size=256):
        self.root = Path(root)
        self.t1 = sorted((self.root/'T1').glob('*.png'))
        self.t2 = sorted((self.root/'T2').glob('*.png'))
        assert len(self.t1) == len(self.t2), 'Número de pares T1/T2 debe coincidir'
        self.tf = transforms.Compose([transforms.Grayscale(1), transforms.Resize((img_size,img_size)), transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
    def __len__(self): return len(self.t1)
    def __getitem__(self, idx):
        a = self.tf(Image.open(self.t1[idx]).convert('L'))
        b = self.tf(Image.open(self.t2[idx]).convert('L'))
        return a, b

# TODO: cambia esta ruta a tu dataset pareado
# Estructura esperada: dataset_root/ T1/*.png, T2/*.png
root = 'data/paired_mri'
img_size = 256
batch_size = 4
train_ds = PairedMRIDataset(root, img_size=img_size)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)

G = UNetGenerator(in_channels=1, out_channels=1).cuda() if torch.cuda.is_available() else UNetGenerator(1,1)
D = PatchGANDiscriminator(in_channels=2).cuda() if torch.cuda.is_available() else PatchGANDiscriminator(2)
opt_g = optim.Adam(G.parameters(), lr=2e-4, betas=(0.5,0.999))
opt_d = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5,0.999))
L1 = nn.L1Loss()
BCE = nn.BCEWithLogitsLoss()


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G.to(device); D.to(device)
for epoch in range(1):
    for i, (a, b) in enumerate(train_dl):
        a, b = a.to(device), b.to(device)
        # ---- D ----
        opt_d.zero_grad()
        pred_real = D(a, b)
        z = G(a)
        pred_fake = D(a, z.detach())
        loss_d = BCE(pred_real, torch.ones_like(pred_real)) + BCE(pred_fake, torch.zeros_like(pred_fake))
        loss_d.backward(); opt_d.step()
        # ---- G ----
        opt_g.zero_grad()
        pred_fake = D(a, z)
        loss_g = BCE(pred_fake, torch.ones_like(pred_fake)) + 100*L1(z, b)
        loss_g.backward(); opt_g.step()
        if i % 50 == 0:
            with torch.no_grad():
                s = torch.cat([a[:4], z[:4], b[:4]], dim=0)
                show_images(s, nrow=4, title=f'Epoch {epoch} iter {i}')
            print('D', float(loss_d), 'G', float(loss_g), 'SSIM', float(ssim(z, b)), 'PSNR', float(psnr(z, b)))
