# 05_diffusion_xray.ipynb — Diffusion model training

Entrenamiento mínimo de DDPM para imágenes de rayos X (64x64, 1 canal).


In [None]:
import sys
if '.' not in sys.path: sys.path.append('.')
import torch
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from src.models.diffusion import SimpleUNet, DDPM
from src.utils.visualization import show_images


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
img_size = 64
batch_size = 32
# TODO: reemplazar por dataset de rayos X (e.g., RSNA, NIH)
dataset = datasets.FakeData(size=512, image_size=(1,img_size,img_size), transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])]))
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=2)
model = SimpleUNet(channels=1, base=32).to(device)
ddpm = DDPM(model, img_size=img_size, channels=1, timesteps=200).to(device)
opt = torch.optim.Adam(model.parameters(), lr=2e-4)


In [None]:
for epoch in range(1):
    for x, _ in loader:
        x = (x - 0.5) / 0.5
        x = x.to(device)
        t = torch.randint(0, ddpm.T, (x.size(0),), device=device)
        loss, _ = ddpm.p_losses(x, t)
        opt.zero_grad(); loss.backward(); opt.step()
    print('epoch', epoch, 'loss', float(loss))
    with torch.no_grad():
        samples = ddpm.sample(16, device=device)
        show_images(samples.cpu(), nrow=4, title=f'Epoch {epoch}')


In [None]:
from src.utils.metrics import fid
from src.utils.visualization import plot_curves
from torchvision import models, transforms as T
from torch.utils.data import DataLoader
import os, torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# build validation set (proxy)
full_ds = datasets.FakeData(size=256, image_size=(1,img_size,img_size), transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize([0.5],[0.5])]))
val_dl = DataLoader(full_ds, batch_size=batch_size, shuffle=False, num_workers=2)
try:
    resnet = models.resnet18(weights=models.ResNet18_Weights.DEFAULT).eval().to(device)
    feat = torch.nn.Sequential(*(list(resnet.children())[:-1]))
except Exception as e:
    print('Warning: could not load pretrained weights:', e)
    resnet = models.resnet18(weights=None).eval().to(device)
    feat = torch.nn.Sequential(*(list(resnet.children())[:-1]))
prep = T.Compose([T.Resize((224,224)), T.Lambda(lambda t: t.repeat(1,3,1,1) if t.size(1)==1 else t)])
def get_feats(x):
    x = (x+1)/2 if x.min()<0 else x
    x = prep(x)
    with torch.no_grad():
        f = feat(x).flatten(1)
    return f
# collect real feats
real_feats = []
for xr,_ in val_dl:
    xr = xr.to(device)
    real_feats.append(get_feats(xr))
real_feats = torch.cat(real_feats, dim=0)
# generate and collect feats
with torch.no_grad():
    fake = ddpm.sample(real_feats.size(0), device=device)
fake_feats = []
for i in range(0, fake.size(0), batch_size):
    fake_feats.append(get_feats(fake[i:i+batch_size]))
fake_feats = torch.cat(fake_feats, dim=0)
fid_val = fid(real_feats, fake_feats)
print({'FID_proxy': fid_val})
plot_curves({'FID_proxy':[fid_val]}, title='Diffusion validation (FID proxy)')
