# 04_cyclegan_ct_mri.ipynb â€” MRIâ†”CT translation

TraducciÃ³n no pareada entre dominios con CycleGAN (MRI â†” CT).


In [None]:
import sys
if '.' not in sys.path: sys.path.append('.')
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from pathlib import Path
from src.models.cyclegan import ResnetGenerator, NLayerDiscriminator
from src.utils.visualization import show_images


In [None]:
class UnpairedDataset(Dataset):
    def __init__(self, rootA, rootB, img_size=256):
        self.A = sorted(Path(rootA).glob('*.png'))
        self.B = sorted(Path(rootB).glob('*.png'))
        self.tf = transforms.Compose([transforms.Grayscale(1), transforms.Resize((img_size,img_size)), transforms.ToTensor(), transforms.Normalize([0.5],[0.5])])
    def __len__(self): return max(len(self.A), len(self.B))
    def __getitem__(self, idx):
        a = self.tf(Image.open(self.A[idx % len(self.A)]).convert('L'))
        b = self.tf(Image.open(self.B[idx % len(self.B)]).convert('L'))
        return a, b

# TODO: cambia estas rutas a tus carpetas MRI y CT
rootA, rootB = 'data/sample_unpaired_mri', 'data/sample_unpaired_ct'
# Autogenera datos de muestra si no existen
from pathlib import Path
import subprocess, sys
if not any(Path(rootA).glob('*.png')) or not any(Path(rootB).glob('*.png')):
    try:
        subprocess.run([sys.executable, 'scripts/make_sample_data.py'], check=False)
    except Exception as e:
        print('Warning: could not generate sample data:', e)
img_size=256; batch_size=2
train_ds = UnpairedDataset(rootA, rootB, img_size)
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=2)

G_AB = ResnetGenerator(1,1); G_BA = ResnetGenerator(1,1)
D_A = NLayerDiscriminator(1); D_B = NLayerDiscriminator(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G_AB.to(device); G_BA.to(device); D_A.to(device); D_B.to(device)
opt_g = optim.Adam(list(G_AB.parameters())+list(G_BA.parameters()), lr=2e-4, betas=(0.5,0.999))
opt_d = optim.Adam(list(D_A.parameters())+list(D_B.parameters()), lr=2e-4, betas=(0.5,0.999))
cycle = nn.L1Loss(); ident = nn.L1Loss(); bce = nn.BCEWithLogitsLoss()


In [None]:
for epoch in range(1):
    for i, (a, b) in enumerate(train_dl):
        a, b = a.to(device), b.to(device)
        # ---- Train D ----
        opt_d.zero_grad()
        fake_b = G_AB(a).detach()
        fake_a = G_BA(b).detach()
        loss_d = bce(D_A(a), torch.ones_like(D_A(a))) + bce(D_A(fake_a), torch.zeros_like(D_A(fake_a))) \
               + bce(D_B(b), torch.ones_like(D_B(b))) + bce(D_B(fake_b), torch.zeros_like(D_B(fake_b)))
        loss_d.backward(); opt_d.step()
        # ---- Train G (adversarial + cycle + identity) ----
        opt_g.zero_grad()
        fake_b = G_AB(a); fake_a = G_BA(b)
        rec_a = G_BA(fake_b); rec_b = G_AB(fake_a)
        adv = bce(D_B(fake_b), torch.ones_like(D_B(fake_b))) + bce(D_A(fake_a), torch.ones_like(D_A(fake_a)))
        cyc = cycle(rec_a, a) + cycle(rec_b, b)
        idt = ident(G_AB(b), b) + ident(G_BA(a), a)
        loss_g = adv + 10*cyc + 0.5*idt
        loss_g.backward(); opt_g.step()
        if i % 50 == 0:
            print(f'E{epoch} I{i} | D {float(loss_d):.3f} G {float(loss_g):.3f}')
            with torch.no_grad():
                show_images(torch.cat([a[:4], fake_b[:4], rec_a[:4]], dim=0), nrow=4, title='A->B->A')
                show_images(torch.cat([b[:4], fake_a[:4], rec_b[:4]], dim=0), nrow=4, title='B->A->B')


In [None]:
from src.utils.visualization import plot_curves, show_images
from torch.utils.data import random_split
from torchmetrics.image import StructuralSimilarityIndexMeasure, PeakSignalNoiseRatio
import os, torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G_AB.eval(); G_BA.eval()
full_ds = UnpairedDataset(rootA, rootB, img_size)
n_val = max(1, int(0.1*len(full_ds)))
n_train = len(full_ds)-n_val
_, val_ds = random_split(full_ds, [n_train, n_val])
val_dl = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=2)
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)
psnr_metric = PeakSignalNoiseRatio(data_range=1.0).to(device)
ssim_A, psnr_A, ssim_B, psnr_B = [], [], [], []
with torch.no_grad():
    for a,b in val_dl:
        a,b = a.to(device), b.to(device)
        fake_b = G_AB(a); rec_a = G_BA(fake_b)
        fake_a = G_BA(b); rec_b = G_AB(fake_a)
        ra01, a01 = (rec_a+1)/2, (a+1)/2
        rb01, b01 = (rec_b+1)/2, (b+1)/2
        ssim_A.append(ssim_metric(ra01, a01).item()); psnr_A.append(psnr_metric(ra01, a01).item())
        ssim_B.append(ssim_metric(rb01, b01).item()); psnr_B.append(psnr_metric(rb01, b01).item())
h = {
 'cycle_ssim_A': [float(sum(ssim_A)/len(ssim_A)) if ssim_A else 0.0],
 'cycle_psnr_A': [float(sum(psnr_A)/len(psnr_A)) if psnr_A else 0.0],
 'cycle_ssim_B': [float(sum(ssim_B)/len(ssim_B)) if ssim_B else 0.0],
 'cycle_psnr_B': [float(sum(psnr_B)/len(psnr_B)) if psnr_B else 0.0],
}
print(h)
os.makedirs('outputs/cyclegan', exist_ok=True)
plot_curves(h, title='CycleGAN validation (TorchMetrics)', save_path='outputs/cyclegan/metrics.png')
if len(val_dl)>0:
    a,b = next(iter(val_dl))
    a,b = a.to(device), b.to(device)
    fb = G_AB(a); ra = G_BA(fb)
    fa = G_BA(b); rb = G_AB(fa)
    show_images(torch.cat([a[:4], fb[:4], ra[:4]], dim=0), nrow=4, title='A->B->A (val)', save_path='outputs/cyclegan/a_b_a.png')
    show_images(torch.cat([b[:4], fa[:4], rb[:4]], dim=0), nrow=4, title='B->A->B (val)', save_path='outputs/cyclegan/b_a_b.png')
