**Installations and Imports**

In [1]:
# Install required packages quietly (suppress output)
!pip -q install imageio scikit-image pytorch-msssim lpips diffusers matplotlib --no-input >/dev/null

# Import necessary libraries
import os, math, random, glob, json, re
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw, ImageFont
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split, Subset
from torchvision import transforms
import cv2
from skimage.metrics import structural_similarity as skimage_ssim
from pytorch_msssim import ms_ssim as torch_ms_ssim
import lpips
from diffusers import UNet2DModel
import matplotlib.pyplot as plt

2025-10-02 16:19:10.311346: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1759421950.510253      36 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1759421950.568648      36 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


**Configuration Settings**

In [2]:
# Configuration class to store all hyperparameters and settings
class CFG:
    SEED = 42  # Random seed for reproducibility
    IMG_SIZE = 256  # Size of input images (256x256)
    BATCH = 8  # Batch size for training
    EPOCHS = 10  # Number of training epochs (reduced for GPU limits)
    LR = 1e-4  # Learning rate
    MAX_SAMPLES = None  # Maximum number of samples to use (None for full dataset)
    NUM_WORKERS = 2  # Number of workers for data loading
    TARGET_AGE = 216  # Target age in months (18 years)
    T_STEPS = 1000  # Number of timesteps for diffusion
    DDIM_STEPS = 30  # Number of steps for DDIM sampling
    CFG_DROP = 0.1  # Dropout rate for classifier-free guidance
    CFG_WEIGHT = 2.0  # Weight for classifier-free guidance
    AMP = True  # Use automatic mixed precision
    ACCUM_STEPS = 2  # Gradient accumulation steps
    CLIP_NORM = 1.0  # Gradient clipping norm
    SAVE_DIR = "/kaggle/working"  # Directory to save models
    OUT_DIR = "/kaggle/working/out_216"  # Directory to save outputs
    COMPILE = False  # Disable model compilation (avoid Dynamo issues)
    DATASET_HINTS = [  # Possible paths for the dataset
        "/kaggle/input/rsna-bone-age",
        "/kaggle/input/rsna-bone-age/rsna-bone-age",
        "/kaggle/input"
    ]
    MEAS_XLSX = "/kaggle/working/measurements_export.xlsx"  # Path to save Excel measurements
    AREA_LAMBDA = 3.0  # Weight for area loss term
    SOFT_K = 7.0  # Softmax scaling factor for area calculation
    AREA_MIN = 0.02  # Minimum area fraction
    AREA_MAX = 0.75  # Maximum area fraction
    ATTN_HEADS = 4  # Number of attention heads
    ATTN_DIM = 32  # Dimension per attention head

# Set random seed for reproducibility
def seed_everything(seed: int = 42):
    random.seed(seed)  # Set Python random seed
    np.random.seed(seed)  # Set NumPy random seed
    torch.manual_seed(seed)  # Set PyTorch random seed
    torch.cuda.manual_seed_all(seed)  # Set PyTorch CUDA random seed
    torch.backends.cudnn.deterministic = True  # Ensure deterministic behavior
    torch.backends.cudnn.benchmark = False  # Disable benchmarking for reproducibility

# Apply seeding
seed_everything(CFG.SEED)

# Set device (GPU if available, else CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

**Data Discovery Functions**

In [3]:
# Function to find the root directory of the RSNA Bone Age dataset
def find_rsna_boneage_root() -> Tuple[str, str]:
    candidates = []  # List to store potential dataset locations
    
    # Search through possible dataset paths
    for base in CFG.DATASET_HINTS:
        if not os.path.exists(base):
            continue  # Skip if path doesn't exist
            
        # Look for the CSV file
        for csv_path in glob.glob(os.path.join(base, "**", "boneage-training-dataset.csv"), recursive=True):
            root = os.path.dirname(csv_path)  # Get directory containing CSV
            preferred = os.path.join(root, "boneage-training-dataset")  # Preferred image directory
            
            # Create list of potential image directories
            dirs = [preferred] if os.path.isdir(preferred) else []
            dirs += [d for d in glob.glob(os.path.join(root, "**"), recursive=True) if os.path.isdir(d)]
            
            # Find directory with most images
            best_dir, best_cnt = "", -1
            for d in dirs:
                cnt = (len(glob.glob(os.path.join(d, "*.png"))) +  # Count PNG files
                       len(glob.glob(os.path.join(d, "*.jpg"))) +  # Count JPG files
                       len(glob.glob(os.path.join(d, "*.jpeg"))))  # Count JPEG files
                if cnt > best_cnt:
                    best_dir, best_cnt = d, cnt
            
            candidates.append((csv_path, best_dir, best_cnt))
    
    # Handle case where no dataset is found
    if not candidates:
        raise FileNotFoundError("Attach Kaggle dataset 'kmader/rsna-bone-age' via Add data.")
    
    # Select the candidate with the most images
    csv_path, img_dir, cnt = sorted(candidates, key=lambda x: -x[2])[0]
    print(f"[data] CSV: {csv_path}")
    print(f"[data] IMG_DIR: {img_dir} (files: {cnt})")
    return csv_path, img_dir

# Regular expression to extract ID from image filenames
_ID_RE = re.compile(r"(\d+)\.(png|jpg|jpeg)$", re.IGNORECASE)

# Function to index images in the directory
def index_image_dir(img_dir: str) -> Dict[int, str]:
    idx: Dict[int, str] = {}  # Dictionary to map image IDs to paths
    
    # Recursively search for image files
    for path in glob.glob(os.path.join(img_dir, "**", "*.*"), recursive=True):
        if not os.path.isfile(path): 
            continue  # Skip if not a file
            
        # Extract ID from filename
        m = _ID_RE.search(os.path.basename(path))
        if not m: 
            continue  # Skip if no match
            
        try:
            img_id = int(m.group(1))  # Convert ID to integer
            # Update path if this is the shortest path (most direct)
            if img_id not in idx or path.count(os.sep) < idx[img_id].count(os.sep):
                idx[img_id] = path
        except Exception:
            continue  # Skip if ID conversion fails
    
    print(f"[data] Indexed {len(idx)} images.")
    return idx

**Image Processing Utilities**

In [4]:
# Function to load grayscale PNG/JPG images and resize if needed
def load_gray_png_any(path: str, size: int = CFG.IMG_SIZE) -> Image.Image:
    im = Image.open(path).convert("L")  # Open image and convert to grayscale
    if im.size != (size, size):
        im = im.resize((size, size), Image.LANCZOS)  # Resize if not target size
    return im

# Function to convert float array [0,1] to uint8 [0,255]
def to_uint8(x: np.ndarray) -> np.ndarray:
    x = np.clip(x, 0.0, 1.0)  # Clip values to [0,1]
    return (x * 255.0 + 0.5).astype(np.uint8)  # Scale and convert to uint8

# Function to convert uint8 [0,255] to float [0,1]
def from_uint8(x: np.ndarray) -> np.ndarray:
    return x.astype(np.float32) / 255.0  # Scale to [0,1]

# Function to convert PIL image to PyTorch tensor (normalized to [-1,1])
def pil_to_tensor_gray(im: Image.Image) -> torch.Tensor:
    arr = np.array(im, dtype=np.float32) / 255.0  # Convert to numpy and normalize to [0,1]
    arr = (arr - 0.5) / 0.5  # Normalize to [-1,1]
    return torch.from_numpy(arr)[None, None, :, :].contiguous()  # Convert to tensor and add batch/channel dims

# Function to convert PyTorch tensor to PIL image
def tensor_to_pil_gray(t: torch.Tensor) -> Image.Image:
    t = t[0, 0].detach().cpu().clamp(-1, 1).numpy()  # Remove batch/channel, clamp to [-1,1], convert to numpy
    t = (t + 1) * 0.5  # Normalize to [0,1]
    return Image.fromarray(to_uint8(t))  # Convert to PIL image

**Robust Measurements and Mask Generation**

In [5]:
# Function to find the largest connected component in a binary image
def _largest_cc(bin_img: np.ndarray) -> np.ndarray:
    n, labels = cv2.connectedComponents(bin_img.astype(np.uint8))  # Get connected components
    if n <= 1:
        return bin_img  # Return original if only one component (background)
    
    # Find area of each component
    areas = [(labels == i).sum() for i in range(1, n)]
    idx = int(np.argmax(areas)) + 1  # Index of largest component
    return (labels == idx).astype(np.uint8)  # Return mask of largest component

# Function to calculate major span using PCA
def _pca_major_span(coords_full: np.ndarray) -> float:
    if coords_full.shape[0] < 5:
        return 0.0  # Return 0 if too few points
    
    x = coords_full[:, [1,0]].astype(np.float32)  # Swap x,y for PCA
    x -= x.mean(0, keepdims=True)  # Center the data
    cov = (x.T @ x) / max(1, len(x)-1)  # Compute covariance matrix
    vals, vecs = np.linalg.eigh(cov)  # Eigen decomposition
    v = vecs[:, np.argmax(vals)]  # Principal component
    proj = x @ v  # Project data onto principal component
    return float(proj.max() - proj.min())  # Return span

# Function to compute robust measurements and generate mask
def robust_measurements_and_mask(img_u8: np.ndarray) -> Tuple[dict, np.ndarray]:
    try:
        h, w = img_u8.shape  # Get image dimensions
        b = max(1, int(0.03 * min(h, w)))  # Border size (3% of min dimension)
        roi = img_u8[b:h-b, b:w-b]  # Define region of interest (excluding border)
        blur = cv2.GaussianBlur(roi, (5,5), 0)  # Apply Gaussian blur
        thr, _ = cv2.threshold(blur, 0, 255, cv2.THRESH_OTSU)  # Otsu thresholding
        m = (blur > thr).astype(np.uint8)  # Create binary mask
        
        # Define morphological kernels
        k3 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
        k5 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
        
        # Apply morphological operations
        m = cv2.morphologyEx(m, cv2.MORPH_OPEN, k3)  # Opening to remove small noise
        m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k5)  # Closing to fill small holes
        m = _largest_cc(m)  # Keep only largest connected component
        
        # Create full-size mask
        mask_full = np.zeros_like(img_u8, dtype=np.uint8)
        mask_full[b:h-b, b:w-b] = m
        
        # Get coordinates of mask pixels
        ys, xs = np.where(mask_full > 0)
        area = float(mask_full.sum())  # Calculate area
        
        # Handle case where mask is empty
        if ys.size == 0:
            dens = float(img_u8.mean())
            return dict(area=0.0, height=0.0, width=0.0, aspect=0.0, density=dens, width_length_ratio=0.0), mask_full
        
        # Calculate bounding box
        ymin, ymax = ys.min(), ys.max()
        xmin, xmax = xs.min(), xs.max()
        bbox_h = float((ymax - ymin + 1))
        bbox_w = float((xmax - xmin + 1))
        span = _pca_major_span(np.stack([ys, xs], 1))  # Calculate PCA span
        
        # Calculate height and width
        height = max(bbox_h, span)
        width = bbox_w
        aspect = height / max(1.0, width)
        dens = float(img_u8[mask_full > 0].mean())  # Mean density in mask
        frac = area / float(h * w)  # Fraction of image covered by mask
        
        # Handle case where mask covers most of image
        if frac > 0.85 and height >= (h * 0.98):
            p = np.percentile(roi, 85)  # Get 85th percentile
            m2 = (roi > p).astype(np.uint8)  # Create new mask with higher threshold
            m2 = cv2.morphologyEx(m2, cv2.MORPH_OPEN, k3)  # Apply opening
            m2 = _largest_cc(m2)  # Keep largest component
            
            # Update full-size mask
            mask_full = np.zeros_like(img_u8, dtype=np.uint8)
            mask_full[b:h-b, b:w-b] = m2
            
            # Recalculate measurements with new mask
            ys, xs = np.where(mask_full > 0)
            if ys.size >= 5:
                ymin, ymax = ys.min(), ys.max()
                xmin, xmax = xs.min(), xs.max()
                bbox_h = float((ymax - ymin + 1))
                bbox_w = float((xmax - xmin + 1))
                span = _pca_major_span(np.stack([ys, xs], 1))
                height = max(bbox_h, span)
                width = bbox_w
                aspect = height / max(1.0, width)
                dens = float(img_u8[mask_full > 0].mean())
                area = float(mask_full.sum())
                frac = area / float(h * w)
        
        # Return measurements and mask
        return dict(
            area=area,
            height=float(np.clip(height, 0, h)),
            width=float(np.clip(width, 0, w)),
            aspect=float(aspect),
            density=dens,
            width_length_ratio=float(width / max(1.0, height)),
            area_frac=frac
        ), mask_full
    except Exception as e:
        print(f"Warning: Mask generation failed: {e}")
        return dict(area=0.0, height=0.0, width=0.0, aspect=0.0, density=0.0, width_length_ratio=0.0), np.zeros_like(img_u8)

# Function to compute robust measurements only (without mask)
def robust_measurements(img_u8: np.ndarray) -> dict:
    d, _ = robust_measurements_and_mask(img_u8)
    return d

# Function to generate pediatric bone mask
def make_pediatric_mask(img_u8: np.ndarray) -> np.ndarray:
    _, m = robust_measurements_and_mask(img_u8)
    return m

**Clinical Metrics Calculation**

In [6]:
# Function to compute clinical metrics from bone image and mask
def compute_clinical_metrics(img_u8: np.ndarray, mask: np.ndarray) -> Dict[str, float]:
    # Detect edges using Canny edge detector
    edges = cv2.Canny(img_u8, 100, 200)
    mask_bool = mask > 0  # Convert mask to boolean
    
    # Calculate mean edge value within bone mask (cortical thickness)
    edge_in_bone = edges[mask_bool].mean() if mask_bool.sum() > 0 else 0.0
    cortical_thickness = float(edge_in_bone / 255.0)
    
    # Define morphological kernel
    k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5,5))
    
    # Calculate epiphyseal plate width by erosion
    eroded = cv2.erode(mask, k)
    plate_region = mask - eroded
    plate_width = float(plate_region.sum() / mask.sum()) if mask.sum() > 0 else 0.0
    
    # Return clinical metrics
    return {
        "cortical_thickness": cortical_thickness,
        "epiphyseal_plate_width": plate_width
    }

**Greulich-Pyle Bone Age Estimation**

In [7]:
# Function to estimate bone age using Greulich-Pyle method
def estimate_gp_bone_age(img_u8: np.ndarray, mask: np.ndarray, sex: int) -> float:
    try:
        # Compute clinical metrics
        clinical = compute_clinical_metrics(img_u8, mask)
        cortical_thickness = clinical["cortical_thickness"]
        epiphyseal_plate_width = clinical["epiphyseal_plate_width"]
        
        # Set target age based on sex
        target_age_m = 216 if sex == 1 else 204  # 18 years for males, 17 years for females
        
        # Define normal values for 18-year-old
        cortical_norm = 0.7
        plate_norm = 0.05
        
        # Calculate scores based on deviation from normal
        cortical_score = 1.0 - abs(cortical_thickness - cortical_norm) / cortical_norm
        plate_score = 1.0 - abs(epiphyseal_plate_width - plate_norm) / max(plate_norm, 1e-3)
        
        # Combine scores (weighted average)
        combined_score = 0.6 * cortical_score + 0.4 * plate_score
        
        # Map score to age estimate
        min_age, max_age = 144, 228  # 12 to 19 years in months
        estimated_age = min_age + (max_age - min_age) * combined_score
        estimated_age = max(min_age, min(max_age, estimated_age))
        
        # Adjust for sex differences
        if sex == 0:  # Female
            estimated_age = min(estimated_age, 204)
            
        return round(float(estimated_age), 1)  # Return rounded estimate
    except Exception as e:
        print(f"Warning: GP bone age estimation failed: {e}")
        return float(target_age_m)  # Return target age if estimation fails

**Growth Schedule and Morphological Transformations**

In [8]:
# Dataclass to store growth schedule parameters
@dataclass
class GrowthSchedule:
    length_gain: float  # Proportional increase in length
    width_gain: float   # Proportional increase in width
    density_gain: float # Proportional increase in density
    closure: float      # Degree of growth plate closure

# Function to calculate growth schedule parameters based on age and sex
def schedule_params_for_age(age_m: int, sex: int, base: Dict[str, float]) -> GrowthSchedule:
    a = max(0, min(216, int(age_m)))  # Clamp age to [0, 216] months
    
    # Set peak growth time based on sex
    t_peak = 144 if sex == 0 else 164  # 12 years for females, 13.67 years for males
    
    # Calculate length and width gains using logistic functions
    k_len, k_wid = 0.035, 0.028  # Growth rate constants
    length_gain = 0.35 * (1 / (1 + math.exp(-k_len * (a - t_peak))))
    width_gain = 0.18 * (1 / (1 + math.exp(-k_wid * (a - (t_peak - 10)))))
    
    # Calculate density gain using logistic function with sex-specific parameters
    if sex == 0:  # Female parameters
        L, x0, k, b = 0.564, 12.02, 0.591, 0.540
    else:  # Male parameters
        L, x0, k, b = 0.633, 13.65, 0.453, 0.514
    
    age_y = a / 12.0  # Convert to years
    dens = b + L / (1 + math.exp(-k * (age_y - x0)))  # Logistic function
    density_gain = 0.25 * (dens - b) / L  # Scale density gain
    
    # Calculate growth plate closure using logistic function
    t_close_mid = 198 if sex == 0 else 222  # Midpoint of closure
    k_close = 0.06  # Closure rate
    closure = 1 / (1 + math.exp(-k_close * (a - t_close_mid)))
    
    # Add random variation to parameters
    length_gain += random.normalvariate(0, 0.05)
    length_gain = max(0.0, min(1.0, length_gain))
    width_gain += random.normalvariate(0, 0.03)
    width_gain = max(0.0, min(1.0, width_gain))
    density_gain += random.normalvariate(0, 0.02)
    density_gain = max(0.0, min(0.25, density_gain))
    closure += random.normalvariate(0, 0.05)
    closure = max(0.0, min(1.0, closure))
    
    return GrowthSchedule(length_gain, width_gain, density_gain, closure)

# Function to morph pediatric bone image to adult-like appearance
def morph_adultize_at_age(base_img: Image.Image, sex: int, age_m: int) -> Image.Image:
    # Convert to numpy array
    u8 = np.array(base_img, dtype=np.uint8)
    
    # Get measurements and growth schedule
    meas = robust_measurements(u8)
    sch = schedule_params_for_age(age_m, sex, meas)
    
    # Convert to float [0,1]
    f = from_uint8(u8)
    
    # Calculate scaling factors
    sy = 1.0 + sch.length_gain  # Height scaling
    sx = 1.0 + sch.width_gain * 0.6  # Width scaling (less than height)
    
    # Apply affine transformation for scaling
    h, w = f.shape
    M = np.array([[sx, 0, (1 - sx) * w / 2],  # Scaling matrix
                  [0, sy, (1 - sy) * h / 2]], dtype=np.float32)
    scaled = cv2.warpAffine(f, M, (w, h), flags=cv2.INTER_LINEAR, borderMode=cv2.BORDER_REFLECT)
    
    # Apply morphological closing to simulate growth plate closure
    k = max(1, int(1 + 6 * sch.closure))
    closed = cv2.morphologyEx(scaled, cv2.MORPH_CLOSE, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k)))
    
    # Enhance cortical bone using Laplacian edges
    edges = cv2.Laplacian(closed, cv2.CV_32F, ksize=3)
    cortical = np.clip(closed + 0.15 * (sch.density_gain + sch.closure * 0.3) * edges, 0.0, 1.0)
    
    # Increase overall bone density
    dense = np.clip(cortical * (1.0 + 0.35 * sch.density_gain), 0.0, 1.0)
    
    # Convert back to uint8 and apply bilateral filter for smoothing
    u8a = to_uint8(dense)
    u8a = cv2.bilateralFilter(u8a, 5, 30, 30)
    
    return Image.fromarray(u8a)  # Return as PIL image

# Function to build parameter vector for conditioning
def build_param_vec(age_m: int, sex: int, base: Dict[str, float]) -> torch.Tensor:
    # Get growth schedule
    sch = schedule_params_for_age(age_m, sex, base)
    
    # Normalize base measurements
    bl = base.get("height", 0.0) / CFG.IMG_SIZE
    bw = base.get("width", 0.0) / CFG.IMG_SIZE
    ar = base.get("aspect", 0.0)
    area_frac = base.get("area_frac", base.get("area", 0.0) / (CFG.IMG_SIZE * CFG.IMG_SIZE))
    dens = (base.get("density", 128.0) / 255.0)
    
    # Create parameter vector
    vec = [
        sch.length_gain, sch.width_gain, sch.density_gain, sch.closure,  # Growth parameters
        bl, bw, ar, area_frac, dens,  # Normalized measurements
        float(sex), age_m / 216.0,  # Sex and normalized age
        0.0, 0.0, 0.0  # Padding
    ]
    
    return torch.tensor(vec[:14], dtype=torch.float32)  # Return as tensor

**Dataset Class**

In [9]:
# Dataset class for bone age conditional diffusion
class BoneAgeCondDiffusionDS(Dataset):
    def __init__(self, csv_path: str, img_dir: str, img_size: int = CFG.IMG_SIZE, max_samples: Optional[int] = CFG.MAX_SAMPLES):
        # Load CSV data
        self.df = pd.read_csv(csv_path).copy()
        self.df["male"] = self.df["male"].astype(int)  # Convert male column to int
        
        # Index image directory
        self.idx = index_image_dir(img_dir)
        
        # Filter dataframe to only include rows with images
        keep = self.df["id"].astype(int).isin(self.idx.keys())
        self.df = self.df[keep].reset_index(drop=True)
        
        # Add image paths to dataframe
        self.df["img_path"] = self.df["id"].astype(int).map(self.idx.get)
        
        # Limit number of samples if specified
        if max_samples is not None and len(self.df) > max_samples:
            self.df = self.df.sample(max_samples, random_state=CFG.SEED).reset_index(drop=True)
        
        self.size = img_size  # Image size
        self.training = False  # Training mode flag
        
        # Define augmentation transforms
        self.augment = transforms.Compose([
            transforms.RandomHorizontalFlip(p=0.5),  # Random horizontal flip
            transforms.RandomRotation(degrees=10),   # Random rotation
            transforms.RandomAffine(degrees=0, scale=(0.9, 1.1), translate=(0.05, 0.05)),  # Random scale and translation
            transforms.ColorJitter(brightness=0.2, contrast=0.2)  # Random brightness and contrast
        ])
        
        # Check if dataframe is empty
        if len(self.df) == 0:
            raise RuntimeError("No images matched IDs.")
    
    def __len__(self): 
        return len(self.df)  # Return number of samples
    
    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor | int]:
        # Get row data
        row = self.df.iloc[idx]
        sex = int(row["male"])
        path = row["img_path"]
        
        # Load and process base image
        base_im = load_gray_png_any(path, self.size)
        if self.training:
            base_im = self.augment(base_im)  # Apply augmentations if in training mode
        
        # Convert to numpy array
        base_u8 = np.array(base_im, dtype=np.uint8)
        
        # Get measurements and mask
        base_meas, mask_full = robust_measurements_and_mask(base_u8)
        
        # Generate target image (adult-like)
        tgt_im = morph_adultize_at_age(base_im, sex, CFG.TARGET_AGE)
        if self.training:
            tgt_im = self.augment(tgt_im)  # Apply augmentations if in training mode
        
        # Convert images to tensors
        pedi_t = pil_to_tensor_gray(base_im).squeeze(0)  # Pediatric image tensor
        mask_t = torch.from_numpy((mask_full.astype(np.float32)/255.0))[None, :, :]  # Mask tensor
        mask_t = mask_t * 2.0 - 1.0  # Normalize to [-1,1]
        adult_t = pil_to_tensor_gray(tgt_im).squeeze(0)  # Adult image tensor
        
        # Build parameter vector
        params = build_param_vec(CFG.TARGET_AGE, sex, base_meas)
        
        # Return sample dictionary
        return {
            "pediatric_img": pedi_t,
            "pediatric_mask": mask_t,
            "adult_img": adult_t,
            "params": params,
            "sex": sex,
            "id": int(row["id"])
        }
    
    def train(self):
        self.training = True  # Enable training mode
    
    def eval(self):
        self.training = False  # Disable training mode

**Model Architecture with Attention**

In [10]:
# Cross-Attention module for incorporating condition information
class CrossAttention(nn.Module):
    def __init__(self, in_ch: int, cdim: int, heads: int = CFG.ATTN_HEADS, dim_head: int = CFG.ATTN_DIM):
        super().__init__()
        self.heads = heads  # Number of attention heads
        self.scale = dim_head ** -0.5  # Scaling factor for attention
        
        # Linear layers for queries, keys, and values
        self.to_q = nn.Linear(in_ch, heads * dim_head, bias=False)
        self.to_k = nn.Linear(cdim, heads * dim_head, bias=False)
        self.to_v = nn.Linear(cdim, heads * dim_head, bias=False)
        
        # Output projection
        self.to_out = nn.Linear(heads * dim_head, in_ch)
        
        # Normalization layer
        self.norm = nn.GroupNorm(8, in_ch)

    def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape  # Get dimensions
        
        # Normalize input
        x_norm = self.norm(x)
        
        # Flatten spatial dimensions
        x_flat = x_norm.permute(0, 2, 3, 1).reshape(B, H*W, C)  # [B, HW, C]
        
        # Compute queries, keys, and values
        q = self.to_q(x_flat).view(B, H*W, self.heads, -1).permute(0, 2, 1, 3)  # [B, heads, HW, dim_head]
        k = self.to_k(context).view(B, 1, self.heads, -1).permute(0, 2, 1, 3)   # [B, heads, 1, dim_head]
        v = self.to_v(context).view(B, 1, self.heads, -1).permute(0, 2, 1, 3)   # [B, heads, 1, dim_head]
        
        # Compute attention scores
        attn = (q @ k.transpose(-2, -1)) * self.scale  # [B, heads, HW, 1]
        attn = F.softmax(attn, dim=-1)  # Normalize attention scores
        
        # Apply attention to values
        out = (attn @ v).permute(0, 2, 1, 3).reshape(B, H*W, -1)  # [B, HW, heads*dim_head]
        
        # Project output and reshape
        out = self.to_out(out).reshape(B, C, H, W)  # [B, C, H, W]
        
        return x + out  # Residual connection

# Sinusoidal time embedding
class SinusoidalTime(nn.Module):
    def __init__(self, dim: int):
        super().__init__()
        self.dim = dim  # Embedding dimension
    
    def forward(self, t: torch.Tensor) -> torch.Tensor:
        half = self.dim // 2  # Half dimension for sin and cos
        
        # Compute frequencies
        freqs = torch.exp(torch.arange(half, device=t.device) * -(math.log(10000.0) / (half - 1)))
        
        # Compute arguments for sin and cos
        args = t[:, None] * freqs[None, :]
        
        # Compute sin and cos embeddings
        emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
        
        # Pad if dimension is odd
        if self.dim % 2 == 1: 
            emb = F.pad(emb, (0, 1))
            
        return emb

# FiLM (Feature-wise Linear Modulation) layer
class FiLM(nn.Module):
    def __init__(self, in_dim: int, ch: int):
        super().__init__()
        # Network to compute scale and shift
        self.net = nn.Sequential(nn.Linear(in_dim, 128), nn.SiLU(), nn.Linear(128, 2 * ch))
    
    def forward(self, x: torch.Tensor, c: torch.Tensor):
        # Compute scale and shift
        gb = self.net(c)
        g, b = gb.chunk(2, dim=-1)
        
        # Reshape for broadcasting
        g = g.unsqueeze(-1).unsqueeze(-1)
        b = b.unsqueeze(-1).unsqueeze(-1)
        
        # Apply FiLM transformation
        return x * (1 + g) + b

# Residual block with time and condition embeddings
class ResBlock(nn.Module):
    def __init__(self, in_ch: int, out_ch: int, tdim: int, cdim: int):
        super().__init__()
        # Normalization and activation
        self.norm1 = nn.GroupNorm(8, in_ch)
        self.act1 = nn.SiLU()
        
        # First convolution
        self.conv1 = nn.Conv2d(in_ch, out_ch, 3, padding=1)
        
        # Time embedding projection
        self.time = nn.Sequential(nn.SiLU(), nn.Linear(tdim, out_ch))
        
        # Condition embedding projection (FiLM)
        self.cond = FiLM(cdim, out_ch)
        
        # Second normalization and activation
        self.norm2 = nn.GroupNorm(8, out_ch)
        self.act2 = nn.SiLU()
        
        # Second convolution
        self.conv2 = nn.Conv2d(out_ch, out_ch, 3, padding=1)
        
        # Skip connection if channels change
        self.skip = nn.Conv2d(in_ch, out_ch, 1) if in_ch != out_ch else nn.Identity()
    
    def forward(self, x, t, c):
        # First block
        h = self.conv1(self.act1(self.norm1(x)))
        
        # Add time embedding
        h = h + self.time(t)[:, :, None, None]
        
        # Apply condition (FiLM)
        h = self.cond(h, c)
        
        # Second block
        h = self.conv2(self.act2(self.norm2(h)))
        
        # Skip connection
        return h + self.skip(x)

# Downsampling block
class Down(nn.Module):
    def __init__(self, in_ch, out_ch, tdim, cdim):
        super().__init__()
        # Two residual blocks
        self.b1 = ResBlock(in_ch, out_ch, tdim, cdim)
        self.b2 = ResBlock(out_ch, out_ch, tdim, cdim)
        
        # Downsampling convolution
        self.pool = nn.Conv2d(out_ch, out_ch, 3, stride=2, padding=1)
    
    def forward(self, x, t, c):
        # Apply residual blocks
        x = self.b1(x, t, c)
        x = self.b2(x, t, c)
        
        # Save skip connection
        skip = x
        
        # Downsample
        x = self.pool(x)
        
        return x, skip

# Upsampling block
class Up(nn.Module):
    def __init__(self, in_ch, skip_ch, out_ch, tdim, cdim, use_attn=False):
        super().__init__()
        # Upsampling convolution
        self.up = nn.ConvTranspose2d(in_ch, out_ch, 2, stride=2)
        
        # Two residual blocks
        self.b1 = ResBlock(out_ch + skip_ch, out_ch, tdim, cdim)
        self.b2 = ResBlock(out_ch, out_ch, tdim, cdim)
        
        # Optional attention
        self.attn = CrossAttention(out_ch, cdim, CFG.ATTN_HEADS, CFG.ATTN_DIM) if use_attn else None
    
    def forward(self, x, skip, t, c):
        # Upsample
        x = self.up(x)
        
        # Concatenate with skip connection
        x = torch.cat([x, skip], dim=1)
        
        # Apply residual blocks
        x = self.b1(x, t, c)
        x = self.b2(x, t, c)
        
        # Apply attention if enabled
        if self.attn is not None:
            x = self.attn(x, c)
        
        return x

# Conditional U-Net model
class CondUNet(nn.Module):
    def __init__(self, cvec_dim=14, base_ch=64):  # Reduced base channels for GPU limits
        super().__init__()
        
        # Time embedding dimension
        tdim = base_ch * 4
        
        # Time embedding network
        self.time = nn.Sequential(SinusoidalTime(tdim), nn.Linear(tdim, tdim), nn.SiLU(), nn.Linear(tdim, tdim))
        
        # Condition vector projection
        self.cproj = nn.Sequential(nn.Linear(cvec_dim, 128), nn.SiLU(), nn.Linear(128, 128))
        cdim = 128  # Condition dimension after projection
        
        # Input convolution
        self.inp = nn.Conv2d(3, base_ch, 3, padding=1)  # 3 channels: noisy + pediatric + mask
        
        # Downsampling blocks
        self.d1 = Down(base_ch, base_ch, tdim, cdim)
        self.d2 = Down(base_ch, base_ch*2, tdim, cdim)
        self.d3 = Down(base_ch*2, base_ch*4, tdim, cdim)
        self.d4 = Down(base_ch*4, base_ch*4, tdim, cdim)
        
        # Middle blocks
        self.mid1 = ResBlock(base_ch*4, base_ch*4, tdim, cdim)
        self.mid_attn = CrossAttention(base_ch*4, cdim, CFG.ATTN_HEADS, CFG.ATTN_DIM)
        self.mid2 = ResBlock(base_ch*4, base_ch*4, tdim, cdim)
        
        # Upsampling blocks
        self.u4 = Up(base_ch*4, base_ch*4, base_ch*4, tdim, cdim, use_attn=True)
        self.u3 = Up(base_ch*4, base_ch*4, base_ch*2, tdim, cdim, use_attn=True)
        self.u2 = Up(base_ch*2, base_ch*2, base_ch, tdim, cdim)
        self.u1 = Up(base_ch, base_ch, base_ch, tdim, cdim)
        
        # Output normalization and convolution
        self.out_norm = nn.GroupNorm(8, base_ch)
        self.out = nn.Conv2d(base_ch, 1, 3, padding=1)  # Single channel output

    def forward(self, x_noisy: torch.Tensor, cond_2ch: torch.Tensor, t_idx: torch.Tensor, cvec: torch.Tensor):
        # Compute time and condition embeddings
        t = self.time(t_idx.float())
        c = self.cproj(cvec)
        
        # Concatenate noisy image with condition (pediatric + mask)
        x = torch.cat([x_noisy, cond_2ch], dim=1)
        
        # Apply input convolution
        x = self.inp(x)
        
        # Apply downsampling blocks
        x, s1 = self.d1(x, t, c)
        x, s2 = self.d2(x, t, c)
        x, s3 = self.d3(x, t, c)
        x, s4 = self.d4(x, t, c)
        
        # Apply middle blocks
        x = self.mid1(x, t, c)
        x = self.mid_attn(x, c)
        x = self.mid2(x, t, c)
        
        # Apply upsampling blocks
        x = self.u4(x, s4, t, c)
        x = self.u3(x, s3, t, c)
        x = self.u2(x, s2, t, c)
        x = self.u1(x, s1, t, c)
        
        # Apply output normalization and activation
        x = F.silu(self.out_norm(x))
        
        # Return output
        return self.out(x)

# Function to load pretrained U-Net weights
def load_pretrained_unet(net: CondUNet, pretrained_model="google/ddpm-cifar10-32"):
    try:
        # Load pretrained model
        pretrained = UNet2DModel.from_pretrained(pretrained_model).to(DEVICE)
        state_dict = pretrained.state_dict()
        
        # Get model state dict
        net_dict = net.state_dict()
        
        # Copy matching weights
        for k, v in state_dict.items():
            if k in net_dict and v.shape == net_dict[k].shape:
                net_dict[k] = v
        
        # Load weights
        net.load_state_dict(net_dict, strict=False)
        print(f"Loaded pretrained weights from {pretrained_model}")
    except Exception as e:
        print(f"Failed to load pretrained weights: {e}")
    
    return net

**Diffusion Scheduler**

In [11]:
# Dataclass for diffusion scheduler parameters
@dataclass
class DiffSched:
    betas: torch.Tensor      # Noise schedule
    alphas: torch.Tensor      # 1 - betas
    ac: torch.Tensor         # Cumulative product of alphas
    sqrt_ac: torch.Tensor    # Square root of ac
    sqrt_om: torch.Tensor    # Square root of 1 - ac

# Function to create diffusion scheduler
def make_scheduler(T: int = CFG.T_STEPS, beta_start=1e-4, beta_end=0.02) -> DiffSched:
    # Linear noise schedule
    betas = torch.linspace(beta_start, beta_end, T, dtype=torch.float32, device=DEVICE)
    
    # Compute alphas and cumulative products
    alphas = 1.0 - betas
    ac = torch.cumprod(alphas, dim=0)
    
    # Return scheduler
    return DiffSched(
        betas=betas, 
        alphas=alphas, 
        ac=ac,
        sqrt_ac=torch.sqrt(ac), 
        sqrt_om=torch.sqrt(1.0 - ac),
    )

# Create global scheduler
SCHED = make_scheduler()

**Loss Function**

In [12]:
# Function to compute soft area fraction
def _soft_area_frac(x0_hat: torch.Tensor, k: float = CFG.SOFT_K) -> torch.Tensor:
    s = torch.sigmoid(k * x0_hat)  # Apply sigmoid with scaling
    return s.mean(dim=[1,2,3])  # Mean over spatial dimensions

# Function to compute expected area fraction from condition vector
def _expected_area_frac(cvec: torch.Tensor) -> torch.Tensor:
    # Extract gains from condition vector
    len_gain = cvec[:, 0]
    wid_gain = cvec[:, 1]
    base_area_frac = cvec[:, 7].clamp(0.0, 1.0)
    
    # Compute expected area with scaling
    scale = (1.0 + len_gain).clamp(0.5, 2.0) * (1.0 + 0.6 * wid_gain).clamp(0.5, 2.0)
    exp = (base_area_frac * scale).clamp(CFG.AREA_MIN, CFG.AREA_MAX)
    
    return exp

# DDPM loss function
def ddpm_loss_step(net: CondUNet, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
    # Get batch data
    x0 = batch["adult_img"].to(DEVICE)  # Target image
    pedi = batch["pediatric_img"].to(DEVICE)  # Pediatric image
    mask = batch["pediatric_mask"].to(DEVICE)  # Mask
    cond = torch.cat([pedi, mask], dim=1)  # Concatenate condition
    cvec = batch["params"].to(DEVICE)  # Parameter vector
    
    B = x0.size(0)  # Batch size
    
    # Sample random timesteps
    t = torch.randint(0, CFG.T_STEPS, (B,), device=DEVICE)
    
    # Sample noise
    eps = torch.randn_like(x0)
    
    # Get scheduler values for timesteps
    sqrt_ac = SCHED.sqrt_ac[t].view(B,1,1,1)
    sqrt_om = SCHED.sqrt_om[t].view(B,1,1,1)
    
    # Forward diffusion: add noise to image
    x_t = sqrt_ac * x0 + sqrt_om * eps
    
    # Classifier-free guidance: randomly drop condition
    drop = (torch.rand(B, device=DEVICE) < CFG.CFG_DROP).float().view(B,1,1,1)
    cond_cf = (1 - drop) * cond
    cvec_cf = (1 - drop[:,0,0,0])[:,None] * cvec
    
    # Predict noise
    eps_pred = net(x_t, cond_cf, t, cvec_cf)
    
    # Compute simple MSE loss
    l_simple = F.mse_loss(eps_pred, eps)
    
    # Compute area loss
    with torch.no_grad():
        ac_t = SCHED.ac[t].view(B,1,1,1)
    
    # Predict x0 from noise
    x0_hat = (x_t - torch.sqrt(1 - ac_t) * eps_pred) / torch.sqrt(ac_t + 1e-8)
    
    # Compute predicted area
    area_pred = _soft_area_frac(x0_hat, CFG.SOFT_K)
    
    # Compute expected area
    area_exp = _expected_area_frac(cvec)
    
    # Compute area loss
    l_area = F.mse_loss(area_pred, area_exp)
    
    # Return combined loss
    return l_simple + CFG.AREA_LAMBDA * l_area

**DDIM Sampler**

In [13]:
# DDIM sampling function
@torch.no_grad()
def ddim_sample(net: CondUNet, pedi: torch.Tensor, mask: torch.Tensor, cvec: torch.Tensor, 
                steps: int = CFG.DDIM_STEPS, eta: float = 0.0, cfg_w: float = CFG.CFG_WEIGHT) -> torch.Tensor:
    # Set model to evaluation mode
    net.eval()
    
    # Get timestep parameters
    T = CFG.T_STEPS
    ts = torch.linspace(T-1, 0, steps, dtype=torch.long, device=DEVICE)  # Timesteps
    
    # Initialize with random noise
    x = torch.randn_like(pedi)
    
    # Concatenate pediatric image and mask
    cond = torch.cat([pedi, mask], dim=1)
    
    # Define function to predict noise with classifier-free guidance
    def eps_hat(x, t, c_in, p_in):
        # Predict noise with condition
        e_cond = net(x, c_in, t, p_in)
        
        # Predict noise without condition
        e_uncond = net(x, torch.zeros_like(c_in), t, torch.zeros_like(p_in))
        
        # Combine predictions with guidance weight
        return e_uncond + cfg_w * (e_cond - e_uncond)
    
    # Set up automatic mixed precision context
    scaler_ctx = torch.amp.autocast('cuda', enabled=CFG.AMP)
    
    # Iteratively denoise
    for i in range(steps):
        # Current timestep
        t = ts[i].repeat(pedi.size(0))
        
        # Get scheduler values
        ac_t = SCHED.ac[t].view(-1,1,1,1)
        ac_prev = SCHED.ac[ts[i+1]].view(-1,1,1,1) if i+1 < steps else torch.ones_like(ac_t)
        
        # Predict noise
        with scaler_ctx:
            eps = eps_hat(x, t, cond, cvec)
        
        # Predict x0
        x0_hat = (x - torch.sqrt(1 - ac_t) * eps) / torch.sqrt(ac_t + 1e-8)
        
        # DDIM update step
        if eta == 0.0:  # Deterministic
            x = torch.sqrt(ac_prev) * x0_hat + torch.sqrt(1 - ac_prev) * eps
        else:  # Stochastic
            sigma = eta * torch.sqrt((1 - ac_prev) / (1 - ac_t) * (1 - ac_t / ac_prev))
            noise = torch.randn_like(x)
            x = torch.sqrt(ac_prev) * x0_hat + torch.sqrt(torch.clamp(1 - ac_prev - sigma**2, 0.0)) * eps + sigma * noise
    
    # Clamp to valid range and return
    return x.clamp(-1, 1)

**Save/Load Helpers**

In [14]:
# Function to unwrap state dict for saving (handles compiled models)
def unwrap_state_dict_for_save(model: torch.nn.Module) -> Dict[str, torch.Tensor]:
    m = model
    
    # Handle compiled models
    if hasattr(m, "_orig_mod") and isinstance(getattr(m, "_orig_mod"), torch.nn.Module):
        m = m._orig_mod
    
    # Handle DataParallel models
    if hasattr(m, "module") and isinstance(getattr(m, "module"), torch.nn.Module):
        m = m.module
    
    return m.state_dict()

# Function to clean state dict prefixes (for loading)
def clean_state_dict_prefixes(state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
    # Handle compiled model prefixes
    if any(k.startswith("_orig_mod.") for k in state.keys()):
        state = {k.replace("_orig_mod.", "", 1): v for k, v in state.items()}
    
    # Handle DataParallel prefixes
    if any(k.startswith("module.") for k in state.keys()):
        state = {k.replace("module.", "", 1): v for k, v in state.items()}
    
    return state

**Metrics Calculation**

In [15]:
# Function to compute PSNR (Peak Signal-to-Noise Ratio)
def compute_psnr(pred_u8: np.ndarray, tgt_u8: np.ndarray) -> float:
    # Compute MSE
    mse = float(np.mean((pred_u8.astype(np.float32) - tgt_u8.astype(np.float32)) ** 2))
    
    # Handle perfect match case
    if mse <= 1e-12: 
        return 99.0
    
    # Compute PSNR
    return 20.0 * math.log10(255.0) - 10.0 * math.log10(mse)

# Function to compute SSIM (Structural Similarity Index)
def compute_ssim(pred_u8: np.ndarray, tgt_u8: np.ndarray) -> float:
    # Use skimage implementation if available
    if skimage_ssim is not None:
        return float(skimage_ssim(pred_u8, tgt_u8, data_range=255))
    
    # Fallback to manual implementation
    pred = pred_u8.astype(np.float32)
    tgt = tgt_u8.astype(np.float32)
    
    # Constants for SSIM calculation
    C1, C2 = (0.01 * 255)**2, (0.03 * 255)**2
    
    # Gaussian kernel
    k = 11
    g = cv2.getGaussianKernel(k, 1.5)
    w = g @ g.T
    
    # Compute means
    mu1 = cv2.filter2D(pred, -1, w)
    mu2 = cv2.filter2D(tgt, -1, w)
    
    # Compute squares of means
    mu1_sq, mu2_sq = mu1*mu1, mu2*mu2
    mu1_mu2 = mu1*mu2
    
    # Compute variances and covariance
    sigma1_sq = cv2.filter2D(pred*pred, -1, w) - mu1_sq
    sigma2_sq = cv2.filter2D(tgt*tgt, -1, w) - mu2_sq
    sigma12 = cv2.filter2D(pred*tgt, -1, w) - mu1_mu2
    
    # Compute SSIM
    ssim_map = ((2*mu1_mu2 + C1)*(2*sigma12 + C2))/((mu1_sq + mu2_sq + C1)*(sigma1_sq + sigma2_sq + C2) + 1e-12)
    
    return float(ssim_map.mean())

# Function to compute MS-SSIM (Multi-Scale SSIM)
def compute_msssim(pred_u8: np.ndarray, tgt_u8: np.ndarray) -> Optional[float]:
    # Return None if torch_ms_ssim is not available
    if torch_ms_ssim is None:
        return None
    
    # Convert to tensors
    pred = torch.from_numpy(pred_u8).float().unsqueeze(0).unsqueeze(0) / 255.0
    tgt = torch.from_numpy(tgt_u8).float().unsqueeze(0).unsqueeze(0) / 255.0
    
    # Compute MS-SSIM
    with torch.no_grad():
        return float(torch_ms_ssim(pred, tgt, data_range=1.0).item())

# Function to compute LPIPS (Learned Perceptual Image Patch Similarity)
def compute_lpips(pred_u8: np.ndarray, tgt_u8: np.ndarray, lpips_alex=None) -> Optional[float]:
    # Return None if lpips is not available
    if lpips is None:
        return None
    
    try:
        # Initialize LPIPS model if not provided
        if lpips_alex is None:
            lpips_alex = lpips.LPIPS(net='alex').to(DEVICE)
        
        # Prepare images for LPIPS
        def prep(x):
            t = torch.from_numpy(x).float()/255.0
            t = (t*2.0 - 1.0).unsqueeze(0).unsqueeze(0).repeat(1,3,1,1).to(DEVICE)
            return t
        
        # Compute LPIPS
        with torch.no_grad():
            d = lpips_alex(prep(pred_u8), prep(tgt_u8))
            return float(d.item())
    except Exception:
        return None

**Training Function**

In [16]:
# Training function with automatic mixed precision and gradient accumulation
def train_diffusion_amp_accum(
    net: CondUNet,
    train_loader: DataLoader,
    val_loader: Optional[DataLoader] = None,
    epochs: int = CFG.EPOCHS,
    save_path: str = os.path.join(CFG.SAVE_DIR, "cond_diffusion_mask_prior.pth")
):
    # Move model to device
    net = net.to(DEVICE)
    
    # Compile model if enabled
    if CFG.COMPILE and hasattr(torch, "compile"):
        try: 
            net = torch.compile(net)
        except Exception: 
            pass
    
    # Initialize optimizer and scheduler
    opt = torch.optim.AdamW(net.parameters(), lr=CFG.LR, weight_decay=1e-4)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs, eta_min=1e-6)
    
    # Initialize gradient scaler for mixed precision
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.AMP and DEVICE.type == "cuda")
    
    # Track best validation loss
    best = float("inf")
    
    # Print training configuration
    print(f"AMP={CFG.AMP}, ACCUM_STEPS={CFG.ACCUM_STEPS}, eff_batch={CFG.BATCH*CFG.ACCUM_STEPS}")
    
    # Training loop
    for ep in range(1, epochs + 1):
        # Set model to training mode
        net.train()
        
        # Initialize loss and accumulation counter
        tot, acc = 0.0, 0
        
        # Create progress bar
        pbar = tqdm(train_loader, desc=f"Train Ep {ep}/{epochs}")
        
        # Zero gradients
        opt.zero_grad(set_to_none=True)
        
        # Iterate over batches
        for it, batch in enumerate(pbar, 1):
            # Forward pass with mixed precision
            with torch.amp.autocast('cuda', enabled=CFG.AMP):
                loss = ddpm_loss_step(net, batch) / CFG.ACCUM_STEPS
            
            # Backward pass
            scaler.scale(loss).backward()
            acc += 1
            
            # Update weights if accumulation steps reached or end of epoch
            if acc == CFG.ACCUM_STEPS or it == len(train_loader):
                # Unscale gradients
                scaler.unscale_(opt)
                
                # Clip gradients
                torch.nn.utils.clip_grad_norm_(net.parameters(), CFG.CLIP_NORM)
                
                # Update weights
                scaler.step(opt)
                scaler.update()
                
                # Zero gradients
                opt.zero_grad(set_to_none=True)
                acc = 0
            
            # Update total loss
            tot += float(loss.detach().cpu()) * CFG.ACCUM_STEPS
            
            # Update progress bar
            pbar.set_postfix({"loss": f"{(tot/it):.4f}", "lr": f"{opt.param_groups[0]['lr']:.6f}"})
        
        # Update learning rate
        scheduler.step()
        
        # Calculate average training loss
        tl = tot / max(1, len(train_loader))
        
        # Validation
        if val_loader is not None:
            # Set model to evaluation mode
            net.eval()
            
            # Initialize validation loss
            vtot = 0.0
            
            # Iterate over validation batches
            with torch.no_grad(), torch.amp.autocast('cuda', enabled=CFG.AMP):
                for vb in val_loader:
                    vtot += float(ddpm_loss_step(net, vb).detach().cpu())
            
            # Calculate average validation loss
            vl = vtot / max(1, len(val_loader))
        else:
            vl = tl
        
        # Print epoch results
        print(f"[Epoch {ep}] train={tl:.4f} val={vl:.4f} lr={opt.param_groups[0]['lr']:.6f}")
        
        # Save best model
        if vl < best:
            best = vl
            clean_sd = unwrap_state_dict_for_save(net)
            torch.save({"model": clean_sd}, save_path)
            print(f"Saved {save_path} (best {best:.4f})")
    
    return save_path

**Inference and Export Function**

In [17]:
# Function to run inference and export results
@torch.no_grad()
def infer_val_and_export(
    net_ckpt: str,
    ds: BoneAgeCondDiffusionDS,
    val_indices: List[int],
    out_dir: str = CFG.OUT_DIR,
    excel_path: str = CFG.MEAS_XLSX
):
    # Create output directory
    os.makedirs(out_dir, exist_ok=True)
    
    # Initialize model
    net = CondUNet(cvec_dim=14).to(DEVICE)
    
    # Load checkpoint
    sd = torch.load(net_ckpt, map_location=DEVICE)
    raw = sd["model"] if isinstance(sd, dict) and "model" in sd else sd
    raw = clean_state_dict_prefixes(raw)
    net.load_state_dict(raw, strict=False)
    
    # Initialize LPIPS model
    lpips_alex = None
    if lpips is not None:
        try:
            lpips_alex = lpips.LPIPS(net='alex').to(DEVICE).eval()
        except Exception:
            lpips_alex = None
    
    # Try to load font
    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except Exception:
        font = ImageFont.load_default()
    
    # Initialize records and paths
    records, paths = [], []
    
    # Iterate over validation indices
    for i in tqdm(val_indices, desc="Infer+Export (val)"):
        # Get row data
        row = ds.df.iloc[i]
        path = row["img_path"]
        sex = int(row["male"])
        boneage = int(row.get("boneage", -1))
        
        # Extract image ID
        img_id = int(_ID_RE.search(os.path.basename(path)).group(1)) if _ID_RE.search(os.path.basename(path)) else i
        sex_str = "Male" if sex == 1 else "Female"
        
        # Load and process image
        base_im = load_gray_png_any(path, CFG.IMG_SIZE)
        base_u8 = np.array(base_im, dtype=np.uint8)
        
        # Get measurements and mask
        pedi_meas, mask_full = robust_measurements_and_mask(base_u8)
        
        # Convert to tensors
        pedi = pil_to_tensor_gray(base_im).to(DEVICE)
        mask = torch.from_numpy((mask_full.astype(np.float32)/255.0))[None,None,:,:].to(DEVICE) * 2.0 - 1.0
        
        # Build parameter vector
        cvec = build_param_vec(CFG.TARGET_AGE, sex, pedi_meas).unsqueeze(0).to(DEVICE)
        
        # Generate prediction
        pred = ddim_sample(net, pedi, mask, cvec, steps=CFG.DDIM_STEPS, eta=0.0, cfg_w=CFG.CFG_WEIGHT)
        pred_im = tensor_to_pil_gray(pred)
        pred_u8 = np.array(pred_im, dtype=np.uint8)
        
        # Get measurements for prediction
        pred_meas, pred_mask = robust_measurements_and_mask(pred_u8)
        
        # Generate target image
        tgt_im = morph_adultize_at_age(base_im, sex, CFG.TARGET_AGE)
        tgt_u8 = np.array(tgt_im, dtype=np.uint8)
        
        # Compute metrics
        psnr = compute_psnr(pred_u8, tgt_u8)
        ssim_val = compute_ssim(pred_u8, tgt_u8)
        msssim_val = compute_msssim(pred_u8, tgt_u8)
        lpips_val = compute_lpips(pred_u8, tgt_u8, lpips_alex)
        
        # Compute clinical metrics
        pedi_clinical = compute_clinical_metrics(base_u8, mask_full)
        pred_clinical = compute_clinical_metrics(pred_u8, pred_mask)
        
        # Estimate GP bone age
        pred_gp_bone_age = estimate_gp_bone_age(pred_u8, pred_mask, sex)
        
        # Create composite image
        composite = Image.new('L', (2 * CFG.IMG_SIZE, CFG.IMG_SIZE + 30))
        composite.paste(base_im, (0, 30))
        composite.paste(pred_im, (CFG.IMG_SIZE, 30))
        
        # Add text annotations
        draw = ImageDraw.Draw(composite)
        draw.text((10, 5), f"Pediatric: ID={img_id}, Age={boneage}m, {sex_str}", fill=255, font=font)
        draw.text((CFG.IMG_SIZE + 10, 5), f"Predicted: 216m (GP est: {pred_gp_bone_age}m)", fill=255, font=font)
        
        # Save composite image
        fp = os.path.join(out_dir, f"{img_id}_{boneage}m_{sex_str}_vs_predicted_216m.png")
        composite.save(fp)
        paths.append(fp)
        
        # Add record
        records.append({
            "id": img_id, "boneage": boneage, "sex": sex_str,
            "pediatric_area": pedi_meas["area"], "pediatric_height": pedi_meas["height"],
            "pediatric_width": pedi_meas["width"], "pediatric_aspect": pedi_meas["aspect"],
            "pediatric_density": pedi_meas["density"], "pediatric_width_length_ratio": pedi_meas["width_length_ratio"],
            "pediatric_cortical_thickness": pedi_clinical["cortical_thickness"],
            "pediatric_epiphyseal_plate_width": pedi_clinical["epiphyseal_plate_width"],
            "predicted_area": pred_meas["area"], "predicted_height": pred_meas["height"],
            "predicted_width": pred_meas["width"], "predicted_aspect": pred_meas["aspect"],
            "predicted_density": pred_meas["density"], "predicted_width_length_ratio": pred_meas["width_length_ratio"],
            "predicted_cortical_thickness": pred_clinical["cortical_thickness"],
            "predicted_epiphyseal_plate_width": pred_clinical["epiphyseal_plate_width"],
            "predicted_gp_bone_age_months": pred_gp_bone_age,
            "psnr_pred_vs_target": psnr, "ssim_pred_vs_target": ssim_val,
            "msssim_pred_vs_target": msssim_val, "lpips_pred_vs_target": lpips_val,
            "output_path": fp,
        })
    
    # Create DataFrame from records
    df = pd.DataFrame.from_records(records)
    
    # Define column order
    cols = ["id", "boneage", "sex",
            "pediatric_area", "pediatric_height", "pediatric_width", "pediatric_aspect", "pediatric_density",
            "pediatric_width_length_ratio", "pediatric_cortical_thickness", "pediatric_epiphyseal_plate_width",
            "predicted_area", "predicted_height", "predicted_width", "predicted_aspect", "predicted_density",
            "predicted_width_length_ratio", "predicted_cortical_thickness", "predicted_epiphyseal_plate_width",
            "predicted_gp_bone_age_months",
            "psnr_pred_vs_target", "ssim_pred_vs_target", "msssim_pred_vs_target", "lpips_pred_vs_target",
            "output_path"]
    
    # Reorder columns
    df = df[cols]
    
    # Save to Excel
    df.to_excel(excel_path, index=False)
    print(f"Saved {len(paths)} comparison images to {out_dir}")
    print(f"Excel: {excel_path}")

    # Visualize metrics
    metrics = ["area", "height", "width", "density", "cortical_thickness", "epiphyseal_plate_width"]
    plt.figure(figsize=(15, 10))
    
    # Create subplots for each metric
    for i, metric in enumerate(metrics, 1):
        plt.subplot(2, 3, i)
        plt.hist(df[f"pediatric_{metric}"], bins=30, alpha=0.5, label="Pediatric", color="blue")
        plt.hist(df[f"predicted_{metric}"], bins=30, alpha=0.5, label="Predicted (216m)", color="orange")
        plt.title(f"{metric.capitalize()} Distribution")
        plt.xlabel(metric.capitalize())
        plt.ylabel("Count")
        plt.legend()
    
    # Adjust layout and save
    plt.tight_layout()
    plot_path = os.path.join(out_dir, "metrics_histogram.png")
    plt.savefig(plot_path, dpi=300)
    plt.close()
    print(f"Metrics visualization saved at: {plot_path}")

    return paths, excel_path

**Inference and Export Function**

In [18]:
# Function to run inference and export results
@torch.no_grad()
def infer_val_and_export(
    net_ckpt: str,
    ds: BoneAgeCondDiffusionDS,
    val_indices: List[int],
    out_dir: str = CFG.OUT_DIR,
    excel_path: str = CFG.MEAS_XLSX
):
    # Create output directory
    os.makedirs(out_dir, exist_ok=True)
    
    # Initialize model
    net = CondUNet(cvec_dim=14).to(DEVICE)
    
    # Load checkpoint
    sd = torch.load(net_ckpt, map_location=DEVICE)
    raw = sd["model"] if isinstance(sd, dict) and "model" in sd else sd
    raw = clean_state_dict_prefixes(raw)
    net.load_state_dict(raw, strict=False)
    
    # Initialize LPIPS model
    lpips_alex = None
    if lpips is not None:
        try:
            lpips_alex = lpips.LPIPS(net='alex').to(DEVICE).eval()
        except Exception:
            lpips_alex = None
    
    # Try to load font
    try:
        font = ImageFont.truetype("arial.ttf", 16)
    except Exception:
        font = ImageFont.load_default()
    
    # Initialize records and paths
    records, paths = [], []
    
    # Iterate over validation indices
    for i in tqdm(val_indices, desc="Infer+Export (val)"):
        # Get row data
        row = ds.df.iloc[i]
        path = row["img_path"]
        sex = int(row["male"])
        boneage = int(row.get("boneage", -1))
        
        # Extract image ID
        img_id = int(_ID_RE.search(os.path.basename(path)).group(1)) if _ID_RE.search(os.path.basename(path)) else i
        sex_str = "Male" if sex == 1 else "Female"
        
        # Load and process image
        base_im = load_gray_png_any(path, CFG.IMG_SIZE)
        base_u8 = np.array(base_im, dtype=np.uint8)
        
        # Get measurements and mask
        pedi_meas, mask_full = robust_measurements_and_mask(base_u8)
        
        # Convert to tensors
        pedi = pil_to_tensor_gray(base_im).to(DEVICE)
        mask = torch.from_numpy((mask_full.astype(np.float32)/255.0))[None,None,:,:].to(DEVICE) * 2.0 - 1.0
        
        # Build parameter vector
        cvec = build_param_vec(CFG.TARGET_AGE, sex, pedi_meas).unsqueeze(0).to(DEVICE)
        
        # Generate prediction
        pred = ddim_sample(net, pedi, mask, cvec, steps=CFG.DDIM_STEPS, eta=0.0, cfg_w=CFG.CFG_WEIGHT)
        pred_im = tensor_to_pil_gray(pred)
        pred_u8 = np.array(pred_im, dtype=np.uint8)
        
        # Get measurements for prediction
        pred_meas, pred_mask = robust_measurements_and_mask(pred_u8)
        
        # Generate target image
        tgt_im = morph_adultize_at_age(base_im, sex, CFG.TARGET_AGE)
        tgt_u8 = np.array(tgt_im, dtype=np.uint8)
        
        # Compute metrics
        psnr = compute_psnr(pred_u8, tgt_u8)
        ssim_val = compute_ssim(pred_u8, tgt_u8)
        msssim_val = compute_msssim(pred_u8, tgt_u8)
        lpips_val = compute_lpips(pred_u8, tgt_u8, lpips_alex)
        
        # Compute clinical metrics
        pedi_clinical = compute_clinical_metrics(base_u8, mask_full)
        pred_clinical = compute_clinical_metrics(pred_u8, pred_mask)
        
        # Estimate GP bone age
        pred_gp_bone_age = estimate_gp_bone_age(pred_u8, pred_mask, sex)
        
        # Create composite image
        composite = Image.new('L', (2 * CFG.IMG_SIZE, CFG.IMG_SIZE + 30))
        composite.paste(base_im, (0, 30))
        composite.paste(pred_im, (CFG.IMG_SIZE, 30))
        
        # Add text annotations
        draw = ImageDraw.Draw(composite)
        draw.text((10, 5), f"Pediatric: ID={img_id}, Age={boneage}m, {sex_str}", fill=255, font=font)
        draw.text((CFG.IMG_SIZE + 10, 5), f"Predicted: 216m (GP est: {pred_gp_bone_age}m)", fill=255, font=font)
        
        # Save composite image
        fp = os.path.join(out_dir, f"{img_id}_{boneage}m_{sex_str}_vs_predicted_216m.png")
        composite.save(fp)
        paths.append(fp)
        
        # Add record
        records.append({
            "id": img_id, "boneage": boneage, "sex": sex_str,
            "pediatric_area": pedi_meas["area"], "pediatric_height": pedi_meas["height"],
            "pediatric_width": pedi_meas["width"], "pediatric_aspect": pedi_meas["aspect"],
            "pediatric_density": pedi_meas["density"], "pediatric_width_length_ratio": pedi_meas["width_length_ratio"],
            "pediatric_cortical_thickness": pedi_clinical["cortical_thickness"],
            "pediatric_epiphyseal_plate_width": pedi_clinical["epiphyseal_plate_width"],
            "predicted_area": pred_meas["area"], "predicted_height": pred_meas["height"],
            "predicted_width": pred_meas["width"], "predicted_aspect": pred_meas["aspect"],
            "predicted_density": pred_meas["density"], "predicted_width_length_ratio": pred_meas["width_length_ratio"],
            "predicted_cortical_thickness": pred_clinical["cortical_thickness"],
            "predicted_epiphyseal_plate_width": pred_clinical["epiphyseal_plate_width"],
            "predicted_gp_bone_age_months": pred_gp_bone_age,
            "psnr_pred_vs_target": psnr, "ssim_pred_vs_target": ssim_val,
            "msssim_pred_vs_target": msssim_val, "lpips_pred_vs_target": lpips_val,
            "output_path": fp,
        })
    
    # Create DataFrame from records
    df = pd.DataFrame.from_records(records)
    
    # Define column order
    cols = ["id", "boneage", "sex",
            "pediatric_area", "pediatric_height", "pediatric_width", "pediatric_aspect", "pediatric_density",
            "pediatric_width_length_ratio", "pediatric_cortical_thickness", "pediatric_epiphyseal_plate_width",
            "predicted_area", "predicted_height", "predicted_width", "predicted_aspect", "predicted_density",
            "predicted_width_length_ratio", "predicted_cortical_thickness", "predicted_epiphyseal_plate_width",
            "predicted_gp_bone_age_months",
            "psnr_pred_vs_target", "ssim_pred_vs_target", "msssim_pred_vs_target", "lpips_pred_vs_target",
            "output_path"]
    
    # Reorder columns
    df = df[cols]
    
    # Save to Excel
    df.to_excel(excel_path, index=False)
    print(f"Saved {len(paths)} comparison images to {out_dir}")
    print(f"Excel: {excel_path}")

    # Visualize metrics
    metrics = ["area", "height", "width", "density", "cortical_thickness", "epiphyseal_plate_width"]
    plt.figure(figsize=(15, 10))
    
    # Create subplots for each metric
    for i, metric in enumerate(metrics, 1):
        plt.subplot(2, 3, i)
        plt.hist(df[f"pediatric_{metric}"], bins=30, alpha=0.5, label="Pediatric", color="blue")
        plt.hist(df[f"predicted_{metric}"], bins=30, alpha=0.5, label="Predicted (216m)", color="orange")
        plt.title(f"{metric.capitalize()} Distribution")
        plt.xlabel(metric.capitalize())
        plt.ylabel("Count")
        plt.legend()
    
    # Adjust layout and save
    plt.tight_layout()
    plot_path = os.path.join(out_dir, "metrics_histogram.png")
    plt.savefig(plot_path, dpi=300)
    plt.close()
    print(f"Metrics visualization saved at: {plot_path}")

    return paths, excel_path

*******Data Loaders*******

In [19]:
# Function to create data loaders
def make_loaders() -> Tuple[DataLoader, DataLoader, 'BoneAgeCondDiffusionDS', List[int]]:
    # Find dataset
    csv_path, img_dir = find_rsna_boneage_root()
    
    # Create dataset
    ds = BoneAgeCondDiffusionDS(csv_path, img_dir, img_size=CFG.IMG_SIZE, max_samples=CFG.MAX_SAMPLES)
    n = len(ds)
    
    # Handle empty dataset
    if n == 0: 
        raise RuntimeError("Empty dataset.")
    
    # Handle single sample case
    if n == 1:
        tr = Subset(ds, [0])
        va = Subset(ds, [0])
        val_indices = [0]
        train_loader = DataLoader(tr, batch_size=1, shuffle=True, num_workers=0, pin_memory=True)
        val_loader   = DataLoader(va, batch_size=1, shuffle=False, num_workers=0, pin_memory=True)
        return train_loader, val_loader, ds, val_indices
    
    # Split dataset
    n_train = max(1, int(0.9 * n))
    n_val = max(1, n - n_train)
    if n_train == n: 
        n_train, n_val = n - 1, 1
    
    # Create random split
    gen = torch.Generator().manual_seed(CFG.SEED)
    tr_subset, va_subset = random_split(ds, [n_train, n_val], generator=gen)
    val_indices = va_subset.indices
    
    # Create data loaders
    train_loader = DataLoader(tr_subset, batch_size=CFG.BATCH, shuffle=True,
                              num_workers=CFG.NUM_WORKERS, pin_memory=True,
                              persistent_workers=(CFG.NUM_WORKERS > 0))
    val_loader   = DataLoader(va_subset, batch_size=CFG.BATCH, shuffle=False,
                              num_workers=CFG.NUM_WORKERS, pin_memory=True,
                              persistent_workers=(CFG.NUM_WORKERS > 0))
    
    return train_loader, val_loader, ds, val_indices

****Main Execution****

In [20]:
# Main execution block
if __name__ == "__main__":
    # Create data loaders
    print("Building loaders...")
    train_loader, val_loader, ds, val_indices = make_loaders()
    print(f"Dataset size: {len(ds)}")
    print(ds.df.head()[["id", "male", "boneage", "img_path"]])

    # Set dataset to training mode
    ds.train()
    
    # Initialize model
    net = CondUNet(cvec_dim=14).to(DEVICE)
    net = load_pretrained_unet(net)
    
    # Define checkpoint path
    ckpt_path = os.path.join(CFG.SAVE_DIR, "cond_diffusion_mask_prior.pth")

    # Train model
    print("Training...")
    ckpt_path = train_diffusion_amp_accum(net, train_loader, val_loader, epochs=CFG.EPOCHS, save_path=ckpt_path)

    # Generate predictions and export
    print("Generating 216-month predictions + side-by-side images + Excel...")
    ds.eval()
    os.makedirs(CFG.OUT_DIR, exist_ok=True)
    paths, xlsx_path = infer_val_and_export(ckpt_path, ds, val_indices, out_dir=CFG.OUT_DIR, excel_path=CFG.MEAS_XLSX)

    # Save file list
    with open(os.path.join(CFG.OUT_DIR, "files.json"), "w") as f:
        json.dump(sorted(paths), f, indent=2)
    
    print("Done.")
    print(f"Excel saved at: {xlsx_path}")

Building loaders...
[data] CSV: /kaggle/input/rsna-bone-age/boneage-training-dataset.csv
[data] IMG_DIR: /kaggle/input/rsna-bone-age/boneage-training-dataset/boneage-training-dataset (files: 12611)
[data] Indexed 12611 images.
Dataset size: 12611
     id  male  boneage                                           img_path
0  1377     0      180  /kaggle/input/rsna-bone-age/boneage-training-d...
1  1378     0       12  /kaggle/input/rsna-bone-age/boneage-training-d...
2  1379     0       94  /kaggle/input/rsna-bone-age/boneage-training-d...
3  1380     1      120  /kaggle/input/rsna-bone-age/boneage-training-d...
4  1381     0       82  /kaggle/input/rsna-bone-age/boneage-training-d...


config.json:   0%|          | 0.00/699 [00:00<?, ?B/s]

diffusion_pytorch_model.safetensors:   0%|          | 0.00/143M [00:00<?, ?B/s]

  scaler = torch.cuda.amp.GradScaler(enabled=CFG.AMP and DEVICE.type == "cuda")


Loaded pretrained weights from google/ddpm-cifar10-32
Training...
AMP=True, ACCUM_STEPS=2, eff_batch=16


Train Ep 1/10: 100%|██████████| 1419/1419 [11:56<00:00,  1.98it/s, loss=0.1439, lr=0.000100]


[Epoch 1] train=0.1439 val=0.1115 lr=0.000098
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.1115)


Train Ep 2/10: 100%|██████████| 1419/1419 [12:00<00:00,  1.97it/s, loss=0.1022, lr=0.000098]


[Epoch 2] train=0.1022 val=0.0927 lr=0.000091
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.0927)


Train Ep 3/10: 100%|██████████| 1419/1419 [11:58<00:00,  1.97it/s, loss=0.0905, lr=0.000091]


[Epoch 3] train=0.0905 val=0.0845 lr=0.000080
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.0845)


Train Ep 4/10: 100%|██████████| 1419/1419 [11:58<00:00,  1.97it/s, loss=0.0856, lr=0.000080]


[Epoch 4] train=0.0856 val=0.0798 lr=0.000066
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.0798)


Train Ep 5/10: 100%|██████████| 1419/1419 [11:57<00:00,  1.98it/s, loss=0.0819, lr=0.000066]


[Epoch 5] train=0.0819 val=0.0794 lr=0.000051
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.0794)


Train Ep 6/10: 100%|██████████| 1419/1419 [11:57<00:00,  1.98it/s, loss=0.0801, lr=0.000051]


[Epoch 6] train=0.0801 val=0.0770 lr=0.000035
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.0770)


Train Ep 7/10: 100%|██████████| 1419/1419 [11:57<00:00,  1.98it/s, loss=0.0781, lr=0.000035]


[Epoch 7] train=0.0781 val=0.0757 lr=0.000021
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.0757)


Train Ep 8/10: 100%|██████████| 1419/1419 [11:57<00:00,  1.98it/s, loss=0.0769, lr=0.000021]


[Epoch 8] train=0.0769 val=0.0756 lr=0.000010
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.0756)


Train Ep 9/10: 100%|██████████| 1419/1419 [11:57<00:00,  1.98it/s, loss=0.0750, lr=0.000010]


[Epoch 9] train=0.0750 val=0.0676 lr=0.000003
Saved /kaggle/working/cond_diffusion_mask_prior.pth (best 0.0676)


Train Ep 10/10: 100%|██████████| 1419/1419 [11:58<00:00,  1.98it/s, loss=0.0755, lr=0.000003]


[Epoch 10] train=0.0755 val=0.0683 lr=0.000001
Generating 216-month predictions + side-by-side images + Excel...
Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]


Downloading: "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth" to /root/.cache/torch/hub/checkpoints/alexnet-owt-7be5be79.pth
100%|██████████| 233M/233M [00:01<00:00, 210MB/s] 


Loading model from: /usr/local/lib/python3.11/dist-packages/lpips/weights/v0.1/alex.pth


Infer+Export (val): 100%|██████████| 1262/1262 [40:55<00:00,  1.95s/it]


Saved 1262 comparison images to /kaggle/working/out_216
Excel: /kaggle/working/measurements_export.xlsx
Metrics visualization saved at: /kaggle/working/out_216/metrics_histogram.png
Done.
Excel saved at: /kaggle/working/measurements_export.xlsx


In [1]:
# ============================================================
# RSNA Bone Age — Quick Plots (Kaggle-ready)
# Outputs -> /kaggle/working/rsna_plots
# ============================================================

import os, glob, re, random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageOps, ImageDraw, ImageFont

# ------------- config -------------
SEED = 42                          # why: reproducible sampling
IMG_SIZE = 256                     # why: consistent loading for intensity calc
STATS_SAMPLE = 400                 # why: limit IO for size/intensity stats
MONTAGE_N = 16                     # why: compact overview (4×4)
OUT_DIR = "/kaggle/working/rsna_plots"
DATASET_HINTS = [
    "/kaggle/input/rsna-bone-age",
    "/kaggle/input/rsna-bone-age/rsna-bone-age",
    "/kaggle/input"
]

random.seed(SEED); np.random.seed(SEED)
os.makedirs(OUT_DIR, exist_ok=True)

# ------------- data discovery -------------
def find_rsna_paths():
    """Find CSV + image root robustly across Kaggle mount patterns."""
    cands = []
    for base in DATASET_HINTS:
        if not os.path.exists(base):
            continue
        for csv_path in glob.glob(os.path.join(base, "**", "boneage-training-dataset.csv"), recursive=True):
            root = os.path.dirname(csv_path)
            # choose dir with most images
            dirs = [root] + [d for d in glob.glob(os.path.join(root, "**"), recursive=True) if os.path.isdir(d)]
            best, cnt = "", -1
            for d in dirs:
                n = (len(glob.glob(os.path.join(d, "*.png"))) +
                     len(glob.glob(os.path.join(d, "*.jpg"))) +
                     len(glob.glob(os.path.join(d, "*.jpeg"))))
                if n > cnt:
                    best, cnt = d, n
            cands.append((csv_path, best, cnt))
    if not cands:
        raise FileNotFoundError("Attach Kaggle dataset: kmader/rsna-bone-age (Add data → search).")
    csv_path, img_dir, cnt = sorted(cands, key=lambda x: -x[2])[0]
    print(f"[data] CSV: {csv_path}\n[data] IMG_DIR: {img_dir} (files: {cnt})")
    return csv_path, img_dir

_ID_RE = re.compile(r"(\d+)\.(png|jpg|jpeg)$", re.IGNORECASE)

def index_image_dir(img_dir: str):
    """Map image id → canonical path (prefer shallower path)."""
    idx = {}
    for p in glob.glob(os.path.join(img_dir, "**", "*.*"), recursive=True):
        if not os.path.isfile(p): 
            continue
        m = _ID_RE.search(os.path.basename(p))
        if not m:
            continue
        i = int(m.group(1))
        if i not in idx or p.count(os.sep) < idx[i].count(os.sep):
            idx[i] = p
    print(f"[data] Indexed {len(idx)} images")
    return idx

# ------------- io helpers -------------
def load_gray_any(path: str, size: int = None) -> Image.Image:
    """Open image as grayscale; optional resize for standardized stats."""
    im = Image.open(path).convert("L")
    if size is not None and im.size != (size, size):
        im = im.resize((size, size), Image.LANCZOS)
    return im

def save_fig(path: str):
    """Consistent figure export."""
    plt.tight_layout()
    plt.savefig(path, dpi=220, bbox_inches="tight")
    plt.close()

# ------------- plots -------------
def plot_age_hist(df: pd.DataFrame, path: str):
    """Bone age distribution (months)."""
    ages = df["boneage"].astype(float).values
    plt.figure(figsize=(7,4))
    plt.hist(ages, bins=40)
    plt.title("Bone Age Distribution (months)")
    plt.xlabel("Bone Age (months)"); plt.ylabel("Count")
    save_fig(path)

def plot_sex_bar(df: pd.DataFrame, path: str):
    """Sex counts."""
    counts = df["male"].astype(int).value_counts().sort_index()
    labels = ["Female","Male"]
    vals = [counts.get(0,0), counts.get(1,0)]
    plt.figure(figsize=(5,4))
    plt.bar(labels, vals)
    plt.title("Sex Distribution"); plt.ylabel("Count")
    save_fig(path)

def plot_age_by_sex_box(df: pd.DataFrame, path: str):
    """Age-by-sex boxplot."""
    m = df[df["male"]==1]["boneage"].astype(float).values
    f = df[df["male"]==0]["boneage"].astype(float).values
    plt.figure(figsize=(6,4))
    plt.boxplot([f, m], labels=["Female","Male"], showfliers=False)
    plt.title("Bone Age by Sex"); plt.ylabel("Bone Age (months)")
    save_fig(path)

def plot_age_hist_by_sex(df: pd.DataFrame, path: str):
    """Overlaid age histograms by sex (quick check of balance)."""
    m = df[df["male"]==1]["boneage"].astype(float).values
    f = df[df["male"]==0]["boneage"].astype(float).values
    plt.figure(figsize=(7,4))
    plt.hist(f, bins=40, alpha=0.6, label="Female")
    plt.hist(m, bins=40, alpha=0.6, label="Male")
    plt.title("Bone Age Histogram by Sex"); plt.xlabel("Bone Age (months)")
    plt.ylabel("Count"); plt.legend()
    save_fig(path)

def plot_image_size_hist(paths: list[str], path: str):
    """Distribution of raw image sizes (sample)."""
    wh = []
    for p in paths:
        try:
            with Image.open(p) as im:
                wh.append(im.size)
        except Exception:
            pass
    if not wh:
        return
    w = np.array([t[0] for t in wh]); h = np.array([t[1] for t in wh])
    plt.figure(figsize=(7,4))
    plt.hist2d(w, h, bins=30)
    plt.colorbar(label="Count")
    plt.title("Raw Image Size Distribution"); plt.xlabel("Width"); plt.ylabel("Height")
    save_fig(path)

def plot_mean_intensity_hist(paths: list[str], path: str):
    """Mean intensity distribution (after resize to IMG_SIZE)."""
    vals = []
    for p in paths:
        try:
            im = load_gray_any(p, IMG_SIZE)
            vals.append(np.array(im, dtype=np.uint8).mean())
        except Exception:
            pass
    if not vals:
        return
    plt.figure(figsize=(6,4))
    plt.hist(np.array(vals), bins=40)
    plt.title("Mean Image Intensity (grayscale)"); plt.xlabel("Mean Intensity (0–255)")
    plt.ylabel("Count")
    save_fig(path)

def make_montage(df: pd.DataFrame, idx_map: dict[int,str], n: int, out_path: str):
    """4×4 montage with ID/age/sex; fast sanity-check gallery."""
    n = min(n, len(df))
    samp = df.sample(n, random_state=SEED).reset_index(drop=True)
    grid = int(np.ceil(np.sqrt(n)))
    cell = IMG_SIZE // 2  # smaller tiles to fit on page
    pad = 8

    # pick a readable font; fallback to default
    try:
        font = ImageFont.truetype("arial.ttf", 14)
    except Exception:
        font = ImageFont.load_default()

    W = grid*cell + (grid+1)*pad
    H = grid*cell + (grid+1)*pad + 30
    canvas = Image.new("RGB", (W, H), (255,255,255))
    draw = ImageDraw.Draw(canvas)
    draw.text((pad, pad//2), "RSNA Sample Montage (ID | Age m | Sex)", fill=(30,41,59), font=font)

    for k, row in samp.iterrows():
        img_id = int(row["id"]); age = int(row.get("boneage",-1)); sex = "M" if int(row["male"])==1 else "F"
        path = idx_map.get(img_id, None); 
        if path is None: 
            continue
        try:
            im = load_gray_any(path)
            im = ImageOps.equalize(im)   # why: mild contrast normalization
            im = im.resize((cell, cell), Image.LANCZOS)
            # convert to RGB for labels
            im_rgb = Image.merge("RGB", (im, im, im))
            lbl = f"{img_id} | {age} | {sex}"
            ImageDraw.Draw(im_rgb).rectangle([0,0,im_rgb.width,20], fill=(255,255,255))
            ImageDraw.Draw(im_rgb).text((4,2), lbl, fill=(30,41,59), font=font)
            r, c = divmod(k, grid)
            x = pad + c*(cell+pad); y = pad + 24 + r*(cell+pad)
            canvas.paste(im_rgb, (x, y))
        except Exception:
            continue

    canvas.save(out_path)

# ------------- main -------------
def main():
    csv_path, img_dir = find_rsna_paths()
    df = pd.read_csv(csv_path)
    df["male"] = df["male"].astype(int)
    idx_map = index_image_dir(img_dir)
    df = df[df["id"].astype(int).isin(idx_map.keys())].reset_index(drop=True)
    if len(df) == 0:
        raise RuntimeError("No images matched IDs in CSV.")

    # Basic plots
    plot_age_hist(df, os.path.join(OUT_DIR, "age_hist.png"))
    plot_sex_bar(df, os.path.join(OUT_DIR, "sex_bar.png"))
    plot_age_by_sex_box(df, os.path.join(OUT_DIR, "age_by_sex_box.png"))
    plot_age_hist_by_sex(df, os.path.join(OUT_DIR, "age_hist_by_sex.png"))

    # Sample for IO-heavy stats
    ids = df["id"].astype(int).tolist()
    if len(ids) > STATS_SAMPLE:
        ids = random.sample(ids, STATS_SAMPLE)
    paths = [idx_map[i] for i in ids if i in idx_map]

    plot_image_size_hist(paths, os.path.join(OUT_DIR, "image_size_hist2d.png"))
    plot_mean_intensity_hist(paths, os.path.join(OUT_DIR, "mean_intensity_hist.png"))
    make_montage(df, idx_map, MONTAGE_N, os.path.join(OUT_DIR, "montage_4x4.png"))

    print("Saved plots:")
    for p in sorted(glob.glob(os.path.join(OUT_DIR, "*.png"))):
        print(" -", p)

if __name__ == "__main__":
    main()


[data] CSV: /kaggle/input/rsna-bone-age/boneage-training-dataset.csv
[data] IMG_DIR: /kaggle/input/rsna-bone-age/boneage-training-dataset/boneage-training-dataset (files: 12611)
[data] Indexed 12611 images
Saved plots:
 - /kaggle/working/rsna_plots/age_by_sex_box.png
 - /kaggle/working/rsna_plots/age_hist.png
 - /kaggle/working/rsna_plots/age_hist_by_sex.png
 - /kaggle/working/rsna_plots/image_size_hist2d.png
 - /kaggle/working/rsna_plots/mean_intensity_hist.png
 - /kaggle/working/rsna_plots/montage_4x4.png
 - /kaggle/working/rsna_plots/sex_bar.png
