In [None]:
# MONAI Dual-Task (Segmentation + Classification) training using nnU-Net-style pipeline
# - Dataset: derived/unified_dualtask (train/val/test CSVs)
# - Spacing standardization to (0.8, 0.8, 1.0) mm
# - Label-preserving resample via one-hot + optional dilate-then-erode
# - Sliding-window patch training (192x192x160)
# - Shared-encoder segmentation (DynUNet) + classification head
# - QC counters that flag label shrinkage post-resample

import os
import math
import time
import json
import random
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.amp import autocast, GradScaler

import nibabel as nib

from monai.config import print_config
from monai.data import CacheDataset, DataLoader, decollate_batch
from monai.data.utils import no_collation
from monai.inferers import SlidingWindowInferer
from monai.losses import DiceCELoss
from monai.metrics import DiceMetric, HausdorffDistanceMetric
from monai.networks.nets import DynUNet
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    EnsureTyped,
    Orientationd,
    Spacingd,
    ScaleIntensityRanged,
    RandSpatialCropd,
    RandFlipd,
    RandRotate90d,
    RandAffined,
    AsDiscreted,
    EnsureTyped,
    CastToTyped,
)
from monai.utils import set_determinism

print_config()


In [None]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

In [None]:
# Reproducibility
SEED = 42
set_determinism(SEED)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

# Paths
PROJ_ROOT = Path('/home/qarc/projects/DHAI-Brain-Segmentation')
DUALTASK_ROOT = PROJ_ROOT / 'derived' / 'unified_dualtask'
TRAIN_CSV = DUALTASK_ROOT / 'train_fixed.csv'
VAL_CSV = DUALTASK_ROOT / 'val_fixed.csv'
TEST_CSV = DUALTASK_ROOT / 'test_fixed.csv'

assert TRAIN_CSV.exists() and VAL_CSV.exists() and TEST_CSV.exists(), 'Split CSVs missing'

# Target spacing and patch params
TARGET_SPACING = (0.8, 0.8, 1.0)
PATCH_SIZE = (192, 192, 160)
PATCH_OVERLAP = 0.5  # sliding window overlap


In [None]:
# CSV -> dict list helpers

def read_unified_csv(path: Path) -> List[Dict]:
    df = pd.read_csv(path)
    # expected columns: case_id,class_label,image_path,label_path
    items = []
    for _, row in df.iterrows():
        items.append({
            'case_id': row['case_id'],
            'image': row['image_path'],
            'label': row['label_path'],
            'class_label': int(row['class_label']),
        })
    return items

train_items = read_unified_csv(TRAIN_CSV)
val_items = read_unified_csv(VAL_CSV)
test_items = read_unified_csv(TEST_CSV)

len(train_items), len(val_items), len(test_items)


In [None]:
# Morphology utilities and QC counters
import scipy.ndimage as ndi

class LabelQC:
    def __init__(self, shrink_warn_threshold: float = 0.35):
        self.shrink_warn_threshold = shrink_warn_threshold
        self.total = 0
        self.warn = 0
    def update(self, before_voxels: int, after_voxels: int, case_id: str):
        self.total += 1
        if before_voxels > 0:
            ratio = (after_voxels + 1e-6) / (before_voxels + 1e-6)
            if ratio < (1.0 - self.shrink_warn_threshold):
                self.warn += 1
                print(f'[QC] label shrinkage: {case_id} before={before_voxels} after={after_voxels} ratio={ratio:.3f}')
    def summary(self):
        print(f'[QC] shrinkage warnings: {self.warn}/{self.total}')


def binary_dilate_then_erode(mask: np.ndarray, radius_vox: int = 1) -> np.ndarray:
    if radius_vox <= 0:
        return mask
    structure = ndi.generate_binary_structure(3, 1)
    for _ in range(radius_vox):
        mask = ndi.binary_dilation(mask, structure=structure)
    for _ in range(radius_vox):
        mask = ndi.binary_erosion(mask, structure=structure)
    return mask


In [None]:
# Transforms: spacing standardization and intensity scale
# Labels are handled with a custom post-transform step to preserve small lesions via one-hot resample and optional morph.

from monai.transforms import MapTransform

class OneHotResampleWithMorphology(MapTransform):
    def __init__(self, keys, num_classes: int = 2, morph_radius: int = 0, allow_missing_keys: bool = False):
        super().__init__(keys, allow_missing_keys)
        self.num_classes = num_classes
        self.morph_radius = morph_radius

    def __call__(self, data):
        d = dict(data)
        # expects d['label'] to be a MONAI image tensor with metadata spacing attached
        label = d['label']  # torch.Tensor [1, D, H, W] after EnsureChannelFirstd
        meta = d.get('label_meta_dict', {})
        # Before voxels for QC
        before_vox = int((label > 0.5).sum().item())

        # one-hot
        label_oh = F.one_hot(label.long().squeeze(0), num_classes=self.num_classes).permute(3, 0, 1, 2).float()

        # resample using metadata of image (already resampled) to match size via trilinear/nearest
        # assume image and label now share the same shape; if not, interpolate to image shape
        img = d['image']
        if label_oh.shape[1:] != img.shape[1:]:
            # channels first
            label_oh = F.interpolate(label_oh.unsqueeze(0), size=img.shape[1:], mode='trilinear', align_corners=False).squeeze(0)
        # discretize back to argmax
        label_res = label_oh.argmax(dim=0, keepdim=True)

        # optional light morph
        if self.morph_radius > 0:
            arr = label_res.detach().cpu().numpy().astype(np.uint8)
            arr = binary_dilate_then_erode(arr[0], radius_vox=self.morph_radius)[None]
            label_res = torch.as_tensor(arr, dtype=torch.long, device=label_res.device)

        after_vox = int((label_res > 0).sum().item())
        d['label'] = label_res
        d['qc_before_vox'] = before_vox
        d['qc_after_vox'] = after_vox
        return d

TARGET_MODE = 'bilinear'

common_load = [
    LoadImaged(keys=['image', 'label']),
    EnsureChannelFirstd(keys=['image', 'label']),
    EnsureTyped(keys=['image', 'label'], dtype=torch.float32),
    Orientationd(keys=['image', 'label'], axcodes='RAS'),
    Spacingd(keys=['image', 'label'], pixdim=TARGET_SPACING, mode=('bilinear', 'nearest')),
]

from monai.transforms import RandCropByPosNegLabeld, SpatialPadd

intensity_train = [
    ScaleIntensityRanged(keys=['image'], a_min=0, a_max=3000, b_min=0.0, b_max=1.0, clip=True),
    RandFlipd(keys=['image', 'label'], spatial_axis=[0, 1, 2], prob=0.2),
    RandRotate90d(keys=['image', 'label'], prob=0.2, max_k=3),
    RandAffined(keys=['image', 'label'], rotate_range=(math.pi/36, math.pi/36, math.pi/36),
                scale_range=(0.1, 0.1, 0.1), mode=('bilinear', 'nearest'), prob=0.2),
    SpatialPadd(keys=['image', 'label'], spatial_size=PATCH_SIZE),
    RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=PATCH_SIZE,
                           pos=1, neg=1, num_samples=1, image_key='image', allow_smaller=True),
]

intensity_val = [
    ScaleIntensityRanged(keys=['image'], a_min=0, a_max=3000, b_min=0.0, b_max=1.0, clip=True),
]

# Post label discretization with morphology + QC bookkeeping performed inline by OneHotResampleWithMorphology

post_label_preserve = [
    OneHotResampleWithMorphology(keys=['label'], num_classes=2, morph_radius=1),
]

class CastClassLabeld(MapTransform):
    def __init__(self, keys, allow_missing_keys=False):
        super().__init__(keys, allow_missing_keys)
    def __call__(self, data):
        d = dict(data)
        if 'class_label' in d:
            d['class_label'] = torch.as_tensor(d['class_label'], dtype=torch.float32)
        return d

train_transforms = Compose(common_load + intensity_train + post_label_preserve + [CastClassLabeld(keys=['class_label'])])
val_transforms = Compose(common_load + intensity_val + post_label_preserve + [CastClassLabeld(keys=['class_label'])])


In [None]:
# Datasets and Loaders with QC hooks

qc_train = LabelQC(shrink_warn_threshold=0.35)
qc_val = LabelQC(shrink_warn_threshold=0.35)

class QCAugmentWrapper(CacheDataset):
    def __init__(self, data, transform, cache_rate=0.0, num_workers=4, copy_cache=True):
        super().__init__(data=data, transform=transform, cache_rate=cache_rate, num_workers=num_workers, copy_cache=copy_cache)
    def __getitem__(self, index):
        item = super().__getitem__(index)
        # Collect QC counters if present
        case_id = item.get('case_id') if isinstance(item, dict) else None
        before_vox = int(item.get('qc_before_vox', 0))
        after_vox = int(item.get('qc_after_vox', 0))
        if before_vox or after_vox:
            # decide which QC to update based on internal flag set earlier
            pass
        return item

train_ds = CacheDataset(data=train_items, transform=train_transforms, cache_rate=0.0, num_workers=4)
val_ds = CacheDataset(data=val_items, transform=val_transforms, cache_rate=0.0, num_workers=2)

train_loader = DataLoader(train_ds, batch_size=1, shuffle=True, num_workers=1, pin_memory=False, persistent_workers=False)
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=1, pin_memory=False, persistent_workers=False)


In [None]:
# Model: DynUNet backbone + classification head from encoder bottleneck

# Segmentation network
seg_net = DynUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    kernel_size=[3, 3, 3, 3, 3, 3],
    strides=[1, 2, 2, 2, 2, 2],
    upsample_kernel_size=[2, 2, 2, 2, 2],
    norm_name='instance',
    deep_supervision=False,
).to(device)

# Classification head
class ClassificationHead(nn.Module):
    def __init__(self, in_channels: int, num_classes: int = 1):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = nn.Linear(in_channels, num_classes)
    def forward(self, feat):
        x = self.pool(feat).flatten(1)
        return self.fc(x)

# Identify bottleneck channels from DynUNet
# DynUNet returns a list of decoder outputs when deep_supervision; we also can hook encoder features via register_forward_hook

# Lazy-init classification head once we know feature channels
# Update the LazyClassificationHead class
class LazyClassificationHead(nn.Module):
    def __init__(self, num_classes: int = 1):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool3d(1)
        self.fc = None
        self.num_classes = num_classes
        
    def forward(self, feat):
        x = self.pool(feat).flatten(1)
        if self.fc is None:
            self.fc = nn.Linear(x.shape[1], self.num_classes).to(x.device)
        return self.fc(x)
    
    def load_state_dict(self, state_dict, strict=True):
        """Override load_state_dict to handle lazy initialization"""
        # Check if we have fc weights in the state dict
        if 'fc.weight' in state_dict and 'fc.bias' in state_dict:
            # Initialize fc layer with the saved dimensions
            fc_weight = state_dict['fc.weight']
            fc_bias = state_dict['fc.bias']
            
            # Create fc layer with correct dimensions
            if self.fc is None:
                in_features = fc_weight.shape[1]
                self.fc = nn.Linear(in_features, self.num_classes).to(fc_weight.device)
            
            # Load the weights
            self.fc.weight.data = fc_weight
            self.fc.bias.data = fc_bias
            
            # Remove fc keys from state_dict to avoid double loading
            state_dict = {k: v for k, v in state_dict.items() if not k.startswith('fc.')}
            
        # Load any remaining state dict items
        if state_dict:
            super().load_state_dict(state_dict, strict=strict)
        
        return None  # Return None for compatibility with older PyTorch versions

cls_head = LazyClassificationHead(num_classes=1).to(device)

# Simple hook to capture bottleneck features
encoder_feat = {'x': None}

def hook_fn(module, input, output):
    encoder_feat['x'] = output

# Attach hook to bottleneck layer (seg_net.encoder4 or seg_net.bottleneck depending on version)
if hasattr(seg_net, 'bottleneck'):
    seg_net.bottleneck.register_forward_hook(hook_fn)
elif hasattr(seg_net, 'encoder4'):
    seg_net.encoder4.register_forward_hook(hook_fn)
else:
    print('[WARN] Could not attach hook, classification head may not receive features')

# Losses and optimizer
seg_loss_fn = DiceCELoss(to_onehot_y=True, softmax=True)
cls_loss_fn = nn.BCEWithLogitsLoss()

params = list(seg_net.parameters()) + list(cls_head.parameters())
optimizer = torch.optim.AdamW(params, lr=2e-4, weight_decay=1e-5)
scaler = GradScaler(enabled=torch.cuda.is_available())

# Metrics
post_pred = AsDiscreted(keys=['pred'], argmax=True)
post_label = AsDiscreted(keys=['label'], to_onehot=2)
dice_metric = DiceMetric(include_background=False, reduction='mean')


In [None]:
# Helper: pad tensor to next multiple-of factor for each spatial dim
import torch.nn.functional as F

def pad_to_factor(x: torch.Tensor, factor: int = 32) -> torch.Tensor:
    # x: (B, C, D, H, W)
    B, C, D, H, W = x.shape
    def next_m(s):
        return ((s + factor - 1) // factor) * factor
    Dn, Hn, Wn = next_m(D), next_m(H), next_m(W)
    pd = Dn - D; ph = Hn - H; pw = Wn - W
    # pad order: (W_left, W_right, H_left, H_right, D_left, D_right)
    pad = (0, pw, 0, ph, 0, pd)
    if any(p > 0 for p in pad):
        x = F.pad(x, pad, mode='constant', value=0.0)
    return x


In [None]:
# Inferer for validation/test
inferer = SlidingWindowInferer(roi_size=PATCH_SIZE, sw_batch_size=1, overlap=PATCH_OVERLAP, mode='gaussian')

# Utils
from contextlib import nullcontext

def to_device(batch: Dict, device: torch.device) -> Dict:
    """Enhanced to_device with error handling"""
    out = {}
    try:
        for k, v in batch.items():
            if isinstance(v, torch.Tensor):
                # Check if tensor is valid
                if torch.isnan(v).any() or torch.isinf(v).any():
                    print(f"Warning: Invalid tensor detected in {k}")
                    continue
                
                # Move to device with error handling
                try:
                    out[k] = v.to(device, non_blocking=True)
                except RuntimeError as e:
                    print(f"Error moving {k} to device: {e}")
                    # Try synchronous transfer as fallback
                    try:
                        out[k] = v.to(device, non_blocking=False)
                    except RuntimeError as e2:
                        print(f"Fallback transfer also failed for {k}: {e2}")
                        # Keep original tensor on CPU as last resort
                        out[k] = v
                        continue
            else:
                out[k] = v
    except Exception as e:
        print(f"Error in to_device: {e}")
        # Return original batch if device transfer fails
        return batch
    
    return out

ckpt_dir = PROJ_ROOT / 'runs' / 'dualtask_monai'
os.makedirs(ckpt_dir, exist_ok=True)
print('Checkpoint dir:', ckpt_dir)


In [None]:
# Enhanced metrics imports and utilities
from sklearn.metrics import roc_auc_score, accuracy_score, f1_score, confusion_matrix

# Enhanced metrics calculation functions
def calculate_classification_metrics(y_true: List[int], y_prob: List[float], y_pred: List[int]) -> Dict[str, float]:
    """Calculate classification metrics: accuracy, F1, AUC"""
    try:
        auc = roc_auc_score(y_true, y_prob)
    except ValueError:
        auc = float('nan')
    
    acc = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred, average='binary')
    
    return {
        'accuracy': acc,
        'f1_score': f1,
        'auc': auc
    }

def calculate_segmentation_metrics(pred: torch.Tensor, target: torch.Tensor) -> Dict[str, float]:
    """Calculate segmentation metrics: Dice and Hausdorff distance"""
    try:
        # Convert to numpy for Hausdorff calculation
        pred_np = pred.detach().cpu().numpy()
        target_np = target.detach().cpu().numpy()
        
        # Ensure binary masks
        pred_binary = (pred_np > 0.5).astype(np.uint8)
        target_binary = (target_np > 0.5).astype(np.uint8)
        
        # Calculate Hausdorff distance only if both masks have foreground
        if pred_binary.sum() > 0 and target_binary.sum() > 0:
            # Use MONAI's HausdorffDistanceMetric
            hausdorff_metric = HausdorffDistanceMetric(include_background=False, percentile=95.0)
            hausdorff_metric(y_pred=torch.from_numpy(pred_binary), y=torch.from_numpy(target_binary))
            hausdorff_dist = hausdorff_metric.aggregate().item()
        else:
            hausdorff_dist = float('nan')
            
    except Exception as e:
        print(f"Warning: Could not calculate Hausdorff distance: {e}")
        hausdorff_dist = float('nan')
    
    return {
        'hausdorff_distance': hausdorff_dist
    }

In [None]:
# Comprehensive checkpointing and resume functionality
import pickle
from datetime import datetime
# Function to initialize classification head with dummy data
def initialize_classification_head(cls_head, seg_net, device):
    """Initialize classification head with dummy forward pass"""
    print("Initializing classification head...")
    
    # Create dummy input to trigger lazy initialization
    dummy_input = torch.randn(1, 1, 64, 64, 64).to(device)  # Small dummy volume
    
    # Forward pass through segmentation network to get features
    with torch.no_grad():
        _ = seg_net(dummy_input)
        # Get bottleneck features
        if 'encoder_feat' in globals() and encoder_feat['x'] is not None:
            dummy_feat = encoder_feat['x']
        else:
            # Fallback: use a dummy feature tensor
            dummy_feat = torch.randn(1, 256, 8, 8, 8).to(device)  # Typical bottleneck size
        
        # This will initialize the fc layer
        _ = cls_head(dummy_feat)
    
    print("Classification head initialized")
class TrainingStateManager:
    def __init__(self, checkpoint_dir: Path):
        self.checkpoint_dir = checkpoint_dir
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'val_dice': [],
            'val_hausdorff': [],
            'val_accuracy': [],
            'val_f1': [],
            'val_auc': []
        }
        
    def save_checkpoint(self, epoch: int, seg_net, cls_head, optimizer, scaler, 
                       best_val_dice: float, current_metrics: Dict, is_best: bool = False):
        """Save comprehensive training checkpoint"""
        checkpoint = {
            'epoch': epoch,
            'seg_net_state_dict': seg_net.state_dict(),
            'cls_head_state_dict': cls_head.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict(),
            'best_val_dice': best_val_dice,
            'current_metrics': current_metrics,
            'history': self.history,
            'random_state': torch.get_rng_state(),
            'numpy_random_state': np.random.get_state(),
            'python_random_state': random.getstate(),
            'timestamp': datetime.now().isoformat()
        }
        
        # Save regular checkpoint
        checkpoint_path = self.checkpoint_dir / f'checkpoint_epoch_{epoch}.pt'
        torch.save(checkpoint, checkpoint_path)
        
        # Save best checkpoint
        if is_best:
            best_path = self.checkpoint_dir / 'best_checkpoint.pt'
            torch.save(checkpoint, best_path)
            
        # Save latest checkpoint (for easy resuming)
        latest_path = self.checkpoint_dir / 'latest_checkpoint.pt'
        torch.save(checkpoint, latest_path)
        
        print(f"Checkpoint saved: {checkpoint_path}")
        if is_best:
            print(f"Best checkpoint updated: {best_path}")
            
    # Update the load_checkpoint method in TrainingStateManager
    def load_checkpoint(self, seg_net, cls_head, optimizer, scaler, checkpoint_path: str = None):
        """Load checkpoint and restore training state"""
        if checkpoint_path is None:
            checkpoint_path = self.checkpoint_dir / 'latest_checkpoint.pt'
            
        if not Path(checkpoint_path).exists():
            print(f"No checkpoint found at {checkpoint_path}")
            return 0, -1.0
            
        print(f"Loading checkpoint from {checkpoint_path}")
        
        try:
            # First try with weights_only=True (safe loading)
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=True)
            print("Loaded checkpoint with safe loading")
        except Exception as e:
            print(f"Safe loading failed: {e}")
            print("Attempting to load with full state restoration...")
            
            # If safe loading fails, try with weights_only=False (trusted source)
            checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
            print("Loaded checkpoint with full state restoration")
        
        # Initialize classification head if needed
        if cls_head.fc is None:
            print("Initializing classification head before loading weights...")
            initialize_classification_head(cls_head, seg_net, device)
        
        # Restore model states
        try:
            seg_net.load_state_dict(checkpoint['seg_net_state_dict'])
            print("Segmentation network loaded successfully")
        except Exception as e:
            print(f"Error loading segmentation network: {e}")
            raise
            
        try:
            cls_head.load_state_dict(checkpoint['cls_head_state_dict'])
            print("Classification head loaded successfully")
        except Exception as e:
            print(f"Error loading classification head: {e}")
            raise
            
        try:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            print("Optimizer state loaded successfully")
        except Exception as e:
            print(f"Error loading optimizer state: {e}")
            raise
            
        try:
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
            print("GradScaler state loaded successfully")
        except Exception as e:
            print(f"Error loading GradScaler state: {e}")
            raise
        
        # Improved random state restoration
        try:
            if 'random_state' in checkpoint:
                random_state = checkpoint['random_state']
                # Ensure it's on the correct device and has correct type
                if isinstance(random_state, torch.Tensor):
                    if random_state.device != device:
                        random_state = random_state.to(device)
                    if random_state.dtype != torch.uint8:
                        random_state = random_state.to(torch.uint8)
                    torch.set_rng_state(random_state)
                    print("PyTorch random state restored")
                else:
                    print("Warning: Invalid PyTorch random state format")
                    
            if 'numpy_random_state' in checkpoint:
                np.random.set_state(checkpoint['numpy_random_state'])
                print("NumPy random state restored")
                
            if 'python_random_state' in checkpoint:
                random.setstate(checkpoint['python_random_state'])
                print("Python random state restored")
                
        except Exception as e:
            print(f"Warning: Could not restore random states: {e}")
            print("Setting new random seed for continued training...")
            # Set a new deterministic seed
            new_seed = 42 + checkpoint.get('epoch', 0)  # Different seed for each epoch
            set_determinism(new_seed)
            print(f"New random seed set: {new_seed}")
        
        # Restore training history
        if 'history' in checkpoint:
            self.history = checkpoint['history']
            print("Training history restored")
        else:
            print("No training history found in checkpoint")
        
        epoch = checkpoint['epoch']
        best_val_dice = checkpoint['best_val_dice']
        
        print(f"Resumed from epoch {epoch}, best val dice: {best_val_dice:.4f}")
        return epoch, best_val_dice
    
    def update_history(self, train_loss: float, val_loss: float = None, 
                      val_dice: float = None, val_hausdorff: float = None,
                      val_accuracy: float = None, val_f1: float = None, val_auc: float = None):
        """Update training history"""
        self.history['train_loss'].append(train_loss)
        if val_loss is not None:
            self.history['val_loss'].append(val_loss)
        if val_dice is not None:
            self.history['val_dice'].append(val_dice)
        if val_hausdorff is not None:
            self.history['val_hausdorff'].append(val_hausdorff)
        if val_accuracy is not None:
            self.history['val_accuracy'].append(val_accuracy)
        if val_f1 is not None:
            self.history['val_f1'].append(val_f1)
        if val_auc is not None:
            self.history['val_auc'].append(val_auc)

# Initialize training state manager
state_manager = TrainingStateManager(ckpt_dir)

# Check for existing checkpoint to resume from
start_epoch, best_val_dice = state_manager.load_checkpoint(seg_net, cls_head, optimizer, scaler)
if start_epoch == 0:
    best_val_dice = -1.0
    print("Starting training from scratch")
else:
    print(f"Resuming training from epoch {start_epoch + 1}")

In [None]:
import gc

# Enhanced Train / Val loops with comprehensive metrics, checkpointing, and memory management
EPOCHS = 50
val_interval = 1
save_checkpoint_interval = 5  # Save checkpoint every 5 epochs

# Initialize additional metrics
hausdorff_metric = HausdorffDistanceMetric(include_background=False, percentile=95.0)

def monitor_gpu_memory():
    """Monitor GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        cached = torch.cuda.memory_reserved() / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB")

def cleanup_batch_memory(images=None, labels=None, seg_logits=None, cls_logits=None, feat=None, images_p=None):
    """Clean up batch-related tensors to free memory"""
    del_list = [images, labels, seg_logits, cls_logits, feat, images_p]
    for tensor in del_list:
        if tensor is not None:
            del tensor
    torch.cuda.empty_cache()
    gc.collect()

for epoch in range(start_epoch + 1, EPOCHS + 1):
    seg_net.train(); cls_head.train()
    epoch_loss = 0.0
    num_steps = 0
    
    print(f'Epoch {epoch}/{EPOCHS}')
    monitor_gpu_memory()

    # Training loop with memory management
    for batch_idx, batch in enumerate(train_loader):
        # Memory cleanup every 10 batches
        if batch_idx % 10 == 0:
            torch.cuda.empty_cache()
            gc.collect()
        
        # QC update per-sample
        for b in decollate_batch(batch):
            qc_train.update(int(b.get('qc_before_vox', 0)), int(b.get('qc_after_vox', 0)), str(b.get('case_id', '?')))
        
        batch = to_device(batch, device)
        images = batch['image']
        labels = batch['label'].long()
        class_labels = batch['class_label'].view(-1, 1)

        optimizer.zero_grad(set_to_none=True)
        ctx = autocast(device_type='cuda', enabled=torch.cuda.is_available())
        with ctx:
            encoder_feat['x'] = None
            seg_logits = seg_net(images)
            # DynUNet with deep supervision returns list; last is highest res
            if isinstance(seg_logits, (list, tuple)):
                seg_logits_main = seg_logits
            else:
                seg_logits_main = seg_logits
            # Classification
            feat = encoder_feat['x'] if encoder_feat['x'] is not None else seg_logits_main
            cls_logits = cls_head(feat)

            loss_seg = seg_loss_fn(seg_logits_main, labels)
            loss_cls = cls_loss_fn(cls_logits, class_labels)
            loss = loss_seg + 0.3 * loss_cls
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_loss += loss.item()
        num_steps += 1
        
        # Cleanup training batch memory
        cleanup_batch_memory(images, labels, seg_logits_main, cls_logits, feat)

    epoch_loss /= max(1, num_steps)
    print(f'Epoch {epoch}/{EPOCHS} - train loss: {epoch_loss:.4f}')
    if epoch % val_interval == 0:
        qc_train.summary()

    # Validation loop with memory management
    if epoch % val_interval == 0:
        seg_net.eval(); cls_head.eval()
        dice_metric.reset()
        hausdorff_metric.reset()
        val_loss = 0.0
        steps = 0
        
        # For classification metrics
        val_y_true, val_y_prob = [], []
        val_hausdorff_distances = []
        
        with torch.no_grad():
            for val_batch_idx, batch in enumerate(val_loader):
                # Memory cleanup every 5 validation batches
                if val_batch_idx % 5 == 0:
                    torch.cuda.empty_cache()
                    gc.collect()
                
                for b in decollate_batch(batch):
                    qc_val.update(int(b.get('qc_before_vox', 0)), int(b.get('qc_after_vox', 0)), str(b.get('case_id', '?')))
                
                batch = to_device(batch, device)
                images = batch['image']
                labels = batch['label'].long()
                class_labels = batch['class_label'].view(-1, 1)

                ctx = autocast(device_type='cuda', enabled=torch.cuda.is_available())
                with ctx:
                    images_p = pad_to_factor(images, factor=32)  # (B,C,D,H,W) -> padded to mult of 32
                    seg_logits = inferer(inputs=images_p, network=seg_net)
                    encoder_feat['x'] = None    
                    # classification from encoder feature may not be available in inferer path; do a forward to populate
                    _ = seg_net(images_p)
                    if seg_logits.shape[-3:] != labels.shape[-3:]:
                        Dz, Hy, Wx = labels.shape[-3:]
                        seg_logits = seg_logits[..., :Dz, :Hy, :Wx]
                    feat = encoder_feat['x'] if encoder_feat['x'] is not None else seg_logits
                    cls_logits = cls_head(feat)

                    loss_seg = seg_loss_fn(seg_logits, labels)
                    loss_cls = cls_loss_fn(cls_logits, class_labels)
                    loss = loss_seg + 0.3 * loss_cls
                
                val_loss += loss.item()
                steps += 1

                # Segmentation metrics
                y_pred = torch.softmax(seg_logits, dim=1)
                y_pred_discrete = torch.argmax(y_pred, dim=1, keepdim=True)
                dice_metric(y_pred=y_pred_discrete, y=labels)
                
                # Hausdorff distance calculation
                try:
                    hausdorff_metric(y_pred=y_pred_discrete, y=labels)
                    hausdorff_dist = hausdorff_metric.aggregate().item()
                    val_hausdorff_distances.append(hausdorff_dist)
                    hausdorff_metric.reset()  # Reset for next sample
                except:
                    val_hausdorff_distances.append(float('nan'))

                # Classification metrics collection
                val_y_true.append(int(class_labels.item()))
                val_y_prob.append(torch.sigmoid(cls_logits).item())
                
                # Cleanup validation batch memory
                cleanup_batch_memory(images, labels, seg_logits, cls_logits, feat, images_p)

        # Calculate metrics
        mean_dice = dice_metric.aggregate().item()
        val_loss /= max(1, steps)
        mean_hausdorff = np.nanmean(val_hausdorff_distances) if val_hausdorff_distances else float('nan')
        
        # Classification metrics
        val_y_pred = [1 if p >= 0.5 else 0 for p in val_y_prob]
        cls_metrics = calculate_classification_metrics(val_y_true, val_y_prob, val_y_pred)
        
        # Print comprehensive metrics
        print(f'  Val loss: {val_loss:.4f} | Val Dice: {mean_dice:.4f} | Val Hausdorff: {mean_hausdorff:.2f}')
        print(f'  Val Accuracy: {cls_metrics["accuracy"]:.4f} | Val F1: {cls_metrics["f1_score"]:.4f} | Val AUC: {cls_metrics["auc"]:.4f}')
        qc_val.summary()
        
        # Update history
        state_manager.update_history(
            train_loss=epoch_loss,
            val_loss=val_loss,
            val_dice=mean_dice,
            val_hausdorff=mean_hausdorff,
            val_accuracy=cls_metrics["accuracy"],
            val_f1=cls_metrics["f1_score"],
            val_auc=cls_metrics["auc"]
        )
        
        # Save checkpoint
        current_metrics = {
            'val_loss': val_loss,
            'val_dice': mean_dice,
            'val_hausdorff': mean_hausdorff,
            'val_accuracy': cls_metrics["accuracy"],
            'val_f1': cls_metrics["f1_score"],
            'val_auc': cls_metrics["auc"]
        }
        
        is_best = mean_dice > best_val_dice
        if is_best:
            best_val_dice = mean_dice
            print(f'  [NEW BEST] Dice: {best_val_dice:.4f}')
            
        # Save checkpoint (best + regular interval)
        if is_best or epoch % save_checkpoint_interval == 0:
            state_manager.save_checkpoint(
                epoch=epoch,
                seg_net=seg_net,
                cls_head=cls_head,
                optimizer=optimizer,
                scaler=scaler,
                best_val_dice=best_val_dice,
                current_metrics=current_metrics,
                is_best=is_best
            )

        # Final cleanup after validation
        torch.cuda.empty_cache()
        gc.collect()
        monitor_gpu_memory()

print("Training completed!")

In [None]:
# Enhanced test evaluation with comprehensive metrics
from sklearn.metrics import roc_auc_score, accuracy_score

# Build test loader
test_ds = CacheDataset(data=test_items, transform=val_transforms, cache_rate=0.0, num_workers=2)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=2, pin_memory=True)

# Load best checkpoint
best_ckpt_path = ckpt_dir / 'best_checkpoint.pt'
if best_ckpt_path.exists():
    checkpoint = torch.load(best_ckpt_path, map_location=device)
    seg_net.load_state_dict(checkpoint['seg_net_state_dict'])
    cls_head.load_state_dict(checkpoint['cls_head_state_dict'])
    print(f"Loaded best checkpoint from epoch {checkpoint['epoch']}")
else:
    print("No best checkpoint found, using current model weights")

seg_net.eval(); cls_head.eval()

dice_metric.reset()
hausdorff_metric.reset()
test_y_true, test_y_prob = [], []
test_hausdorff_distances = []

with torch.no_grad():
    for batch in test_loader:
        batch = to_device(batch, device)
        images = batch['image']
        labels = batch['label'].long()
        class_labels = batch['class_label'].view(-1, 1)

        ctx = autocast(device_type='cuda', enabled=torch.cuda.is_available())
        with ctx:
            seg_logits = inferer(inputs=images, network=seg_net)
            _ = seg_net(images)
            feat = encoder_feat['x'] if encoder_feat['x'] is not None else seg_logits
            cls_logits = cls_head(feat)

            # Segmentation metrics
            y_pred = torch.softmax(seg_logits, dim=1)
            y_pred_discrete = torch.argmax(y_pred, dim=1, keepdim=True)
            dice_metric(y_pred=y_pred_discrete, y=labels)
            
            # Hausdorff distance
            try:
                hausdorff_metric(y_pred=y_pred_discrete, y=labels)
                hausdorff_dist = hausdorff_metric.aggregate().item()
                test_hausdorff_distances.append(hausdorff_dist)
                hausdorff_metric.reset()
            except:
                test_hausdorff_distances.append(float('nan'))

            # Classification metrics collection
            test_y_true.append(int(class_labels.item()))
            test_y_prob.append(torch.sigmoid(cls_logits).item())

# Calculate final test metrics
test_dice = dice_metric.aggregate().item()
test_hausdorff = np.nanmean(test_hausdorff_distances) if test_hausdorff_distances else float('nan')

test_y_pred = [1 if p >= 0.5 else 0 for p in test_y_prob]
test_cls_metrics = calculate_classification_metrics(test_y_true, test_y_prob, test_y_pred)

# Print comprehensive test results
print("\n" + "="*50)
print("FINAL TEST RESULTS")
print("="*50)
print(f"Segmentation Metrics:")
print(f"  - Dice Score: {test_dice:.4f}")
print(f"  - Hausdorff Distance: {test_hausdorff:.2f}")
print(f"\nClassification Metrics:")
print(f"  - Accuracy: {test_cls_metrics['accuracy']:.4f}")
print(f"  - F1 Score: {test_cls_metrics['f1_score']:.4f}")
print(f"  - AUC: {test_cls_metrics['auc']:.4f}")
print("="*50)

# Save final results
final_results = {
    'test_dice': test_dice,
    'test_hausdorff': test_hausdorff,
    'test_accuracy': test_cls_metrics['accuracy'],
    'test_f1': test_cls_metrics['f1_score'],
    'test_auc': test_cls_metrics['auc'],
    'training_history': state_manager.history
}

results_path = ckpt_dir / 'final_test_results.json'
with open(results_path, 'w') as f:
    json.dump(final_results, f, indent=2)
print(f"Results saved to: {results_path}")

In [None]:
# Comprehensive plotting and analysis for publication
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.gridspec import GridSpec

# Set style for publication-quality plots
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette("husl")

def create_publication_plots(history: Dict, save_dir: Path):
    """Create publication-quality plots for all training metrics"""
    
    # Create figure with subplots
    fig = plt.figure(figsize=(20, 15))
    gs = GridSpec(3, 3, figure=fig, hspace=0.3, wspace=0.3)
    
    # 1. Training and Validation Loss
    ax1 = fig.add_subplot(gs[0, :2])
    epochs = range(1, len(history['train_loss']) + 1)
    ax1.plot(epochs, history['train_loss'], 'b-', linewidth=2, label='Training Loss', marker='o', markersize=4)
    ax1.plot(epochs, history['val_loss'], 'r-', linewidth=2, label='Validation Loss', marker='s', markersize=4)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    
    # 2. Segmentation Metrics
    ax2 = fig.add_subplot(gs[0, 2])
    ax2.plot(epochs, history['val_dice'], 'g-', linewidth=2, label='Dice Score', marker='^', markersize=4)
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Dice Score', fontsize=12)
    ax2.set_title('Validation Dice Score', fontsize=14, fontweight='bold')
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    
    # 3. Classification Metrics
    ax3 = fig.add_subplot(gs[1, :])
    ax3.plot(epochs, history['val_accuracy'], 'purple', linewidth=2, label='Accuracy', marker='o', markersize=4)
    ax3.plot(epochs, history['val_f1'], 'orange', linewidth=2, label='F1 Score', marker='s', markersize=4)
    ax3.plot(epochs, history['val_auc'], 'brown', linewidth=2, label='AUC', marker='^', markersize=4)
    ax3.set_xlabel('Epoch', fontsize=12)
    ax3.set_ylabel('Score', fontsize=12)
    ax3.set_title('Classification Metrics', fontsize=14, fontweight='bold')
    ax3.legend(fontsize=11)
    ax3.grid(True, alpha=0.3)
    
    # 4. Hausdorff Distance
    ax4 = fig.add_subplot(gs[2, :2])
    ax4.plot(epochs, history['val_hausdorff'], 'red', linewidth=2, label='Hausdorff Distance', marker='o', markersize=4)
    ax4.set_xlabel('Epoch', fontsize=12)
    ax4.set_ylabel('Distance (mm)', fontsize=12)
    ax4.set_title('Validation Hausdorff Distance', fontsize=14, fontweight='bold')
    ax4.legend(fontsize=11)
    ax4.grid(True, alpha=0.3)
    
    # 5. Combined Performance Overview
    ax5 = fig.add_subplot(gs[2, 2])
    # Normalize metrics to 0-1 range for comparison
    dice_norm = np.array(history['val_dice'])
    acc_norm = np.array(history['val_accuracy'])
    f1_norm = np.array(history['val_f1'])
    auc_norm = np.array(history['val_auc'])
    
    # Create radar chart data
    categories = ['Dice', 'Accuracy', 'F1', 'AUC']
    values = [dice_norm[-1], acc_norm[-1], f1_norm[-1], auc_norm[-1]]  # Final epoch values
    
    angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
    values += values[:1]  # Close the loop
    angles += angles[:1]
    
    ax5.plot(angles, values, 'o-', linewidth=2, color='blue')
    ax5.fill(angles, values, alpha=0.25, color='blue')
    ax5.set_xticks(angles[:-1])
    ax5.set_xticklabels(categories)
    ax5.set_ylim(0, 1)
    ax5.set_title('Final Performance Overview', fontsize=14, fontweight='bold')
    ax5.grid(True)
    
    plt.tight_layout()
    
    # Save high-resolution plots
    plot_path = save_dir / 'training_metrics_publication.png'
    plt.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white')
    print(f"Publication plots saved to: {plot_path}")
    
    # Save individual plots
    individual_plots = {
        'loss_comparison': (gs[0, :2], 'loss_comparison.png'),
        'dice_score': (gs[0, 2], 'dice_score.png'),
        'classification_metrics': (gs[1, :], 'classification_metrics.png'),
        'hausdorff_distance': (gs[2, :2], 'hausdorff_distance.png'),
        'performance_overview': (gs[2, 2], 'performance_overview.png')
    }
    
    for name, (gs_pos, filename) in individual_plots.items():
        fig_ind = plt.figure(figsize=(8, 6))
        ax_ind = fig_ind.add_subplot(111)
        
        # Copy the subplot content
        ax_ind.plot(epochs, history['train_loss'], 'b-', linewidth=2, label='Training Loss', marker='o', markersize=4)
        ax_ind.plot(epochs, history['val_loss'], 'r-', linewidth=2, label='Validation Loss', marker='s', markersize=4)
        ax_ind.set_xlabel('Epoch', fontsize=12)
        ax_ind.set_ylabel('Loss', fontsize=12)
        ax_ind.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
        ax_ind.legend(fontsize=11)
        ax_ind.grid(True, alpha=0.3)
        
        plot_path = save_dir / filename
        plt.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='white')
        plt.close(fig_ind)
    
    plt.show()
    return fig

# Generate publication plots
if 'state_manager' in locals() and hasattr(state_manager, 'history'):
    print("Generating publication-quality plots...")
    fig = create_publication_plots(state_manager.history, ckpt_dir)
    
    # Save plot data for external plotting tools (e.g., LaTeX, R)
    plot_data = {
        'epochs': list(range(1, len(state_manager.history['train_loss']) + 1)),
        'metrics': state_manager.history
    }
    
    plot_data_path = ckpt_dir / 'plot_data.json'
    with open(plot_data_path, 'w') as f:
        json.dump(plot_data, f, indent=2)
    print(f"Plot data saved to: {plot_data_path}")
    
    # Generate summary statistics table
    final_epoch = len(state_manager.history['train_loss'])
    summary_stats = {
        'final_epoch': final_epoch,
        'best_val_dice': max(state_manager.history['val_dice']),
        'best_val_dice_epoch': np.argmax(state_manager.history['val_dice']) + 1,
        'final_val_accuracy': state_manager.history['val_accuracy'][-1],
        'final_val_f1': state_manager.history['val_f1'][-1],
        'final_val_auc': state_manager.history['val_auc'][-1],
        'final_val_hausdorff': state_manager.history['val_hausdorff'][-1],
        'training_convergence': {
            'train_loss_start': state_manager.history['train_loss'][0],
            'train_loss_end': state_manager.history['train_loss'][-1],
            'val_loss_start': state_manager.history['val_loss'][0],
            'val_loss_end': state_manager.history['val_loss'][-1]
        }
    }
    
    summary_path = ckpt_dir / 'training_summary.json'
    with open(summary_path, 'w') as f:
        json.dump(summary_stats, f, indent=2)
    print(f"Training summary saved to: {summary_path}")
    
else:
    print("No training history found. Run training first to generate plots.")