# Medical Image Denoising with DDPM

## Device Details

In [None]:
import torch
import tensorflow as tf

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
print("Num GPUs Available:", len(tf.config.list_physical_devices('GPU')))

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from datasets import load_dataset, concatenate_datasets
from sklearn.model_selection import train_test_split
import torch
from torchvision.transforms import Compose, Resize, Lambda
from PIL import Image
from torch.utils.data import DataLoader,TensorDataset
import math
from diffusers import DDPMPipeline
import os
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from accelerate import Accelerator
from skimage.metrics import structural_similarity as ssim
from skimage.restoration import denoise_nl_means, estimate_sigma
from tabulate import tabulate

In [None]:
!pip install diffusers transformers accelerate datasets

## Dataset Loading and Visualization

In [None]:
def load_and_split_dataset(dataset_name="iamkzntsv/IXI2D", train_ratio=0.8, random_seed=42):
    """
    Load ALL splits from the dataset, combine them using concatenate_datasets, then split.

    Args:
        dataset_name: Name of the Hugging Face dataset
        train_ratio: Proportion of data for training (0.8 = 80% train, 20% test)
        random_seed: Random seed for reproducible splits

    Returns:
        full_dataset: Combined dataset from all splits
        train_indices: Indices for training data
        test_indices: Indices for testing data
    """
    print(f"Loading dataset: {dataset_name}...")

    # Load the dataset
    dataset = load_dataset(dataset_name)

    # Get each split (using .get() to handle missing splits gracefully)
    train_data = dataset.get('train', None)
    validation_data = dataset.get('validation', None)
    test_data = dataset.get('test', None)

    # Count images in each split
    train_count = len(train_data) if train_data is not None else 0
    validation_count = len(validation_data) if validation_data is not None else 0
    test_count = len(test_data) if test_data is not None else 0

    # Print the counts
    print(f"Number of images in training set: {train_count}")
    print(f"Number of images in validation set: {validation_count}")
    print(f"Number of images in test set: {test_count}")

    # Combine the datasets step by step
    full_dataset = None

    if train_data is not None:
        full_dataset = train_data

    if validation_data is not None:
        if full_dataset is not None:
            full_dataset = concatenate_datasets([full_dataset, validation_data])
        else:
            full_dataset = validation_data

    if test_data is not None:
        if full_dataset is not None:
            full_dataset = concatenate_datasets([full_dataset, test_data])
        else:
            full_dataset = test_data

    # Check if we have any data
    if full_dataset is None:
        raise ValueError("No data found in any of the splits!")

    print(f"Total number of images: {len(full_dataset)}")

    # Create indices for splitting
    total_samples = len(full_dataset)
    indices = list(range(total_samples))

    # Split the indices
      #shuffle=True + same seed → The split will stay identical every run.
      #shuffle=True + different seed → The split will change each run.

    train_indices, test_indices = train_test_split(
        indices,
        train_size=train_ratio,
        random_state=random_seed,
        shuffle=True
    )

    print(f"Final split - Train: {len(train_indices)}, Test: {len(test_indices)}")

    return full_dataset, train_indices, test_indices

In [None]:
def display_images(dataset, indices, num_images=5, cols=5):
    rows = (num_images + cols - 1) // cols
    fig, axes = plt.subplots(rows, cols, figsize=(15, 3 * rows))
    axes = np.array(axes).reshape(-1)  # Flatten for easy indexing

    for i in range(num_images):
        if i < len(indices):
            img = dataset[indices[i]]['image']
            img_np = np.array(img)

            if img_np.ndim == 2:  # true grayscale
                axes[i].imshow(img_np, cmap='gray')
            elif img_np.ndim == 3 and np.all(img_np[..., 0] == img_np[..., 1]) and np.all(img_np[..., 1] == img_np[..., 2]):
                axes[i].imshow(img_np[..., 0], cmap='gray')  # fake RGB but grayscale
            else:
                axes[i].imshow(img_np)  # real RGB

            axes[i].set_title(f'Image {indices[i]}')
            axes[i].axis('off')
        else:
            axes[i].axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
def show_image_statistics(dataset, index):
    """
    Display true statistics of an image in the dataset without any modifications.

    Args:
        dataset: Hugging Face dataset object containing 'image'
        index: Index of the image in the dataset
    """
    # Extract as PIL
    # PIL = Python Imaging Library (now Pillow) — used to handle images easily in Python.
    # Useful in Many transforms (like resize, crop, convert to grayscale) work naturally on PIL images.
    # It also makes it easy to inspect image metadata (mode, size, format),
    #In addition, it avoids accidentally altering the image’s data type or pixel range before statistics are printed.
    image_pil = dataset[index]['image']

    # Convert to NumPy without altering mode
    image_np = np.array(image_pil)

    print(f"--- Image Statistics (Index: {index}) ---")
    print(f"Mode          : {image_pil.mode}")  # e.g., 'RGB', 'L'
    print(f"Shape         : {image_np.shape}")
    print(f"Data type     : {image_np.dtype}")
    print(f"Min value     : {image_np.min()}")
    print(f"Max value     : {image_np.max()}")
    print(f"Mean value    : {image_np.mean():.4f}")
    print(f"Std deviation : {image_np.std():.4f}")
    print(f"Median value  : {np.median(image_np):.4f}")
    print(f"Unique values : {len(np.unique(image_np))}")

    # Show image in true color if RGB
    if image_pil.mode == 'RGB':
        plt.imshow(image_np)
    else:
        plt.imshow(image_np, cmap='gray')

    plt.title(f"Image {index}")
    plt.axis('off')
    plt.show()


## Pre-processing

In [None]:
def preprocess_and_save_dataset(dataset, indices, target_size=(128, 128), normalize_type="unsigned", save_path="/content/preprocessed_dataset.pt"):
    """
    Preprocess the dataset immediately and save to disk.

    Args:
        dataset: HuggingFace dataset object containing 'image'
        indices: List of indices to preprocess
        target_size: (width, height) to resize images
        normalize_type:
            "unsigned" → normalize to [0, 1] range
            "signed"   → normalize to [-1, 1] range
        save_path: Path to save the tensor file
    """
    if normalize_type not in ["unsigned", "signed"]:
        raise ValueError("normalize_type must be 'unsigned' (0 –> 1 range) or 'signed' (-1 -> 1 range)")

    print(f"\n[INFO] Starting preprocessing of {len(indices)} images...")
    print(f"[INFO] Target size: {target_size}, Normalization: '{normalize_type}'")

    # Define normalization transformation
    if normalize_type == "unsigned":
        norm_fn = Lambda(lambda img: torch.tensor(np.array(img)).unsqueeze(0).float() / 255.0)
    else:  # signed (-1, 1)
        norm_fn = Lambda(lambda img: torch.tensor(np.array(img)).unsqueeze(0).float() / 127.5 - 1.0)

    transforms = Compose([
        Resize(target_size, interpolation=Image.Resampling.BILINEAR),
        Lambda(lambda img: img.convert("L")),  # Convert to grayscale
        norm_fn
    ])

    preprocessed_images = []
    for idx in indices:
        img = dataset[idx]["image"]
        img_tensor = transforms(img)
        preprocessed_images.append(img_tensor)

    preprocessed_tensor = torch.stack(preprocessed_images)
    torch.save(preprocessed_tensor, save_path)

    print(f"[INFO] Finished preprocessing. Dataset saved to: {save_path}")
    print(f"[INFO] Tensor shape: {preprocessed_tensor.shape}, dtype: {preprocessed_tensor.dtype}\n")
    return preprocessed_tensor



In [None]:
def display_preprocessed_images(tensor, num_images=5, cols=None):
    """
    Display images from preprocessed tensor with adaptive layout.

    Args:
        tensor: Preprocessed tensor of shape (N, 1, H, W)
        num_images: Number of images to display
        cols: Number of columns (if None, auto-select)
    """
    # tensor = tensor.cpu() : Moves a tensor from GPU memory to CPU memory as
    # it is needed to convert it to a NumPy array (.numpy()), NumPy only works on CPU tensors,
    # it is also neededfor plotting or saving, in addition displaying an image in matplotlib (which works on CPU).
    # No modification needed for Colab T4 — works the same on GPU or CPU.

    tensor = tensor.cpu()

    num_images = min(num_images, len(tensor))

    if cols is None:
        cols = min(num_images, 5)  # auto-limit to max 5 per row
    rows = (num_images + cols - 1) // cols

    fig, axes = plt.subplots(rows, cols, figsize=(3 * cols, 3 * rows))
    axes = np.atleast_1d(axes).flatten()

    for i in range(len(axes)):
        if i < num_images:
            img = tensor[i, 0].numpy()
            axes[i].imshow(img, cmap='gray')
            axes[i].set_title(f"Image {i}")
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()


def show_preprocessed_statistics(tensor, index):
    """
    Show statistics of a single preprocessed image tensor.
    """
    img = tensor[index, 0].cpu().numpy()

    print(f"--- Preprocessed Image Statistics (Index: {index}) ---")
    print(f"Shape         : {img.shape}")
    print(f"Data type     : {img.dtype}")
    print(f"Min value     : {img.min():.4f}")
    print(f"Max value     : {img.max():.4f}")
    print(f"Mean value    : {img.mean():.4f}")
    print(f"Std deviation : {img.std():.4f}")
    print(f"Median value  : {np.median(img):.4f}")
    print(f"Unique values : {len(np.unique(img))}")

    plt.imshow(img, cmap='gray')
    plt.title(f"Preprocessed Image {index}")
    plt.axis('off')
    plt.show()

## Data Loaders

In [None]:
def get_data_loaders(
    train_path="/content/train_preprocessed.pt",
    test_path="/content/test_preprocessed.pt",
    batch_size=16,
    shuffle_train=True,
    num_workers=2,
    pin_memory=True
):
    """
    Load preprocessed datasets from disk and return train/test DataLoaders.

    Args:
        train_path: Path to saved preprocessed training tensor
        test_path: Path to saved preprocessed testing tensor
        batch_size: Batch size for loaders
        shuffle_train: Whether to shuffle training data
        num_workers: DataLoader num_workers
        pin_memory: Whether to use pinned memory (speeds up GPU transfer)

    Returns:
        train_loader, test_loader
    """
    # Load preprocessed tensors
    train_tensor = torch.load(train_path)
    test_tensor = torch.load(test_path)

    print(f"[INFO] Loaded train tensor: {train_tensor.shape}")
    print(f"[INFO] Loaded test tensor: {test_tensor.shape}")

    # Wrap in TensorDataset (no labels for diffusion models)
    train_dataset = TensorDataset(train_tensor)
    test_dataset = TensorDataset(test_tensor)

    # Create DataLoaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=False,#shuffle_train
        num_workers=num_workers,
        pin_memory=pin_memory
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=pin_memory
    )

    print(f"[INFO] Train batches: {len(train_loader)}, Test batches: {len(test_loader)}")
    return train_loader, test_loader


## Forward Diffusion

In [None]:
def make_beta_schedule(T=1000, beta_start=1e-4, beta_end=0.02, device=None):
    """
    Create a linear beta schedule and precompute useful terms.
    Returns a dict containing tensors on `device`.
    """
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    #creates a vector of size T that starts with beta_start and ends with beta_end
    #this vector is used to later get the t multiplication of the different alpha_ts, please notice that t is the timestamps which ranges (0<t<T)
    betas = torch.linspace(beta_start, beta_end, T, device=device, dtype=torch.float32) #this gets me the linearly spaced betas
    alphas = 1.0 - betas
    alphas_cumprod = torch.cumprod(alphas, dim=0)             # \bar{\alpha}_t
    sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)          # sqrt(\bar{\alpha}_t)
    sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - alphas_cumprod)  # sqrt(1 - \bar{\alpha}_t)

    return {
        "T": T,
        "betas": betas,
        "alphas": alphas,
        "alphas_cumprod": alphas_cumprod,
        "sqrt_alphas_cumprod": sqrt_alphas_cumprod,
        "sqrt_one_minus_alphas_cumprod": sqrt_one_minus_alphas_cumprod,
        "device": device
    }


In [None]:
def inspect_beta_schedule(schedule, num_values=10):
    """
    Prints out the first few values of the beta schedule and related terms
    to help understand their behavior.

    Args:
        schedule: dict returned by make_beta_schedule
        num_values: how many values to print from the start and end
    """
    print(f"Total timesteps (T): {schedule['T']}")
    print(f"Device: {schedule['device']}")
    print("------------------------------------------------------------------------------------------------")

    def print_tensor_info(name, tensor):
        tensor_cpu = tensor.detach().cpu()
        print(f"{name}:")
        print("Values:", tensor_cpu[:num_values].numpy())
        # print("  last values :", tensor_cpu[-num_values:].numpy())
        print("------------------------------------------------------------------------------------------------")

    print_tensor_info("betas", schedule["betas"])
    print_tensor_info("alphas", schedule["alphas"])
    print_tensor_info("alphas_cumprod", schedule["alphas_cumprod"])
    print_tensor_info("sqrt_alphas_cumprod", schedule["sqrt_alphas_cumprod"])
    print_tensor_info("sqrt_one_minus_alphas_cumprod", schedule["sqrt_one_minus_alphas_cumprod"])


In [None]:
def forward_diffusion_batch(x_start, schedule, timesteps=None):
    """
    Vectorized forward diffusion for a batch.
    Args:
      - x_start: tensor [B, C, H, W] (float).
      - schedule: dict returned by make_beta_schedule
      - timesteps: optional LongTensor [B] of timesteps (0..T-1). If None, sampled uniformly.
    Returns:
      - x_t: noisy images tensor [B, C, H, W]
      - noise: the Gaussian noise added [B, C, H, W]
      - timesteps: LongTensor [B]
    Notes:
      - x_start can be normalized in form of [0,1] or [-1,1]; formula is invariant. Make sure model and loss use same range.
    """
    device = schedule["device"]
    T = schedule["T"]
    sqrt_alphas_cumprod = schedule["sqrt_alphas_cumprod"]   # shape [T]
    sqrt_one_minus_alphas_cumprod = schedule["sqrt_one_minus_alphas_cumprod"]

    x_start = x_start.to(device)
    B = x_start.shape[0]

    if timesteps is None:
        timesteps = torch.randint(1, T, (B,), device=device, dtype=torch.long)

    # gather scalars per batch element and reshape to [B,1,1,1]
    a_t = sqrt_alphas_cumprod[timesteps].view(B, 1, 1, 1)
    b_t = sqrt_one_minus_alphas_cumprod[timesteps].view(B, 1, 1, 1)

    noise = torch.randn_like(x_start, device=device)
    x_t = a_t * x_start + b_t * noise

    return x_t, noise, timesteps

In [None]:
def show_sample_pair(x_clean, x_noisy, idx=0, title_clean="clean", title_noisy="noisy"):
    """
    Helper to visualize a single sample (grayscale tensors).
    Accepts tensors in either [0,1] or [-1,1] range.
    """
    def to_display(img_tensor):
        img = img_tensor.detach().cpu().squeeze().astype if False else None
        # convert to CPU numpy
        img_np = img_tensor.detach().cpu().numpy()
        # channel-first (1,H,W) -> (H,W)
        if img_np.ndim == 3 and img_np.shape[0] == 1:
            img_np = img_np[0]
        # detect range: if values < 0, assume [-1,1] -> convert to [0,1]
        if img_np.min() < 0:
            img_np = (img_np + 1.0) / 2.0
        # clip to [0,1] for display
        img_np = img_np.clip(0.0, 1.0)
        return img_np

    clean = to_display(x_clean[idx])
    noisy = to_display(x_noisy[idx])

    plt.figure(figsize=(6,3))
    plt.subplot(1,2,1)
    plt.imshow(clean, cmap="gray")
    plt.title(title_clean)
    plt.axis("off")
    plt.subplot(1,2,2)
    plt.imshow(noisy, cmap="gray")
    plt.title(title_noisy)
    plt.axis("off")
    plt.show()

In [None]:
def _extract_images_from_batch(batch):
    """Return the first tensor in `batch` that looks like images with shape [B, C, H, W]."""
    if isinstance(batch, torch.Tensor):
        return batch
    if isinstance(batch, (list, tuple)):
        for item in batch:
            if isinstance(item, torch.Tensor) and item.ndim == 4:
                return item
    if isinstance(batch, dict):
        # common HF style: batch['image']
        if "image" in batch and isinstance(batch["image"], torch.Tensor):
            return batch["image"]
        # otherwise return first tensor-like value with ndim==4
        for v in batch.values():
            if isinstance(v, torch.Tensor) and v.ndim == 4:
                return v
    raise ValueError("Could not find image tensor in batch. Expected shape [B,C,H,W].")

def _to_display_numpy(img_tensor):
    """
    Convert a single image tensor [C,H,W] -> numpy [H,W] or [H,W,3] in range [0,1] for imshow.
    Handles grayscale (C==1) and signed/unsigned ranges ([-1,1] or [0,1]).
    """
    img = img_tensor.detach().cpu().numpy()
    # if channel-first grayscale (1,H,W) -> (H,W)
    if img.ndim == 3 and img.shape[0] == 1:
        img = img[0]
    elif img.ndim == 3 and img.shape[0] == 3:
        img = np.transpose(img, (1,2,0))
    # detect range: if values < -0.1 assume [-1,1], else assume [0,1]
    if img.min() < -0.1:
        img = (img + 1.0) / 2.0
    # clip to [0,1]
    img = np.clip(img, 0.0, 1.0)
    return img

def show_original_vs_noisy(
    batch,
    schedule,
    forward_fn,
    device=None,
    num_images=4,
    timesteps=None,
    random_seed=None
):
    """
    Display original vs noisy images for the first `num_images` in `batch`.

    Args:
      - batch: output from DataLoader (tensor, tuple/list, or dict)
      - schedule: dict from make_beta_schedule(...)
      - forward_fn: function(x_start, schedule, timesteps=None) -> (x_t, noise, timesteps)
      - device: torch.device or None (will use schedule['device'] if None)
      - num_images: how many items from batch to show
      - timesteps: optional: int or list/1D-tensor of length B to force timesteps (otherwise sampled inside forward_fn)
      - random_seed: optional seed for reproducibility of timesteps/noise
    Returns:
      - (x_clean, x_noisy, timesteps_used) as tensors (on schedule['device'])
    """
    if device is None:
        device = schedule.get("device", torch.device("cuda" if torch.cuda.is_available() else "cpu"))

    # Extract images tensor (expects shape [B,C,H,W])
    images = _extract_images_from_batch(batch)
    images = images.to(device)

    B = images.shape[0]
    num_show = min(num_images, B)

    # prepare timesteps argument for forward_fn (if provided)
    if timesteps is not None:
        # allow int or list/array/tensor
        if isinstance(timesteps, int):
            timesteps_arg = torch.full((B,), timesteps, dtype=torch.long, device=device)
        else:
            timesteps_arg = torch.as_tensor(timesteps, dtype=torch.long, device=device)
            if timesteps_arg.ndim == 0:
                timesteps_arg = timesteps_arg.repeat(B)
            if timesteps_arg.numel() != B:
                raise ValueError("timesteps length must equal batch size if provided.")
    else:
        timesteps_arg = None

    # optional reproducibility
    if random_seed is not None:
        torch.manual_seed(random_seed)

    # call forward diffusion (this returns x_t, noise, timesteps)
    x_t, noise, timesteps_used = forward_fn(images, schedule, timesteps_arg)

    # Move to CPU for plotting selected images
    images_cpu = images[:num_show].detach().cpu()
    x_t_cpu = x_t[:num_show].detach().cpu()
    timesteps_cpu = timesteps_used[:num_show].detach().cpu()

    # plot
    fig, axes = plt.subplots(num_show, 2, figsize=(6, 3 * num_show))
    if num_show == 1:
        axes = np.array([axes])  # make indexing consistent

    for i in range(num_show):
        orig_np = _to_display_numpy(images_cpu[i])
        noisy_np = _to_display_numpy(x_t_cpu[i])
        ax_orig = axes[i, 0]
        ax_noisy = axes[i, 1]

        if orig_np.ndim == 2:
            ax_orig.imshow(orig_np, cmap="gray")
            ax_noisy.imshow(noisy_np, cmap="gray")
        else:
            ax_orig.imshow(orig_np)
            ax_noisy.imshow(noisy_np)

        ax_orig.set_title(f"Original (idx {i})")
        ax_noisy.set_title(f"Noisy (t={int(timesteps_cpu[i].item())})")
        ax_orig.axis("off"); ax_noisy.axis("off")

    plt.tight_layout()
    plt.show()

    return images, x_t, timesteps_used

## Training Model : benetraco/brain_ddpm_128 model from Hugging Face


In [None]:
def show_training_step_visuals(x_start, x_t, pred_noise, t, idx=0, clip_range=(0, 1)):
    """
    Visualize clean, noisy, and denoised images for a sample in the batch.
    """
    import numpy as np
    import matplotlib.pyplot as plt
    import torch

    def to_numpy_img(tensor_img):
        img = tensor_img.detach().cpu().numpy()
        # [C,H,W] → [H,W] for grayscale
        if img.shape[0] == 1:
            img = img[0]
        img = np.clip(img, clip_range[0], clip_range[1])
        if clip_range != (0, 1):
            img = (img - clip_range[0]) / (clip_range[1] - clip_range[0])
        return img

    clean_img = to_numpy_img(x_start[idx])
    noisy_img = to_numpy_img(x_t[idx])

    # Get sqrt(alpha_cumprod) and sqrt(1 - alpha_cumprod) for this timestep
    if isinstance(t, torch.Tensor):
        t_val = t[idx] if t.ndim > 0 else t
    else:
        t_val = torch.tensor(t)

    a_t = torch.sqrt(schedule["alphas_cumprod"][t_val]).to(x_t.device)
    b_t = torch.sqrt(1 - schedule["alphas_cumprod"][t_val]).to(x_t.device)

    # Correct denoising formula: x0 = (x_t - b_t * pred_noise) / a_t
    denoised_tensor = (x_t[idx] - b_t * pred_noise[idx]) / a_t
    denoised_img = to_numpy_img(denoised_tensor)

    timestep_value = t_val.item()

    titles = [
        "Original Clean Image",
        f"Noisy Image (t={timestep_value})",
        "Denoised Image"
    ]

    images = [clean_img, noisy_img, denoised_img]

    plt.figure(figsize=(12, 4))
    for i, (img, title) in enumerate(zip(images, titles)):
        plt.subplot(1, 3, i + 1)
        plt.imshow(img, cmap='gray')
        plt.title(title)
        plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
def train_step(model, batch, schedule, optimizer, device):
    """
    One training step:
    - batch: tensor of shape [B, C, H, W], clean images normalized [0,1] or [-1,1]
    - schedule: diffusion beta schedule dict
    - optimizer: optimizer instance
    - device: cuda or cpu
    """
    model.train()
    optimizer.zero_grad()

    batch = batch.to(device)  # move to device

    # Add noise at random timesteps (forward diffusion)
    x_t, noise, timesteps = forward_diffusion_batch(batch, schedule)

    # Predict noise from model
    # Note: The model expects inputs normalized the same way as training
    pred_noise = model(x_t, timesteps).sample

    # Show visuals for the first image in the batch
    show_training_step_visuals(batch, x_t, pred_noise, timesteps, idx=0, clip_range=(-1, 1))

    # Compute MSE loss between predicted noise and true noise
    loss = F.mse_loss(pred_noise, noise)

    # loss.backward()
    # optimizer.step()
    # return loss.item()

    return loss

In [None]:
def train_model(
    model,
    train_loader,
    schedule,
    device,
    epochs=10,
    lr=1e-4,
    checkpoint_dir="./checkpoints",
    save_every=2,
    log_every=10,
    resume_checkpoint=None
):
    """
    Train the diffusion model using HuggingFace Accelerator for mixed precision and device handling.

    Args:
        model: PyTorch model (e.g., brain_ddpm_128).
        train_loader: DataLoader providing training images.
        schedule: diffusion beta schedule dictionary.
        device: 'cuda' or 'cpu' (Accelerator will override this).
        epochs: number of training epochs.
        lr: learning rate.
        checkpoint_dir: directory to save checkpoints.
        save_every: save model every N epochs.
        log_every: log the average loss every M batches.
        resume_checkpoint: path to a checkpoint file to resume from.
    """

    # Init Accelerator
    accelerator = Accelerator(
        mixed_precision="fp16",        # Enable mixed precision training
        gradient_accumulation_steps=1  # Change if you want accumulation
    )

    # Make sure save directory exists
    os.makedirs(checkpoint_dir, exist_ok=True)

    # Optimizer & Scheduler
    optimizer = AdamW(model.parameters(), lr=lr, betas=(0.9, 0.999), weight_decay=1e-4)
    scheduler = CosineAnnealingLR(optimizer, T_max=epochs)

    # Resume from checkpoint if provided
    start_epoch = 0
    if resume_checkpoint is not None and os.path.isfile(resume_checkpoint):
        accelerator.print(f"Resuming from checkpoint: {resume_checkpoint}")
        checkpoint = torch.load(resume_checkpoint, map_location="cpu")
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
        scheduler.load_state_dict(checkpoint["scheduler_state_dict"])
        start_epoch = checkpoint["epoch"] + 1

    # Prepare objects with accelerator
    model, optimizer, train_loader, scheduler = accelerator.prepare(
        model, optimizer, train_loader, scheduler
    )

    # Training loop
    for epoch in range(start_epoch, epochs):
        model.train()
        epoch_loss = 0.0

        for batch_idx, batch in enumerate(train_loader):
            images = batch[0]  # take only images from (images, labels)

            with accelerator.accumulate(model):
                # Compute loss (train_step should return loss tensor now)
                loss = train_step(model, images, schedule, optimizer, device)

                accelerator.backward(loss)
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

            epoch_loss += loss.item()

            if (batch_idx + 1) % log_every == 0 and accelerator.is_main_process:
                accelerator.print(
                    f"[Epoch {epoch+1}/{epochs}] Step {batch_idx+1}/{len(train_loader)}, "
                    f"Avg Loss: {epoch_loss/(batch_idx+1):.6f}"
                )

        # Save checkpoint
        if accelerator.is_main_process and ((epoch + 1) % save_every == 0):
            checkpoint_path = os.path.join(checkpoint_dir, f"model_epoch_{epoch+1}.pth")
            torch.save({
                "epoch": epoch,
                "model_state_dict": accelerator.unwrap_model(model).state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict()
            }, checkpoint_path)
            accelerator.print(f"Model saved to {checkpoint_path}")

        if accelerator.is_main_process:
            accelerator.print(f"Epoch {epoch+1} finished. Average Loss: {epoch_loss/len(train_loader):.6f}")


## Evaluation

In [None]:
def load_latest_checkpoint(model, checkpoint_dir, device):
    """
    Loads the latest checkpoint for the given model from checkpoint_dir.

    Args:
        model: The PyTorch model instance.
        checkpoint_dir: Directory containing checkpoint files.
        device: 'cuda' or 'cpu'.

    Returns:
        model, optimizer, scheduler, scaler, start_epoch
    """
    if not os.path.isdir(checkpoint_dir):
        raise FileNotFoundError(f"Checkpoint directory '{checkpoint_dir}' does not exist.")

    # List all checkpoint files
    checkpoints = [f for f in os.listdir(checkpoint_dir) if f.endswith(".pth")]
    if not checkpoints:
        raise FileNotFoundError("No checkpoint files found in the directory.")

    # Sort by epoch number
    checkpoints.sort(key=lambda x: int(x.split("_")[-1].split(".")[0]))
    latest_checkpoint = checkpoints[-1]
    checkpoint_path = os.path.join(checkpoint_dir, latest_checkpoint)

    print(f"Loading latest checkpoint: {checkpoint_path}")

    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["model_state_dict"])

    # # Create empty placeholders if you don’t need optimizer/scheduler during testing
    # optimizer = None
    # scheduler = None
    # scaler = None

    return model, checkpoint["epoch"]


In [None]:
def compute_snr(clean, denoised):
    """
    Compute Signal-to-Noise Ratio in dB.

    Args:
        clean: Tensor or numpy array, ground truth.
        denoised: Tensor or numpy array, reconstructed output.

    Returns:
        snr_value: float, SNR in decibels.
    """
    if isinstance(clean, torch.Tensor):
        clean = clean.detach().cpu().numpy()
    if isinstance(denoised, torch.Tensor):
        denoised = denoised.detach().cpu().numpy()

    noise = clean - denoised
    signal_power = np.sum(clean ** 2)
    noise_power = np.sum(noise ** 2)

    if noise_power == 0:
        return float("inf")  # Perfect reconstruction
    return 10 * np.log10(signal_power / noise_power)

In [None]:
def compute_ssim(clean, recon):
    """
    Compute average SSIM between clean and reconstructed batch of grayscale images.
    clean, recon: [B,C,H,W] tensors
    """
    B, C, H, W = clean.shape
    assert C == 1, "SSIM is implemented here for grayscale only."

    ssim_vals = []
    for i in range(B):
        ref = clean[i, 0].detach().cpu().numpy()
        test = recon[i, 0].detach().cpu().numpy()
        val = ssim(ref, test, data_range=ref.max() - ref.min())
        ssim_vals.append(val)

    return torch.tensor(np.mean(ssim_vals))  # return tensor

## NLM

In [None]:
def fast_nlm_batch(noisy_batch, patch_size=3, patch_distance=5, h=0.8, fast_mode=True, device="cpu"):
    """
    Apply fast Non-Local Means (NLM) denoising to a batch of grayscale images.
    """
    B, C, H, W = noisy_batch.shape
    assert C == 1, "NLM is best suited for grayscale images (C=1)."

    denoised_list = []
    for i in range(B):
        img = noisy_batch[i, 0].detach().cpu().numpy()  # [H,W]

        sigma_est = np.mean(estimate_sigma(img, channel_axis=None))

        denoised = denoise_nl_means(
            img,
            h=h * sigma_est,
            patch_size=patch_size,
            patch_distance=patch_distance,
            fast_mode=fast_mode,
            channel_axis=None
        )

        denoised_list.append(torch.tensor(denoised, dtype=torch.float32).unsqueeze(0))  # [1,H,W]

    return torch.stack(denoised_list, dim=0).to(device)  # [B,1,H,W]

In [None]:
# def evaluate_model(model, test_loader, schedule, device="cpu", num_batches=1):
#     """
#     Evaluate model performance on noisy test data using SNR.

#     Args:
#         model: trained PyTorch model.
#         test_loader: DataLoader for test dataset.
#         schedule: diffusion beta schedule dict.
#         device: device to run evaluation on.
#         num_batches: number of batches to evaluate (can be 1 for single image batch).

#     Returns:
#         avg_snr: average SNR over evaluated batches.
#     """
#     model.eval()
#     snr_list = []
#     ssim_list =[]

#     with torch.no_grad():
#         for batch_idx, batch in enumerate(test_loader):
#             if batch_idx >= num_batches:
#                 break

#             # Support DataLoader with or without labels
#             images = batch[0] if isinstance(batch, (list, tuple)) else batch
#             images = images.to(device)

#             # Add noise using forward diffusion
#             x_t, noise, timesteps = forward_diffusion_batch(images, schedule)

#             # Model predicts noise
#             pred_noise = model(x_t, timesteps).sample

#             # Reconstruct clean image estimate: x0_hat
#             sqrt_alphas_cumprod = schedule["sqrt_alphas_cumprod"]
#             sqrt_one_minus_alphas_cumprod = schedule["sqrt_one_minus_alphas_cumprod"]

#             a_t = sqrt_alphas_cumprod[timesteps].view(-1, 1, 1, 1)
#             b_t = sqrt_one_minus_alphas_cumprod[timesteps].view(-1, 1, 1, 1)
#             reconstructed_images = (x_t - b_t * pred_noise) / a_t

#             # Compute SNR
#             snr_value = compute_snr(images, reconstructed_images)
#             snr_list.append(float(snr_value))

#             # Compute SNR
#             ssim_value = compute_ssim(images, reconstructed_images)
#             ssim_list.append(float(ssim_value))



#     avg_snr = sum(snr_list) / len(snr_list) if snr_list else 0.0
#     print(f"Average SNR over {len(snr_list)} batch(es): {avg_snr:.4f} dB")

#     avg_ssim = sum(ssim_list) / len(ssim_list) if ssim_list else 0.0
#     print(f"Average SSIM over {len(ssim_list)} batch(es): {avg_ssim:.4f}")

#     return avg_snr, avg_ssim

In [None]:
def evaluate_model(model, test_loader, schedule, device="cuda", num_batches=1, use_nlm=False):
    """
    Evaluate DDPM denoising performance (and optionally compare with NLM).
    Computes SNR & SSIM for both methods.
    """
    model.eval()
    results = {"ddpm": {"snr": [], "ssim": []}}
    if use_nlm:
        results["nlm"] = {"snr": [], "ssim": []}

    with torch.no_grad():
        for batch_idx, batch in enumerate(test_loader):
            if batch_idx >= num_batches:
                break

            # Handle dataset with or without labels
            if isinstance(batch, (list, tuple)):
                images = batch[0].to(device)
            else:
                images = batch.to(device)

            # Forward diffusion
            x_t, noise, timesteps = forward_diffusion_batch(images, schedule)

            # Model prediction
            pred_noise = model(x_t, timesteps).sample


            # Reconstruct clean image estimate: x0_hat
            sqrt_alphas_cumprod = schedule["sqrt_alphas_cumprod"]
            sqrt_one_minus_alphas_cumprod = schedule["sqrt_one_minus_alphas_cumprod"]

            a_t = sqrt_alphas_cumprod[timesteps].view(-1, 1, 1, 1)
            b_t = sqrt_one_minus_alphas_cumprod[timesteps].view(-1, 1, 1, 1)
            recon_ddpm = (x_t - b_t * pred_noise) / a_t

            # Metrics for DDPM
            results["ddpm"]["snr"].append(compute_snr(images, recon_ddpm).item())
            results["ddpm"]["ssim"].append(compute_ssim(images, recon_ddpm).item())

            # Metrics for NLM (if enabled)
            if use_nlm:
                recon_nlm = fast_nlm_batch(x_t, device=device)
                results["nlm"]["snr"].append(compute_snr(images, recon_nlm).item())
                results["nlm"]["ssim"].append(compute_ssim(images, recon_nlm).item())

    # Average results
    for method in results:
        results[method]["snr"] = np.mean(results[method]["snr"])
        results[method]["ssim"] = np.mean(results[method]["ssim"])

    return results


In [None]:
# def visualize_denoising(images, x_t, recon_ddpm, recon_nlm=None, num_images=4):
#     """
#     Visualize original, noisy, DDPM denoised, and optionally NLM denoised images.
#     """
#     num_images = min(num_images, images.shape[0])
#     cols = 3 if recon_nlm is None else 4
#     titles = ["Original", "Noisy", "DDPM Denoised"] + (["NLM Denoised"] if recon_nlm is not None else [])

#     fig, axs = plt.subplots(num_images, cols, figsize=(3*cols, 3*num_images))

#     if num_images == 1:
#         axs = [axs]  # ensure iterable

#     for i in range(num_images):
#         imgs_to_show = [images[i, 0], x_t[i, 0], recon_ddpm[i, 0]]
#         if recon_nlm is not None:
#             imgs_to_show.append(recon_nlm[i, 0])

#         for j, img in enumerate(imgs_to_show):
#             ax = axs[i, j] if num_images > 1 else axs[j]
#             ax.imshow(img.detach().cpu().numpy(), cmap="gray")
#             ax.axis("off")
#             if i == 0:  # top row titles
#                 ax.set_title(titles[j])

#     plt.tight_layout()
#     plt.show()


In [None]:
def visualize_denoising(images, x_t, recon_ddpm, recon_nlm=None, timesteps=None, num_images=4):
    """
    Visualize original, noisy, DDPM denoised, and optionally NLM denoised images.
    Shows timestep t and SNR/SSIM values above each image.
    """
    num_images = min(num_images, images.shape[0])
    cols = 3 if recon_nlm is None else 4
    titles = ["Original", "Noisy", "DDPM Denoised"] + (["NLM Denoised"] if recon_nlm is not None else [])

    fig, axs = plt.subplots(num_images, cols, figsize=(3*cols, 3*num_images))

    if num_images == 1:
        axs = [axs]

    for i in range(num_images):
        clean = images[i:i+1]
        noisy = x_t[i:i+1]
        ddpm = recon_ddpm[i:i+1]
        nlm = recon_nlm[i:i+1] if recon_nlm is not None else None

        # Compute metrics
        snr_ddpm = compute_snr(clean, ddpm)
        ssim_ddpm = compute_ssim(clean, ddpm)

        if nlm is not None:
            snr_nlm = compute_snr(clean, nlm)
            ssim_nlm = compute_ssim(clean, nlm)

        imgs_to_show = [clean[0, 0], noisy[0, 0], ddpm[0, 0]]
        if nlm is not None:
            imgs_to_show.append(nlm[0, 0])

        for j, img in enumerate(imgs_to_show):
            ax = axs[i, j] if num_images > 1 else axs[j]
            ax.imshow(img.detach().cpu().numpy(), cmap="gray")
            ax.axis("off")

            # Top row: fixed titles
            if i == 0:
                ax.set_title(titles[j], fontsize=11, pad=6)

            # Add dynamic info above the image
            if j == 1 and timesteps is not None:  # noisy image
                t_val = timesteps[i].item()
                ax.set_title(f"{titles[j]}\n(t={t_val})", fontsize=10, pad=2)

            if j == 2:  # DDPM
                ax.set_title(f"{titles[j]}\nSNR={snr_ddpm:.2f}, SSIM={ssim_ddpm:.3f}", fontsize=10, pad=2)

            if j == 3 and nlm is not None:  # NLM
                ax.set_title(f"{titles[j]}\nSNR={snr_nlm:.2f}, SSIM={ssim_nlm:.3f}", fontsize=10, pad=2)

    plt.tight_layout()
    plt.show()

In [None]:
def display_metrics_table(metrics_dict):
    """
    Display evaluation metrics (SNR, SSIM) in a tabular format.
    """
    headers = ["Method", "SNR (dB)", "SSIM"]
    table = []

    for method, vals in metrics_dict.items():
        table.append([
            method.upper(),
            f"{vals['snr']:.4f}",
            f"{vals['ssim']:.4f}"
        ])

    print("\n=== Evaluation Metrics ===")
    print(tabulate(table, headers=headers, tablefmt="fancy_grid"))


## Main Code

In [None]:
##################### Dataset Loading and visualization ###################################
full_dataset, train_indices, test_indices = load_and_split_dataset()

# Display 5 training images
display_images(full_dataset, train_indices[:5])

# Display 3 test images in a single row
display_images(full_dataset, test_indices[:3], num_images=3, cols=3)

# Show stats for first training image
show_image_statistics(full_dataset, train_indices[0])

In [None]:
################################### Training data Preprocessing ###################################

# Preprocess and save training set
train_tensor = preprocess_and_save_dataset(full_dataset, train_indices[0:128], target_size=(128, 128), normalize_type="signed", save_path="/content/train_preprocessed.pt")

# Load later without preprocessing again
train_tensor = torch.load("/content/train_preprocessed.pt")

# Show some preprocessed images
display_preprocessed_images(train_tensor, num_images=16)

# Show statistics for one preprocessed image
show_preprocessed_statistics(train_tensor, index=0)

################################### Testing data Preprocessing ###################################

# Preprocess and save testing set
test_tensor = preprocess_and_save_dataset(full_dataset, test_indices[0:32], target_size=(128, 128), normalize_type="signed", save_path="/content/test_preprocessed.pt")

# Load later without preprocessing again
test_tensor = torch.load("/content/test_preprocessed.pt")

# Show some preprocessed images
display_preprocessed_images(test_tensor, num_images=2)

# Show statistics for one preprocessed image
show_preprocessed_statistics(test_tensor, index=0)

In [None]:
################################### Definition of the  Model ###################################
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Load the pretrained DDPM pipeline
model_id = "benetraco/brain_ddpm_128"
pipeline = DDPMPipeline.from_pretrained(model_id)
model = pipeline.unet.to(device)
model.train()  # set to training mode


################################### Start the model training ###################################
train_loader, test_loader = get_data_loaders(train_path="/content/train_preprocessed.pt",
                                            test_path="/content/test_preprocessed.pt",
                                            batch_size=8)

schedule = make_beta_schedule(T=100, beta_start=1e-4, beta_end=0.02, device=device)


train_model(
    model=model,
    train_loader=train_loader,
    schedule=schedule,
    device=device,
    epochs=10,
    lr=2e-4,
    checkpoint_dir="./ddpm_checkpoints",
    save_every=8,
    log_every = 32,
    resume_checkpoint=None  # or "ddpm_checkpoints/model_epoch_20.pth"
)


In [None]:
################################### Evaluation and comparing with NLM ###################################

# Load latest model checkpoint
model, checkpoint_epoch = load_latest_checkpoint(model, "./ddpm_checkpoints", device)
model.to(device)

In [None]:
# Evaluate on test dataset
test_schedule = make_beta_schedule(T=100, beta_start=1e-4, beta_end=0.001, device=device)

# # Only DDPM
# evaluate_model_1(model, test_loader, schedule, device="cuda", num_batches=5)
# Compare DDPM vs NLM
results = evaluate_model(model, test_loader, test_schedule, device="cuda", num_batches=20, use_nlm=True)

In [None]:
# Show metrics in table
display_metrics_table(results)

In [None]:
# # Show some images
# with torch.no_grad():
#     for batch in test_loader:
#         # Handle both possible cases: (images,) or images
#         if isinstance(batch, (list, tuple)):
#             images = batch[0].to("cuda")
#         else:
#             images = batch.to("cuda")

#         x_t, noise, timesteps = forward_diffusion_batch(images, schedule)
#         pred_noise = model(x_t, timesteps).sample
#         recon_ddpm = x_t - pred_noise
#         recon_nlm = fast_nlm_batch(x_t, device="cuda")

#         visualize_denoising(images, x_t, recon_ddpm, recon_nlm, num_images=3)
#         break

In [None]:
with torch.no_grad():
    for batch in test_loader:
        if isinstance(batch, (list, tuple)):
            images = batch[0].to("cuda")
        else:
            images = batch.to("cuda")

        x_t, noise, timesteps = forward_diffusion_batch(images, test_schedule)
        pred_noise = model(x_t, timesteps).sample

        # Reconstruct clean image estimate x0_hat
        sqrt_alphas_cumprod = test_schedule["sqrt_alphas_cumprod"]
        sqrt_one_minus_alphas_cumprod = test_schedule["sqrt_one_minus_alphas_cumprod"]
        a_t = sqrt_alphas_cumprod[timesteps].view(-1, 1, 1, 1)
        b_t = sqrt_one_minus_alphas_cumprod[timesteps].view(-1, 1, 1, 1)
        recon_ddpm = (x_t - b_t * pred_noise) / a_t

        recon_nlm = fast_nlm_batch(x_t, device="cuda")

        visualize_denoising(images, x_t, recon_ddpm, recon_nlm, timesteps=timesteps, num_images=7)
        break