In [None]:
# Core utilities and system tools
from tqdm import tqdm
import numpy as np
import os
import random
import math
from typing import Tuple
import warnings
import time

# Suppress specific TorchIO loader warnings (non-critical)
warnings.filterwarnings("ignore", message=".*torchio.*SubjectsLoader.*")

# Memory management
import gc

# PyTorch core
import torch
import torch.nn as nn
from torch.nn import functional as F
import torch.utils.checkpoint as cp  # activation checkpointing to save memory

# Training acceleration (multi-GPU, mixed precision, etc.)
from accelerate import Accelerator

# Medical imaging augmentations
import torchio as tio

# Optimization and scheduling
from torch.optim import Adam
from transformers import get_linear_schedule_with_warmup
from diffusers import DDIMScheduler  # diffusion scheduler

# Data loading
from torch.utils.data import Dataset, DataLoader

# Logging
from torch.utils.tensorboard import SummaryWriter

# Visualization
import matplotlib.pyplot as plt


In [None]:
class VolumeDataset(Dataset):
    def __init__(self, pt_dir, train=True):
        self.pt_dir = pt_dir
        self.pt_files = sorted([os.path.join(pt_dir, f) for f in os.listdir(pt_dir) if f.endswith(".pt")])
        self.train = train

        # Spatial augmentation pool (heavy 3D transforms)
        self.transform_list = [
            tio.RandomAffine(scales=(0.8, 1.1), degrees=8, isotropic=True, center='image', p=1),
            tio.RandomElasticDeformation(num_control_points=7, max_displacement=10, p=1),
            tio.RandomFlip(axes=(0,), p=1),
        ]

        # Discrete gamma shift values for intensity augmentation
        self.gamma_values = [-0.4, -0.3, -0.2, -0.1, 
                             0, 0.1, 0.2, 0.3, 0.4]

    def __len__(self):
        return len(self.pt_files)

    def __getitem__(self, idx):
        data = torch.load(self.pt_files[idx])
        volume, ds_name = data["volume"], data["ds_name"]

        volume = volume.unsqueeze(0)

        if self.train:
            # Random selection of spatial transforms
            n_transforms = random.choice([1, 2])
            chosen = random.sample(self.transform_list, n_transforms)
            transform = tio.Compose(chosen)

            # TorchIO uses Subject wrappers for volumetric transforms
            subject = tio.Subject(vol=tio.ScalarImage(tensor=volume))
            subject = transform(subject)
            volume = subject['vol'].data.clone()

            # Gamma augmentation applied deterministically with fixed log_gamma
            gamma_val = random.choice(self.gamma_values)
            gamma_transform = tio.RandomGamma(log_gamma=(gamma_val, gamma_val), p=1.0)
            subject = tio.Subject(vol=tio.ScalarImage(tensor=volume))
            subject = gamma_transform(subject)
            volume = subject['vol'].data.clone()

            # Dataset label adjusted to record intensity transform
            ds_name = f"{ds_name}_gamma_{gamma_val}"

        if not self.train:
            ds_name = f"{ds_name}_gamma_0"

        # Z-score normalization (global over the full volume)
        v_mean = volume.mean()
        v_std = volume.std(unbiased=False)
        volume_z_score = (volume - v_mean) / v_std

        return volume_z_score, ds_name


In [None]:
# Train dataset and loader
train_dataset = VolumeDataset("/NAS/coolio/Barnabe/CODES/diffusion_classifier_guidance/iguane_pt_train_dataset", train=True)
train_loader = DataLoader(train_dataset, batch_size=5, shuffle=True, num_workers=5, pin_memory=True, persistent_workers=False, prefetch_factor=3)

# Test dataset and loader (no augmentation)
test_dataset = VolumeDataset("/NAS/coolio/Barnabe/CODES/diffusion_classifier_guidance/iguane_pt_test_dataset", train=False)
test_loader = DataLoader(test_dataset, batch_size=5, shuffle=False, num_workers=5, pin_memory=True, persistent_workers=False, prefetch_factor=3)

## Labels

In [None]:
gamma_values = [-0.4, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.4]

# Collect original dataset names (before augmentation)
original_ds_names = set()
for pt_file in tqdm(test_dataset.pt_files):
    data = torch.load(pt_file)
    original_ds_names.add(data["ds_name"])

# Generate all possible augmented labels (name + gamma)
all_labels = set()
for name in tqdm(original_ds_names):
    for gamma in gamma_values:
        all_labels.add(f"{name}_gamma_{gamma}")

# Create bidirectional mapping between string label and integer ID
ds2id = {name: idx + 1 for idx, name in enumerate(sorted(all_labels))}
id2ds = {idx: name for name, idx in ds2id.items()}

n_classes = len(all_labels)
print(f"Number of classes: {n_classes}")


## Anatomical representation

In [None]:
def _gaussian_kernel_3d(sigma: float, device: torch.device, dtype=torch.float32):
    # Build separable 3D Gaussian kernel (or identity if sigma=0)
    if sigma <= 0:
        k = torch.zeros((1,1,1,1,1), device=device, dtype=dtype)
        k[0,0,0,0,0] = 1.0
        return k
    radius = int(max(1, torch.ceil(3 * torch.tensor(sigma)).item()))
    coords = torch.arange(-radius, radius+1, device=device, dtype=dtype)
    g1d = torch.exp(-(coords**2) / (2 * sigma**2))
    g1d = g1d / g1d.sum()
    g3d = g1d[:,None,None] * g1d[None,:,None] * g1d[None,None,:]
    return g3d.unsqueeze(0).unsqueeze(0)  # (1,1,D,H,W)

def _pad_for_conv(kernel):
    # Compute symmetric padding for "same" 3D convolution
    kD, kH, kW = kernel.shape[-3:]
    return (kW//2, kW//2, kH//2, kH//2, kD//2, kD//2)

def _conv3d_same(x, kernel):
    # Apply 3D conv with replicate-padding to preserve spatial size
    pad = _pad_for_conv(kernel)
    x_p = F.pad(x, pad, mode='replicate')
    return F.conv3d(x_p, kernel)

def gaussian_blur3d(x, sigma: float):
    # Gaussian smoothing wrapper
    if sigma <= 0:
        return x
    kernel = _gaussian_kernel_3d(sigma, device=x.device, dtype=x.dtype)
    return _conv3d_same(x, kernel)

def _sobel_kernels_3d(device, dtype=torch.float32):
    # 3D central-difference Sobel-like kernels
    kx = torch.zeros((1,1,3,3,3), device=device, dtype=dtype)
    ky = torch.zeros_like(kx)
    kz = torch.zeros_like(kx)
    kx[0,0,1,1,0] = -1.0; kx[0,0,1,1,2] = 1.0
    ky[0,0,1,0,1] = -1.0; ky[0,0,1,2,1] = 1.0
    kz[0,0,0,1,1] = -1.0; kz[0,0,2,1,1] = 1.0
    return kx, ky, kz

def gradient_magnitude_torch(x, sigma=0.0):
    # Multi-scale gradient magnitude using optional Gaussian pre-blur
    if sigma > 0:
        x = gaussian_blur3d(x, sigma)
    kx, ky, kz = _sobel_kernels_3d(x.device, dtype=x.dtype)
    pad = _pad_for_conv(kx)
    xpad = F.pad(x, pad, mode='replicate')
    dx = F.conv3d(xpad, kx)
    dy = F.conv3d(xpad, ky)
    dz = F.conv3d(xpad, kz)
    return torch.sqrt(dx*dx + dy*dy + dz*dz + 1e-12)

def _global_quantile(tensor, q, max_samples=2_000_000):
    """
    Compute approximate global quantile using a deterministic stride-subsample
    when full flattening would exceed memory limits.
    """
    flat = tensor.view(-1)
    n = flat.numel()
    if n <= max_samples:
        return torch.quantile(flat, q)
    step = int(n // max_samples) + 1
    sample = flat[::step]
    return torch.quantile(sample, q)

def make_structural_anatomy_map(batch_imgs: torch.Tensor,
                                grad_sigmas=(0.5, 2.0),
                                hf_sigma=1.0,
                                smooth_sigma=1.0,
                                normalize_percentiles=(1.0, 99.0)):
    """
    Crée une carte anatomique 3D unique (1 canal) de même dimension que l'entrée.
    """
    assert batch_imgs.ndim == 5 and batch_imgs.shape[1] == 1
    device, dtype = batch_imgs.device, batch_imgs.dtype

    # Multi-scale gradient channels
    g1 = gradient_magnitude_torch(batch_imgs, sigma=grad_sigmas[0])
    g2 = gradient_magnitude_torch(batch_imgs, sigma=grad_sigmas[1])

    # High-frequency residual |I - Gσ(I)|
    blurred = gaussian_blur3d(batch_imgs, sigma=hf_sigma)
    hf = torch.abs(batch_imgs - blurred)

    # Weighted fusion
    combined = 0.5 * g1 + 0.3 * g2 + 0.2 * hf

    # Global robust normalization to [-1, 1]
    p1, p99 = normalize_percentiles
    lo = _global_quantile(combined, p1/100.0)
    hi = _global_quantile(combined, p99/100.0)
    normed = (combined - lo) / (hi - lo + 1e-6)
    normed = normed.clamp(0, 1) * 2 - 1

    # Optional anatomical smoothing
    if smooth_sigma > 0:
        normed = gaussian_blur3d(normed, sigma=smooth_sigma)

    return normed


## Utils

In [None]:
def save_checkpoint(epoch, step, model, embedder, optimizer, checkpoint_dir, accelerator):
    accelerator.wait_for_everyone()
    unwrapped = accelerator.unwrap_model(model)
    if accelerator.is_main_process:
        ckpt = {
            "model_state_dict": unwrapped.state_dict(),
            "embedder_state_dict": embedder.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "epoch": epoch,
            "step": step,
        }
        fname = os.path.join(checkpoint_dir, f"ckpt_ep{epoch:04d}_crop_coarse.pt")
        torch.save(ckpt, fname)
        print(f"[checkpoint] saved -> {fname}")

In [None]:
def load_checkpoint_if_exists(resume_from, model_diffusion, embedder, optimizer, accelerator):
    """
    Load checkpoint if `resume_from` is provided and exists.
    Must be called AFTER accelerator.prepare(...) because state_dicts are loaded
    into the wrapped objects.
    Returns:
        start_epoch (int), global_step (int)
    """
    if resume_from is None:
        return 0, 0  # start at epoch 0, global_step 0

    ckpt = torch.load(resume_from, map_location=accelerator.device)

    # Must unwrap Accelerator-managed model before applying state_dict
    unwrapped_model = accelerator.unwrap_model(model_diffusion)
    unwrapped_model.load_state_dict(ckpt["model_state_dict"])

    embedder.load_state_dict(ckpt["embedder_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])

    start_epoch = ckpt.get("epoch", 0)
    global_step = ckpt.get("step", 0)

    print(f"Resumed from checkpoint {resume_from} -> start_epoch={start_epoch}, global_step={global_step}")

    return start_epoch, global_step


In [None]:
def get_3_slices(volume_np):
    # volume_np: 3D numpy array (D1,D2,D3)
    d1, d2, d3 = volume_np.shape
    i = d1 // 2
    j = d2 // 2
    k = d3 // 2
    slice1 = volume_np[i, :, :]   # sagittal
    slice2 = volume_np[:, j, :]   # coronal
    slice3 = volume_np[:, :, k]   # axial
    return slice1, slice2, slice3

In [None]:
def random_crop_3d_pair(batch_imgs, anatomic_cond, patch_size=(64, 64, 64)):
    """
    Randomly extracts a 3D patch from both `batch_imgs` and `anatomic_cond`,
    after performing an individual tight crop around the brain region
    (mask != minimum value).

    Args:
        batch_imgs: Tensor of shape (B, 1, D, H, W)
        anatomic_cond: Tensor of shape (B, 1, D, H, W)
        patch_size: tuple (pd, ph, pw) defining patch depth, height, width

    Returns:
        patch_imgs: Tensor of shape (B, 1, pd, ph, pw)
        patch_cond: Tensor of shape (B, 1, pd, ph, pw)
    """
    assert batch_imgs.shape == anatomic_cond.shape
    B, C, D, H, W = batch_imgs.shape
    pd, ph, pw = patch_size

    patch_imgs_list = []
    patch_cond_list = []

    for b in range(B):
        # Brain mask: exclude background using min-intensity
        mask3 = (batch_imgs[b, 0] != batch_imgs[b, 0].min())

        # Extract bounding box of brain region
        zs, ys, xs = torch.where(mask3)
        zmin, zmax = int(zs.min().item()), int(zs.max().item())
        ymin, ymax = int(ys.min().item()), int(ys.max().item())
        xmin, xmax = int(xs.min().item()), int(xs.max().item())

        # Tight crop (+1 for inclusive upper bound)
        img_crop = batch_imgs[b:b+1, :, zmin:zmax+1, ymin:ymax+1, xmin:xmax+1]
        cond_crop = anatomic_cond[b:b+1, :, zmin:zmax+1, ymin:ymax+1, xmin:xmax+1]

        _, _, Dc, Hc, Wc = img_crop.shape

        # Random valid patch origin in cropped volume
        z0 = torch.randint(0, Dc - pd + 1, (1,)).item()
        y0 = torch.randint(0, Hc - ph + 1, (1,)).item()
        x0 = torch.randint(0, Wc - pw + 1, (1,)).item()

        patch_img = img_crop[:, :, z0:z0+pd, y0:y0+ph, x0:x0+pw]
        patch_cond = cond_crop[:, :, z0:z0+pd, y0:y0+ph, x0:x0+pw]

        patch_imgs_list.append(patch_img)
        patch_cond_list.append(patch_cond)

    # Stack patches back into batch
    patch_imgs = torch.cat(patch_imgs_list, dim=0)
    patch_cond = torch.cat(patch_cond_list, dim=0)

    return patch_imgs, patch_cond


## Diffusion model

In [None]:
import math
from typing import Tuple, Sequence, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F


# ---------------------------
# Helpers: timestep embedding
# ---------------------------
def timestep_embedding(timesteps: torch.Tensor, dim: int, max_period: int = 10000):
    """
    Sinusoidal timestep embedding, same style as common diffusion implementations.
    timesteps: (B,) long
    returns: (B, dim)
    """
    half = dim // 2
    freqs = torch.exp(
        -math.log(max_period) * torch.arange(0, half, dtype=torch.float32, device=timesteps.device) / half
    )
    args = timesteps.float().unsqueeze(1) * freqs.unsqueeze(0)
    emb = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
    if dim % 2:
        emb = F.pad(emb, (0, 1))
    return emb  # (B, dim)


# ---------------------------
# Basic blocks in 3D
# ---------------------------
class Conv3dZeroInit(nn.Conv3d):
    """Conv3d with zero initialization option for residual projection (optionally)."""
    pass


class ResidualBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim=None):
        super().__init__()
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.conv1 = nn.Conv3d(in_ch, out_ch, kernel_size=3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.conv2 = nn.Conv3d(out_ch, out_ch, kernel_size=3, padding=1)
        self.nin_shortcut = None
        if in_ch != out_ch:
            self.nin_shortcut = nn.Conv3d(in_ch, out_ch, kernel_size=1)
        if time_emb_dim is not None:
            self.time_mlp = nn.Sequential(
                nn.SiLU(),
                nn.Linear(time_emb_dim, out_ch)
            )
        else:
            self.time_mlp = None

    def forward(self, x, t_emb=None):
        h = self.norm1(x)
        h = F.silu(h)
        h = self.conv1(h)
        if self.time_mlp is not None and t_emb is not None:
            # t_emb: (B, dim)
            t = self.time_mlp(t_emb).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            h = h + t
        h = self.norm2(h)
        h = F.silu(h)
        h = self.conv2(h)
        if self.nin_shortcut is not None:
            x = self.nin_shortcut(x)
        return x + h


# ---------------------------
# Multi-head attention over 3D tokens (self-attention + cross-attention)
# ---------------------------
class MultiHeadAttention3D(nn.Module):
    def __init__(self, dim, num_heads, head_dim, cross_dim=None):
        """
        dim: input embedding dim
        num_heads: number of heads
        head_dim: dimension per head
        cross_dim: if not None, cross-attention key/val come from vector of dim cross_dim
        """
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.inner_dim = num_heads * head_dim
        self.scale = head_dim ** -0.5

        self.to_q = nn.Linear(dim, self.inner_dim, bias=False)
        self.to_k = nn.Linear(dim if cross_dim is None else cross_dim, self.inner_dim, bias=False)
        self.to_v = nn.Linear(dim if cross_dim is None else cross_dim, self.inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(self.inner_dim, dim),
        )

    def forward(self, x, context: Optional[torch.Tensor] = None):
        """
        x: (B, N, dim) -- N = D*H*W tokens
        context: if provided, keys/values come from context (B, M, cross_dim) -> cross-attention
                 else use x (self-attention)
        returns: (B, N, dim)
        """
        b, n, _ = x.shape
        context = x if context is None else context
        q = self.to_q(x)
        k = self.to_k(context)
        v = self.to_v(context)
        # reshape [B, N, heads, head_dim] -> [B, heads, N, head_dim]
        q = q.view(b, n, self.num_heads, self.head_dim).transpose(1, 2)
        k = k.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)
        v = v.view(b, -1, self.num_heads, self.head_dim).transpose(1, 2)

        # scaled dot-product
        attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale  # (B, heads, N, M)
        attn = torch.softmax(attn, dim=-1)
        out = torch.matmul(attn, v)  # (B, heads, N, head_dim)
        out = out.transpose(1, 2).contiguous().view(b, n, self.inner_dim)
        return self.to_out(out)


class AttentionBlock3D(nn.Module):
    def __init__(self, channels, num_heads, head_dim, cross_attention_dim: Optional[int] = None):
        super().__init__()
        self.norm = nn.GroupNorm(8, channels)
        self.proj_in = nn.Conv3d(channels, channels, kernel_size=1)
        self.proj_out = nn.Conv3d(channels, channels, kernel_size=1)
        self.mha = MultiHeadAttention3D(dim=channels, num_heads=num_heads, head_dim=head_dim,
                                        cross_dim=cross_attention_dim)

        self.cross_attention_dim = cross_attention_dim

    def forward(self, x: torch.Tensor, context: Optional[torch.Tensor] = None):
        """
        x: (B, C, D, H, W)
        context: (B, M, cross_attention_dim) or None
        returns same shape as x
        """
        b, c, d, h, w = x.shape
        h_in = self.norm(x)
        h_in = F.silu(h_in)
        h_in = self.proj_in(h_in)  # (B, C, D, H, W)
        # flatten spatial dims
        h_flat = h_in.view(b, c, d * h * w).permute(0, 2, 1)  # (B, N, C)
        if context is not None:
            # context expected (B, M, cross_dim)
            attn_out = self.mha(h_flat, context)  # (B, N, C)
        else:
            attn_out = self.mha(h_flat, None)
        attn_out = attn_out.permute(0, 2, 1).view(b, c, d, h, w)
        out = self.proj_out(attn_out)
        return x + out


# ---------------------------
# Down / Up blocks
# ---------------------------
class Downsample3D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.op = nn.Conv3d(channels, channels, kernel_size=3, stride=2, padding=1)

    def forward(self, x):
        return self.op(x)


class Upsample3D(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.op = nn.ConvTranspose3d(channels, channels, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        return self.op(x)


class DownBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim=None, use_attn=False, num_heads=1, head_dim=32, cross_attention_dim=None):
        super().__init__()
        self.res1 = ResidualBlock3D(in_ch, out_ch, time_emb_dim=time_emb_dim)
        self.attn = AttentionBlock3D(out_ch, num_heads, head_dim, cross_attention_dim) if use_attn else None
        self.res2 = ResidualBlock3D(out_ch, out_ch, time_emb_dim=time_emb_dim)

    def forward(self, x, t_emb=None, context=None):
        x = self.res1(x, t_emb)
        if self.attn is not None:
            x = self.attn(x, context)
        x = self.res2(x, t_emb)
        return x


class UpBlock3D(nn.Module):
    def __init__(self, in_ch, out_ch, time_emb_dim=None, use_attn=False, num_heads=1, head_dim=32, cross_attention_dim=None):
        super().__init__()
        # in_ch is the concatenated channels from skip + current
        self.res1 = ResidualBlock3D(in_ch, out_ch, time_emb_dim=time_emb_dim)
        self.attn = AttentionBlock3D(out_ch, num_heads, head_dim, cross_attention_dim) if use_attn else None
        self.res2 = ResidualBlock3D(out_ch, out_ch, time_emb_dim=time_emb_dim)

    def forward(self, x, skip, t_emb=None, context=None):
        x = torch.cat([x, skip], dim=1)
        x = self.res1(x, t_emb)
        if self.attn is not None:
            x = self.attn(x, context)
        x = self.res2(x, t_emb)
        return x


# ---------------------------
# The main UNet3DConditionModel
# ---------------------------
class UNet3DConditionModel_maison(nn.Module):
    def __init__(
        self,
        sample_size: Tuple[int, int, int],
        in_channels: int = 1,
        out_channels: int = 1,
        down_block_types: Sequence[str] = ("DownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"),
        up_block_types: Sequence[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "UpBlock3D"),
        block_out_channels: Sequence[int] = (32, 64, 128, 256),
        cross_attention_dim: int = 512,
        attention_head_dim: int = 64,
        time_embedding_dim: int = 512,
    ):
        super().__init__()

        assert len(down_block_types) == len(up_block_types) == len(block_out_channels), \
            "down_block_types, up_block_types and block_out_channels must have same length"

        self.sample_size = sample_size
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.block_out_channels = block_out_channels
        self.cross_attention_dim = cross_attention_dim
        self.attention_head_dim = attention_head_dim
        self.time_embedding_dim = time_embedding_dim

        # initial conv
        self.conv_in = nn.Conv3d(in_channels, block_out_channels[0], kernel_size=3, padding=1)

        # time embedding MLP
        self.time_mlp = nn.Sequential(
            nn.Linear(time_embedding_dim, time_embedding_dim * 4),
            nn.SiLU(),
            nn.Linear(time_embedding_dim * 4, time_embedding_dim)
        )

        # build down blocks
        self.down_blocks = nn.ModuleList()
        self.downsamplers = nn.ModuleList()
        prev_ch = block_out_channels[0]
        for i, out_ch in enumerate(block_out_channels):
            block_type = down_block_types[i]
            use_attn = "CrossAttn" in block_type or "Attn" in block_type
            num_heads = max(1, out_ch // attention_head_dim)
            head_dim = attention_head_dim
            db = DownBlock3D(prev_ch, out_ch, time_emb_dim=time_embedding_dim,
                             use_attn=use_attn, num_heads=num_heads, head_dim=head_dim, cross_attention_dim=cross_attention_dim if use_attn else None)
            self.down_blocks.append(db)
            # add downsample except for last block
            if i != len(block_out_channels) - 1:
                self.downsamplers.append(Downsample3D(out_ch))
            prev_ch = out_ch

        # middle (bottleneck)
        mid_ch = block_out_channels[-1]
        self.mid_block1 = ResidualBlock3D(mid_ch, mid_ch, time_emb_dim=time_embedding_dim)
        self.mid_attn = AttentionBlock3D(mid_ch, num_heads=max(1, mid_ch // attention_head_dim),
                                         head_dim=attention_head_dim, cross_attention_dim=cross_attention_dim)
        self.mid_block2 = ResidualBlock3D(mid_ch, mid_ch, time_emb_dim=time_embedding_dim)

        # build up blocks
        self.upsamplers = nn.ModuleList()
        self.up_blocks = nn.ModuleList()
        rev_out = list(reversed(block_out_channels))
        prev_ch = rev_out[0]
        for i, out_ch in enumerate(rev_out):
            block_type = up_block_types[i]
            use_attn = "CrossAttn" in block_type or "Attn" in block_type
            num_heads = max(1, out_ch // attention_head_dim)
            head_dim = attention_head_dim
            # input channels for up block is prev_ch (current decoder) + skip channels (out_ch)
            in_ch = prev_ch + out_ch
            ub = UpBlock3D(in_ch, out_ch, time_emb_dim=time_embedding_dim,
                           use_attn=use_attn, num_heads=num_heads, head_dim=head_dim, cross_attention_dim=cross_attention_dim if use_attn else None)
            self.up_blocks.append(ub)
            if i != len(rev_out) - 1:
                self.upsamplers.append(Upsample3D(out_ch))
            prev_ch = out_ch

        # final normalization and conv
        self.norm_out = nn.GroupNorm(8, block_out_channels[0])
        self.conv_out = nn.Conv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)

    def forward(self, sample: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None):
        if timestep.dim() == 0:
            timestep = timestep.view(1).expand(sample.shape[0])
        t_emb = timestep_embedding(timestep, self.time_embedding_dim, max_period=10000)
        t_emb = self.time_mlp(t_emb)

        # input conv
        x = self.conv_in(sample)
        # print("After conv_in:", x.shape)  # <-- dimension C après la première conv

        # store skips
        skips = []

        # down path
        for i, db in enumerate(self.down_blocks):
            x = db(x, t_emb, encoder_hidden_states)
            # print(f"After down_block {i}:", x.shape)  # <-- dimension C après chaque down_block
            skips.append(x)
            if i < len(self.downsamplers):
                x = self.downsamplers[i](x)
                # print(f"After downsampler {i}:", x.shape)  # <-- dimension C après chaque downsampler

        # mid
        x = self.mid_block1(x, t_emb)
        # print("After mid_block1:", x.shape)
        x = self.mid_attn(x, encoder_hidden_states)
        # print("After mid_attn:", x.shape)
        x = self.mid_block2(x, t_emb)
        # print("After mid_block2:", x.shape)

        # up path
        for i, ub in enumerate(self.up_blocks):
            skip = skips.pop()
            x = ub(x, skip, t_emb, encoder_hidden_states)
            # print(f"After up_block {i}:", x.shape)  # <-- dimension C après chaque up_block
            if i < len(self.upsamplers):
                x = self.upsamplers[i](x)
                # print(f"After upsampler {i}:", x.shape)  # <-- dimension C après chaque upsampler

        x = self.norm_out(x)
        # print("After norm_out:", x.shape)
        x = F.silu(x)
        x = self.conv_out(x)
        # print("After conv_out:", x.shape)

        return x



patch_size = (80,96,80)

model_diffusion = UNet3DConditionModel_maison(
    sample_size=patch_size,
    in_channels=3,
    out_channels=1,
    down_block_types=("DownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D"),
    up_block_types=("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "UpBlock3D"),
    block_out_channels=(64, 128, 256, 256),
    cross_attention_dim=512,
    attention_head_dim=64,
    time_embedding_dim=512,
)

## Training

In [None]:
# ---------- HYPERPARAMS ----------
num_epochs = 3000
lr = 1e-4
grad_accum_steps = 1  
eval_every_epoch = 10
save_every_epoch = 10
checkpoint_dir = "folder_to_save/checkpoint"
resume_from = None

In [None]:
accelerator = Accelerator(mixed_precision="fp16")  # enable FP16 training for speed/memory

n_classes = len(all_labels)
noise_scheduler = DDIMScheduler(num_train_timesteps=1000)
embedder = nn.Embedding(n_classes + 1, 512)  # class-conditioning embedding
optimizer = Adam(
    list(model_diffusion.parameters()) + list(embedder.parameters()),
    lr=lr,
    eps=1e-8
)

# Wrap all components for distributed / mixed-precision execution
model_diffusion, embedder, optimizer, train_loader, test_loader = accelerator.prepare(
    model_diffusion, embedder, optimizer, train_loader, test_loader
)

# Initialize DDIM sampling steps
noise_scheduler.set_timesteps(noise_scheduler.num_train_timesteps)

mse_loss = nn.MSELoss()

# Load checkpoint after accelerator.prepare (important for wrapped models)
start_epoch, global_step = load_checkpoint_if_exists(
    resume_from, model_diffusion, embedder, optimizer, accelerator
)


In [None]:
# Logger TensorBoard (Main process only)
# -------------------------
writer = None
if accelerator.is_main_process:
    tb_log_dir = "path_to/log_dir"
    os.makedirs(tb_log_dir, exist_ok=True)
    writer = SummaryWriter(tb_log_dir)
    print(f"[logger] TensorBoard writer created at {tb_log_dir}")

In [None]:
# -------------------------
# TRAINING
# -------------------------

for epoch in range(start_epoch, num_epochs):
    model_diffusion.train()
    epoch_loss = 0.0

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader),
                        desc=f"Epoch {epoch}", ncols=120)

    for step_in_epoch, batch in progress_bar:

        global_step += 1
        volumes, labels = batch
        volumes = volumes.to(accelerator.device, non_blocking=True).float()

        # Gamma transformation to prevent anatomical map from encoding gamma information
        volumes_transform = volumes.clone()
        volumes_transform_gamma = torch.stack([
            tio.RandomGamma(log_gamma=(-0.5, 0.5), p=1.0)(
                tio.ScalarImage(tensor=img)
            ).data
            for img in volumes_transform
        ])

        batch_size = volumes.shape[0]

        # Random sigma jittering for anatomical map
        g1_sigma = random.uniform(0.1, 0.7)
        g2_sigma = random.uniform(0.7, 1.4)
        hf_sigma = random.uniform(0.1, 1.0)

        volume_anat_map = make_structural_anatomy_map(
            volumes_transform_gamma,
            grad_sigmas=(g1_sigma, g2_sigma),
            hf_sigma=hf_sigma,
            smooth_sigma=0.0,
            normalize_percentiles=(0.5, 99.5)
        )

        anat_coarse = F.interpolate(
            volume_anat_map.float(),
            size=patch_size,
            mode='trilinear',
            align_corners=False
        )

        # Paired anatomical + image patch extraction
        patch_volume, patch_anat_map = random_crop_3d_pair(
            volumes, volume_anat_map, patch_size=patch_size
        )

        timesteps = torch.randint(
            0,
            noise_scheduler.num_train_timesteps,
            (batch_size,),
            device=accelerator.device,
            dtype=torch.long,
        )

        noise = torch.randn_like(patch_volume)
        noisy_latents = noise_scheduler.add_noise(patch_volume, noise, timesteps)

        # Class-conditioning (with classifier-free masking)
        label_ids = torch.tensor(
            [ds2id[l] for l in labels],
            device=accelerator.device,
            dtype=torch.long
        )

        p_uncond = 0.15
        mask = torch.rand(label_ids.shape, device=accelerator.device) < p_uncond
        label_ids_masked = label_ids.clone()
        if mask.any():
            label_ids_masked[mask] = 0

        label_embedding = embedder(label_ids_masked)
        label_embedding = label_embedding.unsqueeze(1)  # (B, 1, embed_dim)

        # Model input: noisy patch + fine anatomy + coarse anatomy
        model_input = torch.cat([noisy_latents, patch_anat_map, anat_coarse], dim=1)
        model_output = model_diffusion(
            model_input, timesteps, encoder_hidden_states=label_embedding
        )

        noise_pred = model_output
        loss = mse_loss(noise_pred.float(), noise.float())
        epoch_loss += loss.item()

        accelerator.backward(loss)
        if accelerator.sync_gradients:
            optimizer.step()
            optimizer.zero_grad()

        avg_loss = epoch_loss / float(step_in_epoch + 1)
        progress_bar.set_postfix({"loss": f"{avg_loss:.6f}"})

    # TensorBoard logging (main process only)
    if accelerator.is_main_process and writer is not None:
        writer.add_scalar("train/loss_epoch", epoch_loss / len(train_loader), epoch)

    # -------------------------
    # EVALUATION
    # -------------------------

    if (epoch + 1) % eval_every_epoch == 0 or epoch == start_epoch:
        model_diffusion.eval()
        with torch.no_grad():

            # Take a single batch for eval
            for eval_batch in test_loader:
                eval_volumes, eval_labels = eval_batch
                eval_volumes = eval_volumes.to(accelerator.device, non_blocking=True).float()
                break

            # Build eval anatomical maps
            eval_volume_anat_map = make_structural_anatomy_map(
                eval_volumes,
                grad_sigmas=(0.3, 1.0),
                hf_sigma=1.0,
                smooth_sigma=0.0,
                normalize_percentiles=(0.5, 99.5)
            )

            eval_anat_coarse = F.interpolate(
                eval_volume_anat_map.float(),
                size=patch_size,
                mode='nearest'
            )

            patch_volume, patch_anat_map = random_crop_3d_pair(
                eval_volumes, eval_volume_anat_map, patch_size=patch_size
            )

            # Embeddings for conditional & unconditional sampling
            label_ids_eval = torch.tensor(
                [ds2id[l] for l in eval_labels],
                device=accelerator.device,
                dtype=torch.long
            )
            cond_label_embedding = embedder(label_ids_eval).unsqueeze(1)

            uncond_ids = torch.zeros_like(label_ids_eval, dtype=torch.long,
                                          device=label_ids_eval.device)
            uncond_label_embedding = embedder(uncond_ids).unsqueeze(1)

            batch_size_eval = eval_volumes.shape[0]
            num_inference_steps = 50
            noise_scheduler.set_timesteps(num_inference_steps)
            timesteps_iter = list(noise_scheduler.timesteps)

            gaussian_noise = torch.randn_like(patch_volume).to(accelerator.device)

            num_train_timesteps = noise_scheduler.config.num_train_timesteps
            num_inference_steps = noise_scheduler.num_inference_steps
            step_offset = num_train_timesteps // num_inference_steps

            alphas_cumprod = noise_scheduler.alphas_cumprod.to(accelerator.device)
            final_alpha_cumprod = noise_scheduler.final_alpha_cumprod.to(accelerator.device)

            # DDIM-like inference loop
            for t in timesteps_iter:
                t_b = torch.tensor(
                    [int(t)] * batch_size_eval,
                    device=accelerator.device,
                    dtype=torch.long
                )

                # Duplicate input for cond/uncond
                volume_in = torch.cat([gaussian_noise, gaussian_noise], dim=0)
                anatomy_in = torch.cat([patch_anat_map, patch_anat_map], dim=0)
                coarse_in = torch.cat([eval_anat_coarse, eval_anat_coarse], dim=0)

                model_input_eval = torch.cat([volume_in, anatomy_in, coarse_in], dim=1)

                emb_in = torch.cat([uncond_label_embedding, cond_label_embedding], dim=0)
                t_in = torch.cat([t_b, t_b], dim=0)

                # Model prediction
                model_output_eval = model_diffusion(
                    model_input_eval, t_in, encoder_hidden_states=emb_in
                )

                uncond_noise_pred, conde_noise_pred = model_output_eval.chunk(2, dim=0)

                # Classifier-free guidance
                prev_t = t - step_offset
                alpha_t = alphas_cumprod[t]
                alpha_prev = alphas_cumprod[prev_t] if prev_t >= 0 else final_alpha_cumprod
                beta_t = 1 - alpha_t

                guidance_scale = 1.0
                pred = uncond_noise_pred + guidance_scale * (conde_noise_pred - uncond_noise_pred)

                # DDIM prediction of x0
                x0_pred = (gaussian_noise - beta_t**0.5 * pred) / alpha_t**0.5

                coeff_dir = (1 - alpha_prev)**0.5
                pred_sample_direction = coeff_dir * uncond_noise_pred

                # Sample at next timestep
                gaussian_noise = alpha_prev**0.5 * x0_pred + pred_sample_direction

            diffused_latents = gaussian_noise

            # -------------------------
            # VISUALIZATION
            # -------------------------
            patch_volume_np = patch_volume[0,0].cpu().numpy()
            patch_anat_map_np = patch_anat_map[0,0].cpu().numpy()
            diffused_latents_np = diffused_latents[0,0].cpu().numpy()

            volume_s1, volume_s2, volume_s3 = get_3_slices(patch_volume_np)
            anat_s1, anat_s2, anat_s3 = get_3_slices(patch_anat_map_np)
            gen_recon_s1, gen_recon_s2, gen_recon_s3 = get_3_slices(diffused_latents_np)

            fig, axes = plt.subplots(3, 3, figsize=(12, 12))
            plt.subplots_adjust(wspace=0.05, hspace=0.2)

            row_titles = [
                "Original (3 slices)",
                "Anat map (3 slices)",
                "Reconstruction (3 slices)",
            ]

            slice_sets = [
                (volume_s1, volume_s2, volume_s3),
                (anat_s1, anat_s2, anat_s3),
                (gen_recon_s1, gen_recon_s2, gen_recon_s3)
            ]

            for r in range(3):
                s1, s2, s3 = slice_sets[r]
                for c, im in enumerate([s1, s2, s3]):
                    ax = axes[r, c]
                    im_show = np.rot90(im)

                    # Dynamic contrast normalization
                    p1, p99 = np.percentile(im_show, [1, 99])
                    if p1 == p99:
                        ax.imshow(im_show, cmap="gray")
                    else:
                        im_clipped = np.clip(im_show, p1, p99)
                        im_norm = (im_clipped - p1) / (p99 - p1 + 1e-5)
                        ax.imshow(im_norm, cmap="gray", vmin=0, vmax=1)

                    ax.axis("off")
                    if c == 1:
                        ax.set_title(row_titles[r], fontsize=10)

            vis_fname = os.path.join(checkpoint_dir, f"vis_epoch{epoch:04d}_full_inference.png")

            if accelerator.is_main_process:
                if writer is not None:
                    writer.add_figure("eval/full_inference_visualization", fig, epoch)
            plt.close(fig)

    # -------------------------
    # CHECKPOINTING
    # -------------------------
    if (epoch + 1) % save_every_epoch == 0 or (epoch == num_epochs - 1):
        save_checkpoint(
            epoch + 1,
            global_step,
            model_diffusion,
            embedder,
            optimizer,
            checkpoint_dir,
            accelerator
        )

if accelerator.is_main_process and writer is not None:
    writer.close()
