# 05_diffusion_xray.ipynb — Diffusion model training

Entrenamiento mínimo de DDPM para imágenes de rayos X (64x64, 1 canal).


In [None]:
import sys
if '.' not in sys.path: sys.path.append('.')
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from src.models.diffusion import SimpleUNet, DDPM
from src.utils.visualization import show_images


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_size = 64
batch_size = 32
# Usa PNGs locales si existen (p.ej., data/sample_unpaired_ct); si no, FakeData
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class PNGFolderDataset(Dataset):
    def __init__(self, root, img_size):
        self.paths = sorted(Path(root).glob('*.png'))
        self.tf = transforms.Compose([
            transforms.Grayscale(num_output_channels=1),
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
    def __len__(self): return len(self.paths)
    def __getitem__(self, i):
        return self.tf(Image.open(self.paths[i]).convert('L')), 0
root_ct = Path('data/sample_unpaired_ct')
if not any(root_ct.glob('*.png')):
    # intenta generar muestras
    import subprocess, sys
    try: subprocess.run([sys.executable, 'scripts/make_sample_data.py'], check=False)
    except Exception as e: print('Warning: sample gen failed:', e)
if any(root_ct.glob('*.png')):
    dataset = PNGFolderDataset(root_ct, img_size)
else:
    dataset = datasets.FakeData(size=512, image_size=(1,img_size,img_size), transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])]))
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
model = SimpleUNet(channels=1, base=32).to(device)
ddpm = DDPM(model, img_size=img_size, channels=1, timesteps=200).to(device)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)


In [None]:
for epoch in range(1):
    for x, _ in loader:
        x = (x - 0.5) / 0.5
        x = x.to(device)
        t = torch.randint(0, ddpm.T, (x.size(0),), device=device)
        loss, _ = ddpm.p_losses(x, t)
        opt.zero_grad(); loss.backward(); opt.step()
    print('epoch', epoch, 'loss', float(loss))
    with torch.no_grad():
        samples = ddpm.sample(16, device=device)
        show_images(samples.cpu(), nrow=4, title=f'Epoch {epoch}')


In [None]:
from torchmetrics.image.fid import FrechetInceptionDistance
from src.utils.visualization import plot_curves, show_images
from torch.utils.data import DataLoader
import os, torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# build validation set
full_ds = datasets.FakeData(size=256, image_size=(1,img_size,img_size), transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])]))
val_dl = DataLoader(full_ds, batch_size=batch_size, shuffle=False, num_workers=2)
fid = FrechetInceptionDistance(normalize=True).to(device)
# real updates
for xr,_ in val_dl:
    xr = xr.to(device)
    xr01 = (xr + 1)/2
    fid.update(xr01, real=True)
# fake updates
with torch.no_grad():
    fake = ddpm.sample(len(full_ds), device=device)
fake01 = (fake + 1)/2
for i in range(0, fake01.size(0), batch_size):
    fid.update(fake01[i:i+batch_size], real=False)
fid_val = float(fid.compute())
print({'FID': fid_val})
os.makedirs('outputs/diffusion', exist_ok=True)
plot_curves({'FID':[fid_val]}, title='Diffusion validation (TorchMetrics FID)', save_path='outputs/diffusion/metrics.png')
show_images(fake[:16].cpu(), nrow=4, title='Diffusion samples (val)', save_path='outputs/diffusion/samples.png')
