In [1]:
# Cell 1: All necessary imports
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torchvision import transforms
                                               
import matplotlib.pyplot as plt
import nibabel as nib
import numpy as np
from pathlib import Path
from tqdm.notebook import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
import random

## Dataset and Dataloader

In [2]:
class MRIT1T2Dataset(Dataset):
    """Dataset class for loading paired/unpaired T1-T2 MRI data.
    Handles loading, validation, normalization and caching of MRI volumes."""
    
    def __init__(self, t1_dir, t2_dir, slice_mode='middle', paired=True, transform=None, cache_size=0):
        """
        Args:
            t1_dir (str): Directory containing T1 scans
            t2_dir (str): Directory containing T2 scans
            slice_mode (str): 'middle' or 'random' - how to select slice from volume
            paired (bool): If True, uses paired T1-T2 data, else random unpaired selection
            transform: Optional transforms to apply to slices
            cache_size (int): Number of volumes to cache in memory (0 for no caching)
        """
        super().__init__()
        
        # Add resize transform
        self.resize_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((128, 128)), 
            transforms.ToTensor()
        ])
        
        # Store initialization parameters
        self.t1_dir = t1_dir
        self.t2_dir = t2_dir
        self.transform = transform
        self.slice_mode = slice_mode
        self.paired = paired
        self.cache_size = cache_size
        
        # Get lists of all NIfTI files (.nii.gz)
        self.t1_files = sorted([f for f in os.listdir(t1_dir) if f.endswith('.nii.gz')])
        self.t2_files = sorted([f for f in os.listdir(t2_dir) if f.endswith('.nii.gz')])
        
        # For paired training, find matching T1-T2 pairs based on subject ID
        if self.paired:
            self.paired_files = []
            for t1f in self.t1_files:
                subject_id = t1f.split('-')[0][3:]  # Extract subject ID from filename
                matching_t2 = [t2f for t2f in self.t2_files if t2f.split('-')[0][3:] == subject_id]
                if matching_t2:
                    self.paired_files.append((t1f, matching_t2[0]))
            print(f"Found {len(self.paired_files)} paired T1/T2 datasets")
            self.data_files = self.paired_files
        else:
            # For unpaired, just use T1 files and randomly select T2 later
            self.data_files = [(t1f, None) for t1f in self.t1_files]
        
        # Initialize cache dictionary
        self.cache = {}

    def _load_and_validate_volume(self, filename, is_t1=True):
        """Load and validate a single MRI volume.
        
        Args:
            filename (str): Name of the NIfTI file
            is_t1 (bool): Whether this is a T1 volume (determines directory)
            
        Returns:
            tuple: (volume data array, (min_value, max_value))
        """
        dir_path = self.t1_dir if is_t1 else self.t2_dir
        filepath = os.path.join(dir_path, filename)
        
        # Check if volume is in cache
        if filepath in self.cache:
            vol = self.cache[filepath]
        else:
            # Load volume using NiBabel
            vol = nib.load(filepath).get_fdata()
            
            # Validate volume
            if not self.is_valid_volume(vol):
                raise ValueError(f"Invalid volume: {filename}")
            
            # Cache volume if cache isn't full
            if len(self.cache) < self.cache_size:
                self.cache[filepath] = vol
    
        # Calculate volume statistics for normalization
        stats = (float(vol.min()), float(vol.max()))
        
        return vol, stats

    def is_valid_volume(self, vol):
        """Check if volume meets quality criteria.
        
        Args:
            vol (np.ndarray): Volume data
            
        Returns:
            bool: True if volume is valid
        """
        min_size = 64  # Minimum dimension size
        return (vol.shape[0] >= min_size and 
                vol.shape[1] >= min_size and 
                vol.shape[2] >= 1 and 
                not np.any(np.isnan(vol)))

    def get_slice_idx(self, volume):
        """Get slice index based on slice_mode setting.
        
        Args:
            volume (np.ndarray): Volume data
            
        Returns:
            int: Index of slice to extract
        """
        if self.slice_mode == 'middle':
            return volume.shape[2] // 2
        else:  # random
            return random.randint(0, volume.shape[2] - 1)

    def __len__(self):
        """Return the total number of T1-T2 pairs in the dataset."""
        return len(self.data_files)

    def __getitem__(self, idx):
        """Get a T1-T2 pair of slices.
        
        Args:
            idx (int): Index of the pair
            
        Returns:
            dict: Contains 'T1' and 'T2' tensor slices
        """
        t1_file, t2_file = self.data_files[idx]
        
        # Load T1 volume and get stats
        t1_vol, t1_stats = self._load_and_validate_volume(t1_file, is_t1=True)
        
        if self.paired:
            # Load matching T2 volume
            t2_vol, t2_stats = self._load_and_validate_volume(t2_file, is_t1=False)
        else:
            # Random T2 volume for unpaired
            random_t2_idx = random.randint(0, len(self.t2_files) - 1)
            t2_vol, t2_stats = self._load_and_validate_volume(
                self.t2_files[random_t2_idx], is_t1=False
            )
            
        # Get slice indices
        t1_slice_idx = self.get_slice_idx(t1_vol)
        t2_slice_idx = (t1_slice_idx if self.paired 
                       else self.get_slice_idx(t2_vol))

        # Extract and normalize slices
        t1_slice = self.normalize_slice(t1_vol[:,:,t1_slice_idx], t1_stats)
        t2_slice = self.normalize_slice(t2_vol[:,:,t2_slice_idx], t2_stats)

        # Convert to tensors and add channel dimension
        t1_tensor = torch.from_numpy(t1_slice).float().unsqueeze(0)
        t2_tensor = torch.from_numpy(t2_slice).float().unsqueeze(0)

        # Apply resize transform
        t1_tensor = self.resize_transform(t1_tensor)
        t2_tensor = self.resize_transform(t2_tensor)

        # Apply any additional transforms
        if self.transform:
            t1_tensor = self.transform(t1_tensor)
            t2_tensor = self.transform(t2_tensor)

        # Clear from memory if not cached
        if t1_file not in self.cache:
            del t1_vol
        if t2_file not in self.cache:
            del t2_vol

        return {'T1': t1_tensor, 'T2': t2_tensor}

    @staticmethod
    def normalize_slice(slice_data, stats):
        """Normalize slice to [0,1] range using pre-computed statistics.
        
        Args:
            slice_data (np.ndarray): Raw slice data
            stats (tuple): (min_value, max_value) for normalization
            
        Returns:
            np.ndarray: Normalized slice
        """
        min_val, max_val = stats
        return (slice_data - min_val) / (max_val - min_val + 1e-8)

## Model (U-net with Time + Cross-attention)

In [3]:
class TimeEmbedding(nn.Module):
    """Time embedding module that projects timesteps into a higher dimensional space.
    This is crucial for conditioning the diffusion model on the noise level/timestep."""
    
    def __init__(self, n_channels):
        """
        Args:
            n_channels (int): Number of channels in the embedding
        """
        super().__init__()
        self.n_channels = n_channels
        # Project scalar timestep to higher dimension and add non-linearity
        self.time_proj = nn.Sequential(
            nn.Linear(1, n_channels),
            nn.SiLU(),  # Sigmoid Linear Unit activation
            nn.Linear(n_channels, n_channels)
        )

    def forward(self, t):
        """Projects timesteps into the embedding space.
        
        Args:
            t (torch.Tensor): Batch of timesteps [batch_size]
            
        Returns:
            torch.Tensor: Time embeddings [batch_size, n_channels]
        """
        t = t.unsqueeze(-1).float()  # Add feature dimension
        return self.time_proj(t)

class ConvBlock(nn.Module):
    """Convolutional block with time conditioning and residual connections.
    This is the basic building block of the UNet architecture."""
    
    def __init__(self, in_channels, out_channels, time_channels):
        """
        Args:
            in_channels (int): Number of input channels
            out_channels (int): Number of output channels
            time_channels (int): Number of channels in time embedding
        """
        super().__init__()
        # Main convolution path
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.norm1 = nn.GroupNorm(8, out_channels)  # Normalization for stability
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.norm2 = nn.GroupNorm(8, out_channels)
        
        # Time conditioning projection
        self.time_mlp = nn.Linear(time_channels, out_channels)
        
        # Residual connection handling
        self.use_residual = in_channels == out_channels
        if not self.use_residual:
            self.residual_conv = nn.Conv2d(in_channels, out_channels, 1)

    def forward(self, x, t):
        """
        Args:
            x (torch.Tensor): Input feature maps [B, C, H, W]
            t (torch.Tensor): Time embeddings [B, time_channels]
            
        Returns:
            torch.Tensor: Processed feature maps [B, out_channels, H, W]
        """
        # Save residual connection
        residual = x if self.use_residual else self.residual_conv(x)
        
        # Main convolution path
        h = self.conv1(x)
        h = self.norm1(h)
        # Add time conditioning
        h += self.time_mlp(t)[:, :, None, None]  # Broadcasting time features
        h = F.silu(h)
        h = self.conv2(h)
        h = self.norm2(h)
        h = F.silu(h)
        
        return h + residual  # Add residual connection

class SelfAttention(nn.Module):
    """Self-attention module for capturing long-range dependencies in the feature maps."""
    
    def __init__(self, channels):
        """
        Args:
            channels (int): Number of input/output channels
        """
        super().__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_self = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input feature maps [B, C, H, W]
            
        Returns:
            torch.Tensor: Self-attended feature maps [B, C, H, W]
        """
        size = x.shape[-2:]  # Save spatial dimensions
        # Reshape for attention
        x = x.flatten(2).transpose(1, 2)  # [B, HW, C]
        x = self.ln(x)
        # Self attention
        attention_value, _ = self.mha(x, x, x)
        attention_value = attention_value + x  # Skip connection
        # Feedforward network
        attention_value = self.ff_self(attention_value) + attention_value
        # Restore spatial dimensions
        return attention_value.transpose(1, 2).view(-1, self.channels, *size)

class CrossAttention(nn.Module):
    """Cross-attention module for attending between source and target features."""
    
    def __init__(self, channels):
        """
        Args:
            channels (int): Number of input/output channels
        """
        super().__init__()
        self.channels = channels
        self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
        self.ln = nn.LayerNorm([channels])
        self.ff_cross = nn.Sequential(
            nn.LayerNorm([channels]),
            nn.Linear(channels, channels),
            nn.GELU(),
            nn.Linear(channels, channels),
        )

    def forward(self, x, context):
        """
        Args:
            x (torch.Tensor): Query feature maps [B, C, H, W]
            context (torch.Tensor): Key/Value feature maps [B, C, H, W]
            
        Returns:
            torch.Tensor: Cross-attended feature maps [B, C, H, W]
        """
        size = x.shape[-2:]
        # Reshape for attention
        x = x.flatten(2).transpose(1, 2)  # [B, HW, C]
        context = context.flatten(2).transpose(1, 2)
        x = self.ln(x)
        # Cross attention
        attention_value, _ = self.mha(x, context, context)
        attention_value = attention_value + x  # Skip connection
        # Feedforward network
        attention_value = self.ff_cross(attention_value) + attention_value
        # Restore spatial dimensions
        return attention_value.transpose(1, 2).view(-1, self.channels, *size)

# UNet Architecture
class UNet(nn.Module):
    """U-Net architecture with attention and time conditioning for diffusion models."""
    
    def __init__(self, in_channels=3, time_channels=256, n_channels=64):
        """
        Args:
            in_channels (int): Number of input channels (default: 3)
            time_channels (int): Dimension of time embedding (default: 256)
            n_channels (int): Base number of channels (default: 64)
        """
        super().__init__()
        
        # Add context projection layer
        self.context_proj = nn.Conv2d(1, n_channels * 8, kernel_size=1)
        
        # Time embedding
        self.time_embed = TimeEmbedding(time_channels)
        
        # Encoder path
        self.inc = ConvBlock(in_channels, n_channels, time_channels)
        self.down1 = nn.ModuleList([
            nn.MaxPool2d(2),
            ConvBlock(n_channels, n_channels*2, time_channels),
            SelfAttention(n_channels*2)
        ])
        self.down2 = nn.ModuleList([
            nn.MaxPool2d(2),
            ConvBlock(n_channels*2, n_channels*4, time_channels),
            SelfAttention(n_channels*4)
        ])
        self.down3 = nn.ModuleList([
            nn.MaxPool2d(2),
            ConvBlock(n_channels*4, n_channels*8, time_channels),
            SelfAttention(n_channels*8)
        ])

        # Bottleneck with attention
        self.bot1 = ConvBlock(n_channels*8, n_channels*8, time_channels)
        self.bot_attn = SelfAttention(n_channels*8)
        self.cross_attn = CrossAttention(n_channels*8)
        self.bot2 = ConvBlock(n_channels*8, n_channels*8, time_channels)
        self.bot3 = ConvBlock(n_channels*8, n_channels*8, time_channels)

        # Decoder path with skip connections
        # For upsampling, adjust the channel dimensions to match the concatenated skip connections.
        self.up1 = nn.ModuleList([
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            ConvBlock(n_channels*12, n_channels*4, time_channels),  # e.g., concatenating x4 (n_channels*8) and x3 (n_channels*4) gives n_channels*12
            SelfAttention(n_channels*4)
        ])
        self.up2 = nn.ModuleList([
            nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
            ConvBlock(n_channels*6, n_channels*2, time_channels),   # e.g., concatenating previous output (n_channels*4) with x2 (n_channels*2)
            SelfAttention(n_channels*2)
        ])
        self.up3 = nn.ModuleList([
            ConvBlock(n_channels*3, n_channels, time_channels)        # e.g., concatenating previous output (n_channels*2) with x1 (n_channels)
        ])
        
        # Output convolution
        self.outc = nn.Conv2d(n_channels, 1, 1)

    def forward(self, x, t, condition=None, context=None):
        """
        Args:
            x (torch.Tensor): Input tensor [B, C, H, W]
            t (torch.Tensor): Timesteps [B]
            condition (torch.Tensor, optional): Conditioning flag
            context (torch.Tensor, optional): Context image for cross-attention [B, 1, H, W]
            
        Returns:
            torch.Tensor: Predicted noise or image delta [B, 1, H, W]
        """
        # Add conditioning information if provided
        if condition is not None:
            condition = condition.expand(-1, 1, x.shape[2], x.shape[3])
            x = torch.cat([x, condition], dim=1)
            
        # Time embedding
        t = self.time_embed(t)
        
        # Encoder path with skip connections
        x1 = self.inc(x, t)
        x2 = self.down1[0](x1)  # MaxPool2d
        x2 = self.down1[1](x2, t)  # ConvBlock takes `t`
        x2 = self.down1[2](x2)  # SelfAttention
        
        x3 = self.down2[0](x2)  # MaxPool2d
        x3 = self.down2[1](x3, t)  # ConvBlock takes `t`
        x3 = self.down2[2](x3)  # SelfAttention

        x4 = self.down3[0](x3)  # MaxPool2d
        x4 = self.down3[1](x4, t)  # ConvBlock takes `t`
        x4 = self.down3[2](x4)  # SelfAttention
        
        # Bottleneck with attention
        x4 = self.bot1(x4, t)
        x4 = self.bot_attn(x4)
        if context is not None:
            # Project context to match bottleneck dimensions
            context = self.context_proj(context)  # [B, n_channels*8, H, W]
            x4 = self.cross_attn(x4, context)
        x4 = self.bot2(x4, t)
        x4 = self.bot3(x4, t)
        
        # Decoder path with skip connections
        # Upsample x4 to match x3
        x4 = F.interpolate(x4, size=x3.shape[-2:], mode='bilinear', align_corners=True)
        # Concatenate bottleneck output (upsampled x4) with the output from down2 (x3)
        x = torch.cat([x4, x3], dim=1)
        x = self.up1[0](x)          # Upsample (if further upsampling is needed)
        x = self.up1[1](x, t)         # ConvBlock with time conditioning
        x = self.up1[2](x)          # SelfAttention

        # Up Block 2: Concatenate result with output from down1 (x2)
        x = torch.cat([x, x2], dim=1)
        x = self.up2[0](x)          # Upsample
        x = self.up2[1](x, t)         # ConvBlock with time conditioning
        x = self.up2[2](x)          # SelfAttention

        # Up Block 3: Concatenate result with initial encoder output (x1)
        x = torch.cat([x, x1], dim=1)
        x = self.up3[0](x, t)         # ConvBlock with time conditioning
        
        return self.outc(x)

## Diffusion Trainer Class

In [4]:
class DDPMTrainer:
    """Denoising Diffusion Probabilistic Models (DDPM) Trainer.
    Handles the training process, including:
    - Forward/reverse diffusion processes
    - Optimization
    - Mixed precision training
    - Sampling
    """
    
    def __init__(
        self, model, n_timesteps=1000, beta_start=1e-4, beta_end=0.02,
        lr=1e-4, device="cuda"
    ):
        """
        Args:
            model: UNet model instance
            n_timesteps (int): Number of diffusion timesteps
            beta_start (float): Starting noise schedule value
            beta_end (float): Ending noise schedule value
            lr (float): Learning rate for Adam optimizer
            device (str): Device to run on ("cuda" or "cpu")
        """
        self.model = model.to(device)
        self.device = device
        self.n_timesteps = n_timesteps
        
        # Setup noise schedule
        self.betas = torch.linspace(beta_start, beta_end, n_timesteps).to(device)
        self.alphas = 1 - self.betas
        self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
        
        # Pre-compute values for diffusion process
        self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
        self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - self.alphas_cumprod)
        
        # Setup optimizer
        self.optimizer = Adam(self.model.parameters(), lr=lr)

    def diffuse_step(self, x_0, t):
        """Forward diffusion step: adds noise to image according to timestep.
        
        Args:
            x_0 (torch.Tensor): Original clean image
            t (torch.Tensor): Timesteps for batch
            
        Returns:
            tuple: (noisy image, noise added)
        """
        noise = torch.randn_like(x_0)  # Random noise
        
        # Get noise scaling factors for timestep t
        sqrt_alpha_t = self.sqrt_alphas_cumprod[t].view(-1, 1, 1, 1)
        sqrt_one_minus_alpha_t = self.sqrt_one_minus_alphas_cumprod[t].view(-1, 1, 1, 1)
        
        # Apply forward diffusion equation
        x_t = sqrt_alpha_t * x_0 + sqrt_one_minus_alpha_t * noise
        return x_t, noise
    
    def train_one_batch(self, x_0, condition=None, context=None):
        """Trains model on a single batch.
        
        Args:
            x_0 (torch.Tensor): Clean images [B, C, H, W]
            condition (torch.Tensor, optional): Conditioning information
            context (torch.Tensor, optional): Context for cross-attention
            
        Returns:
            float: Batch loss value
        """
        batch_size = x_0.shape[0]
        # Sample random timesteps for batch
        t = torch.randint(0, self.n_timesteps, (batch_size,), device=self.device)
        
        # Apply forward diffusion
        x_t, noise = self.diffuse_step(x_0, t)

        # Forward pass (removed autocast since it's not supported on MPS)
        noise_pred = self.model(x_t, t, condition=condition, context=context)
        # Calculate loss between predicted and actual noise
        loss = F.mse_loss(noise_pred, noise)
        
        # Optimizer step
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.item()
    
    @torch.no_grad()
    def sample(self, condition=None, context=None, shape=None, n_steps=None):
        """Generates samples using the reverse diffusion process.
        
        Args:
            condition (torch.Tensor, optional): Conditioning information
            context (torch.Tensor, optional): Context for cross-attention
            shape (tuple): Shape of samples to generate
            n_steps (int, optional): Number of sampling steps
            
        Returns:
            torch.Tensor: Generated samples
        """
        if n_steps is None:
            n_steps = self.n_timesteps
        
        # Start from pure noise
        x_t = torch.randn(shape, device=self.device)
        
        # Gradually denoise the sample
        for t in reversed(range(n_steps)):
            t_batch = torch.ones(shape[0], device=self.device, dtype=torch.long) * t
            
            # Predict noise in current sample
            noise_pred = self.model(x_t, t_batch, condition=condition, context=context)
            
            # Get diffusion parameters for timestep t
            alpha_t = self.alphas[t]
            alpha_t_cumprod = self.alphas_cumprod[t]
            beta_t = self.betas[t]
            
            # Add noise only if not the final step
            if t > 0:
                noise = torch.randn_like(x_t)
            else:
                noise = 0.
            
            # Apply reverse diffusion equation
            x_t = (1 / torch.sqrt(alpha_t)) * (
                x_t - beta_t / torch.sqrt(1 - alpha_t_cumprod) * noise_pred
            ) + torch.sqrt(beta_t) * noise
        
        return x_t


class MetricsTracker:
    """Tracks and computes various metrics during training."""
    
    def __init__(self, device):
        """
        Args:
            device: Device to run metrics computation on
        """
        self.psnr = PeakSignalNoiseRatio().to(device)
        self.ssim = StructuralSimilarityIndexMeasure().to(device)
        self.reset()
    
    def reset(self):
        """Resets all metrics for new epoch."""
        self.train_losses = []
        self.psnr_scores = []
        self.ssim_scores = []
    
    def update(self, pred, target, loss=None):
        """Updates metrics with new batch results."""
        if loss is not None:
            self.train_losses.append(loss)
        if pred is not None and target is not None:
            self.psnr_scores.append(self.psnr(pred, target).item())
            self.ssim_scores.append(self.ssim(pred, target).item())
    
    def get_metrics(self):
        """Returns average metrics for the current period."""
        return {
            'loss': np.mean(self.train_losses) if self.train_losses else 0,
            'psnr': np.mean(self.psnr_scores) if self.psnr_scores else 0,
            'ssim': np.mean(self.ssim_scores) if self.ssim_scores else 0
        }

def visualize_samples(t1_real, t2_real, t1_gen, t2_gen, epoch, step, save=True):
    """Creates visualization grid of real and generated images.
    
    Args:
        t1_real, t2_real: Real T1 and T2 images
        t1_gen, t2_gen: Generated T1 and T2 images
        epoch (int): Current epoch
        step (int): Current step
        save (bool): Whether to save the plot
    """
    fig, axes = plt.subplots(1, 4, figsize=(15, 4))
    
    # Real T1
    axes[0].imshow(t1_real[0,0].cpu().numpy().T, cmap='gray', origin='lower')
    axes[0].set_title('Real T1')
    axes[0].axis('off')
    
    # Real T2
    axes[1].imshow(t2_real[0,0].cpu().numpy().T, cmap='gray', origin='lower')
    axes[1].set_title('Real T2')
    axes[1].axis('off')
    
    # Generated T2 from T1
    axes[2].imshow(t2_gen[0,0].cpu().numpy().T, cmap='gray', origin='lower')
    axes[2].set_title('T1→T2')
    axes[2].axis('off')
    
    # Generated T1 from T2
    axes[3].imshow(t1_gen[0,0].cpu().numpy().T, cmap='gray', origin='lower')
    axes[3].set_title('T2→T1')
    axes[3].axis('off')
    
    plt.tight_layout()
    
    if save:
        plt.savefig(f'visualizations/diffusion/samples_epoch{epoch}_step{step}.png')
        plt.close()
    else:
        plt.show()

## Training Loop Script

In [None]:
class Config:
    """Configuration class containing all training parameters and paths."""
    def __init__(self):
        # Data paths
        self.t1_dir = "../data/IXI_T1"  # Go up one level from notebooks to root
        self.t2_dir = "../data/IXI_T2"
        
        # Model parameters
        self.in_channels = 2  # Image + condition channel
        self.time_channels = 256
        self.n_channels = 64
        self.n_timesteps = 1000
        self.beta_start = 1e-4
        self.beta_end = 0.02
        
        # Training parameters
        self.batch_size = 1
    
        self.num_epochs = 100
        self.lr = 1e-4
        self.save_interval = 100  # Save checkpoints every N steps
        
        # Directories setup - go up one level from notebooks to root
        self.checkpoint_dir = Path("../checkpoints/diffusion")
        self.vis_dir = Path("../visualizations/diffusion") 
        self.log_dir = Path("../logs/diffusion")
        
        # Create directories if they don't exist
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        self.vis_dir.mkdir(parents=True, exist_ok=True)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        
        # Device
        self.device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

# Initialize configuration
config = Config()

# Cell 8: Training Functions

def load_checkpoint(trainer, checkpoint_path):
    """Loads model and optimizer state from checkpoint.
    
    Args:
        trainer: DDPMTrainer instance
        checkpoint_path: Path to checkpoint file
        
    Returns:
        tuple: (epoch, global_step)
    """
    checkpoint = torch.load(checkpoint_path)
    trainer.model.load_state_dict(checkpoint['model_state_dict'])
    trainer.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch'], checkpoint['global_step']

def save_checkpoint(trainer, epoch, global_step):
    """Saves model and optimizer state to checkpoint.
    
    Args:
        trainer: DDPMTrainer instance
        epoch: Current epoch
        global_step: Current global step
    """
    checkpoint = {
        'model_state_dict': trainer.model.state_dict(),
        'optimizer_state_dict': trainer.optimizer.state_dict(),
        'epoch': epoch,
        'global_step': global_step,
    }
    torch.save(checkpoint, 
               config.checkpoint_dir / f'model_epoch{epoch}_step{global_step}.pt')

# Main Training Loop

def train_diffusion():
    """Main training function that handles the complete training pipeline."""
    
    # Initialize tensorboard writer
    writer = SummaryWriter(config.log_dir)
    
    # Initialize dataset and dataloader
    dataset = MRIT1T2Dataset(
        t1_dir=config.t1_dir,
        t2_dir=config.t2_dir,
        slice_mode='middle',
        paired=True
    )
    
    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=0,
        pin_memory=False
    )
    
    # Initialize model and trainer
    model = UNet(
        in_channels=config.in_channels,
        time_channels=config.time_channels,
        n_channels=config.n_channels
    )
    
    trainer = DDPMTrainer(
        model=model,
        n_timesteps=config.n_timesteps,
        beta_start=config.beta_start,
        beta_end=config.beta_end,
        lr=config.lr,
        device=config.device
    )
    
    # Initialize metrics tracker
    metrics = MetricsTracker(config.device)
    
    # Check for existing checkpoint
    start_epoch = 0
    global_step = 0
    if config.checkpoint_dir.exists():
        checkpoints = list(config.checkpoint_dir.glob('model_epoch*_step*.pt'))
        if checkpoints:
            latest_checkpoint = max(checkpoints, key=os.path.getctime)
            start_epoch, global_step = load_checkpoint(trainer, latest_checkpoint)
            print(f"Resuming from epoch {start_epoch}, step {global_step}")
    
    # Training loop
    for epoch in range(start_epoch, config.num_epochs):
        metrics.reset()
        epoch_pbar = tqdm(dataloader, desc=f"Epoch {epoch}")
        
        for batch_idx, batch in enumerate(epoch_pbar):
            # Move data to device
            t1 = batch['T1'].to(config.device)
            t2 = batch['T2'].to(config.device)
            
            # Train T1 -> T2
            loss_t1_t2 = trainer.train_one_batch(
                x_0=t2,            # Target is T2
                condition=t1,      # Condition on T1
                context=t1         # Cross-attention sees T1
            )
            
            # Train T2 -> T1
            loss_t2_t1 = trainer.train_one_batch(
                x_0=t1,            # Target is T1
                condition=t2,      # Condition on T2
                context=t2         # Cross-attention sees T2
            )
            
            # Update metrics
            avg_loss = (loss_t1_t2 + loss_t2_t1) / 2
            metrics.update(None, None, avg_loss)
            
            # Update progress bar
            epoch_pbar.set_postfix({
                'loss': f"{avg_loss:.4f}",
                'step': global_step
            })
            
            # Checkpoint and visualization
            if global_step % config.save_interval == 0:
                trainer.model.eval()
                with torch.no_grad():
                    # Generate samples
                    t2_gen = trainer.sample(
                        condition=t1,
                        context=t1,
                        shape=t1.shape
                    )
                    t1_gen = trainer.sample(
                        condition=t2,
                        context=t2,
                        shape=t2.shape
                    )
                    
                    # Calculate metrics for generated images
                    metrics.update(t2_gen, t2)
                    metrics.update(t1_gen, t1)
                    
                    # Visualize samples
                    visualize_samples(
                        t1, t2,
                        t1_gen, t2_gen,
                        epoch, global_step
                    )
                    
                    # Log to tensorboard
                    current_metrics = metrics.get_metrics()
                    writer.add_scalar('Loss/train', current_metrics['loss'], global_step)
                    writer.add_scalar('Metrics/PSNR', current_metrics['psnr'], global_step)
                    writer.add_scalar('Metrics/SSIM', current_metrics['ssim'], global_step)
                    writer.add_images('Samples/T1', t1, global_step)
                    writer.add_images('Samples/T2', t2, global_step)
                    writer.add_images('Samples/T1_generated', t1_gen, global_step)
                    writer.add_images('Samples/T2_generated', t2_gen, global_step)
                    
                    # Save checkpoint
                    save_checkpoint(trainer, epoch, global_step)
                
                trainer.model.train()
            
            global_step += 1
        
        # End of epoch
        epoch_metrics = metrics.get_metrics()
        print(f"\nEpoch {epoch} Summary:")
        print(f"Average Loss: {epoch_metrics['loss']:.4f}")
        print(f"Average PSNR: {epoch_metrics['psnr']:.2f}")
        print(f"Average SSIM: {epoch_metrics['ssim']:.4f}")
    
    writer.close()


if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    random.seed(42)
    
    # Start training
    train_diffusion()

Found 577 paired T1/T2 datasets


Epoch 0:   0%|          | 0/577 [00:00<?, ?it/s]