# 1. Environment Setup
Installing necessary dependencies and importing libraries.

In [None]:
!pip install nnunetv2

import os
import json
import shutil
from pathlib import Path
import numpy as np
from typing import List, Dict

print("✓ Dependencies installed and imported")


# 2. Experiment Configuration
Defining experiment parameters, selecting the specific **Cross-Validation Fold**, and configuring rare subject oversampling.

In [None]:
# Training configuration - Fold Selection
FOLD = 4  # Select fold (0, 1, 2, 3, 4)

# ------------------------------------------
# User Programmable Parameters
# ------------------------------------------
TRAINING_TIME_MINUTES = (11 * 60) + 45  # 11 hours 45 minutes
OVERSAMPLE_FACTOR = 3.0

# Define your dataset paths - MODIFY THESE TO MATCH YOUR KAGGLE PATHS
PREPROCESSED_PATH = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
RAW_PATH = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_raw_data_base/nnUNet_raw/Dataset002_BonnFCD_FLAIR"

# nnUNet environment variables
os.environ['nnUNet_raw'] = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_raw_data_base/nnUNet_raw"
os.environ['nnUNet_preprocessed'] = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed"
os.environ['nnUNet_results'] = "/kaggle/working/nnUNet_results"
if 'nnUNet_compile' not in os.environ:
    os.environ['nnUNet_compile'] = 'false' # Set default

# ------------------------------------------
# Configuration for Custom Modules
# ------------------------------------------
# Define the path to the custom modules directory
CUSTOM_MODULES_PATH = "custom_modules"
RARE_SUBJECTS_PATH = os.path.join(CUSTOM_MODULES_PATH, "rare_subjects.json")
SPLITS_PATH = os.path.join(CUSTOM_MODULES_PATH, "splits_final.json")

# Load rare subjects for validation checks
with open(RARE_SUBJECTS_PATH, "r") as f:
    RARE_SUBJECTS = json.load(f)

# CREATE TRAINER CONFIG JSON
trainer_config = {
    "oversample_factor": OVERSAMPLE_FACTOR,
    "max_time_minutes": TRAINING_TIME_MINUTES
}
config_path = os.path.join(CUSTOM_MODULES_PATH, "trainer_config.json")
with open(config_path, "w") as f:
    json.dump(trainer_config, f, indent=4)

print(f"✓ Configuration set")
print(f"  - Custom modules path: {CUSTOM_MODULES_PATH}")
print(f"  - Trainer config created at: {config_path}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")


# 3. Custom nnU-Net Components
Defining and generating the `CustomOversamplingDataLoader` and `nnUNetTrainerOversampling` to handle class imbalance during training.

In [None]:
def create_custom_dataloader_file():
    """
    Copy the custom DataLoader from custom_modules to the working directory
    """
    source = os.path.join(CUSTOM_MODULES_PATH, "custom_dataloader.py")
    target_dir = "/kaggle/working/custom_nnunet"
    os.makedirs(target_dir, exist_ok=True)
    target = os.path.join(target_dir, "custom_dataloader.py")
    
    shutil.copy2(source, target)
    print(f"✓ Custom DataLoader copied from {source} to {target}")

create_custom_dataloader_file()


In [None]:
import nnunetv2
from pathlib import Path

def create_custom_trainer():
    """
    Copy the custom Trainer and config from custom_modules to:
    1. The working directory (for reference/import)
    2. The nnU-Net package directory (so it can be used by -tr argument)
    """
    target_dir = "/kaggle/working/custom_nnunet"
    os.makedirs(target_dir, exist_ok=True)
    
    # Files to copy
    files_to_copy = [
        "nnUNetTrainerOversampling.py",
        "rare_subjects.json",
        "trainer_config.json"
    ]
    
    # 1. Copy to working dir
    for file in files_to_copy:
        src = os.path.join(CUSTOM_MODULES_PATH, file)
        dst = os.path.join(target_dir, file)
        if os.path.exists(src):
            shutil.copy2(src, dst)
        else:
            print(f"⚠ Warning: {src} not found!")

    # 2. Copy to nnU-Net package location
    nnunet_trainers_dir = Path(nnunetv2.__file__).parent / "training" / "nnUNetTrainer"
    
    for file in files_to_copy:
        src = os.path.join(CUSTOM_MODULES_PATH, file)
        dst = nnunet_trainers_dir / file
        if os.path.exists(src):
            shutil.copy2(src, dst)
            
    print(f"✓ Custom Trainer deployed to:")
    print(f"  - {target_dir}")
    print(f"  - {nnunet_trainers_dir}")

create_custom_trainer()


# 4. Data Preparation
Verifying the input dataset integrity and copying it to a writable directory (`/kaggle/working`) for processing.

In [None]:
def verify_dataset():
    """Verify the preprocessed dataset structure and subject identifiers"""
    
    print("\n" + "="*70)
    print("DATASET VERIFICATION")
    print("="*70)
    
    # Check preprocessed data
    if os.path.exists(PREPROCESSED_PATH):
        print(f"\n✓ Preprocessed data found at: {PREPROCESSED_PATH}")
        
        # List nnUNet plans
        plans_files = list(Path(PREPROCESSED_PATH).glob("nnUNetPlans*.json"))
        if plans_files:
            print(f"  - Found {len(plans_files)} plan file(s)")
            for plan in plans_files:
                print(f"    • {plan.name}")
        
        # Check for preprocessed data folders
        print("\n✓ Checking preprocessed data structure:")
        for config in ['nnUNetPlans_2d', 'nnUNetPlans_3d_fullres']:
            config_path = Path(PREPROCESSED_PATH) / config
            if config_path.exists():
                # Count .npz or .npy files in the configuration folder
                npz_files = list(config_path.glob("**/*.npz"))
                npy_files = list(config_path.glob("**/*.npy"))
                pkl_files = list(config_path.glob("**/*.pkl"))
                
                total_files = len(npz_files) + len(npy_files) + len(pkl_files)
                print(f"  - {config}: {total_files} preprocessed files")
                
                # Try to extract subject IDs from filenames
                if npz_files or npy_files or pkl_files:
                    all_files = npz_files + npy_files + pkl_files
                    subjects = set()
                    for f in all_files[:50]:  # Sample first 50 files
                        # Extract subject ID (format might be: sub-00001.npz or sub-00001_0000.npz)
                        fname = f.stem
                        if fname.startswith('sub-'):
                            subject_id = fname.split('_')[0]
                            subjects.add(subject_id)
                    
                    if subjects:
                        print(f"    • Found {len(subjects)} unique subjects (sampled)")
                        
                        # Check rare subjects
                        rare_found = [s for s in RARE_SUBJECTS if s in subjects]
                        if rare_found:
                            print(f"    • {len(rare_found)} rare subjects found in sample")
    else:
        print(f"\n✗ Preprocessed data NOT found at: {PREPROCESSED_PATH}")
        return False
    
    # Also check dataset.json for complete subject list
    dataset_json_path = Path(PREPROCESSED_PATH) / "dataset.json"
    if dataset_json_path.exists():
        print("\n✓ Reading dataset.json for complete subject list...")
        with open(dataset_json_path, 'r') as f:
            dataset_info = json.load(f)
            if 'training' in dataset_info:
                training_cases = dataset_info['training']
                subjects = set()
                for case in training_cases:
                    # Extract subject ID from image path
                    img_path = case.get('image', '')
                    if isinstance(img_path, list):
                        img_path = img_path[0]
                    fname = Path(img_path).stem
                    if fname.startswith('sub-'):
                        subject_id = fname.split('_')[0]
                        subjects.add(subject_id)
                
                print(f"  - Total training subjects: {len(subjects)}")
                
                # Check rare subjects
                rare_found = [s for s in RARE_SUBJECTS if s in subjects]
                print(f"  - Rare subjects found: {len(rare_found)}/{len(RARE_SUBJECTS)}")
                
                if len(rare_found) < len(RARE_SUBJECTS):
                    missing = set(RARE_SUBJECTS) - subjects
                    print(f"\n⚠ Warning: {len(missing)} rare subjects not in dataset:")
                    for m in sorted(missing)[:10]:
                        print(f"    • {m}")
                    if len(missing) > 10:
                        print(f"    ... and {len(missing)-10} more")
    
    print("\n" + "="*70)
    return True

verify_dataset()


In [None]:
# Copy preprocessed data to writable location
source = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
dest = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"

print("Copying preprocessed data to writable location...")
print(f"From: {source}")
print(f"To: {dest}")

# Create destination directory
Path(dest).parent.mkdir(parents=True, exist_ok=True)

# Copy the entire dataset
shutil.copytree(source, dest, dirs_exist_ok=True)

print(f"✓ Copied successfully")

# Update environment variable
os.environ['nnUNet_preprocessed'] = "/kaggle/working/nnUNet_preprocessed"

print(f"✓ Updated nnUNet_preprocessed path to: {os.environ['nnUNet_preprocessed']}")


In [None]:
import json
import math
import os

plans_path = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR/nnUNetPlans.json"

# Load plans
with open(plans_path, "r") as f:
    plans = json.load(f)

# ----------------------------------------------------
# Ensure the augmentation keys exist
# ----------------------------------------------------
if "data_augmentation" not in plans:
    plans["data_augmentation"] = {}

if "spatial" not in plans["data_augmentation"]:
    plans["data_augmentation"]["spatial"] = {}

if "intensity" not in plans["data_augmentation"]:
    plans["data_augmentation"]["intensity"] = {}

# ----------------------------------------------------
# 🔥 Doubled spatial augmentations
# ----------------------------------------------------
# Rotation: ±60 degrees (Default was 30)
plans["data_augmentation"]["spatial"]["rotation"] = {
    "x": 60 * (math.pi / 180),
    "y": 60 * (math.pi / 180),
    "z": 60 * (math.pi / 180)
}

# Scale: Range [0.70, 1.50] (Default was [0.85, 1.25])
# Logic: Doubled the deviation from 1.0 (15%->30% down, 25%->50% up)
plans["data_augmentation"]["spatial"]["scale_range"] = [0.70, 1.50]

# Elastic Deformation: Still Disabled (Doubling None is None)
# If you want to force it on, you must add a dictionary here.
plans["data_augmentation"]["spatial"]["elastic_deform"] = None

# ----------------------------------------------------
# 🔥 Doubled intensity augmentations
# ----------------------------------------------------
# Brightness: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["brightness"] = [0.5, 1.5]

# Contrast: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["contrast"] = [0.5, 1.5]

# Gaussian Noise: Variance 0.2 (Default was 0.1)
plans["data_augmentation"]["intensity"]["gaussian_noise_std"] = 0.2

# Gamma: Range [0.4, 2.0] (Default was [0.7, 1.5])
# Logic: Doubled the deviation from 1.0
plans["data_augmentation"]["intensity"]["gamma_range"] = [0.4, 2.0]

# Save back
with open(plans_path, "w") as f:
    json.dump(plans, f, indent=4)

print("🔥 Augmentation parameters DOUBLED successfully!")
print("- Rotation: ±60°")
print("- Scale: 0.70 - 1.50")
print("- Noise/Contrast/Gamma: Deviation doubled")


In [None]:
import json
import os

# ----------------------------------------------------
# 🧪 Custom 5-Fold Cross-Validation Splits
# ----------------------------------------------------
# Load splits from JSON
with open(SPLITS_PATH, "r") as f:
    custom_splits = json.load(f)

# Define the target path where nnU-Net expects the splits file
# It must be in the preprocessed directory of the specific task
preprocessed_dir = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
splits_file = os.path.join(preprocessed_dir, "splits_final.json")

# Write the splits to the file
os.makedirs(preprocessed_dir, exist_ok=True)
with open(splits_file, "w") as f:
    json.dump(custom_splits, f, indent=4)

print(f"✅ Custom splits enforced successfully!")
print(f"  - Loaded from: {SPLITS_PATH}")
print(f"  - Written to: {splits_file}")
print(f"  - Total folds: {len(custom_splits)}")


In [None]:
print("\n" + "="*70)
print("✓ Dataset is already preprocessed - skipping preprocessing step")
print("="*70)
print("\nOversampling will be applied during training via custom DataLoader")
print("This modifies sampling probabilities AFTER preprocessing as intended")
print("="*70)

In [None]:
# CELL: Disable torch.compile for P100 GPU compatibility
import os
os.environ['nnUNet_compile'] = 'false'

print("✓ Disabled torch.compile for GPU compatibility")
print("  (P100 GPU doesn't support Triton compiler)")

# 6. Training Execution
Final configuration checks and executing the nnU-Net training command.

In [None]:
print("\n" + "="*70)
print("STARTING TRAINING WITH CUSTOM OVERSAMPLING")
print("="*70)
print(f"\nConfiguration:")
print(f"  - Dataset: 002 (BonnFCD_FLAIR)")
print(f"  - Configuration: 3d_fullres")
print(f"  - Fold: {FOLD}")
print(f"  - Trainer: nnUNetTrainerOversampling")
print(f"  - Rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")
print("="*70 + "\n")

!nnUNetv2_train 002 3d_fullres {FOLD} -tr nnUNetTrainerOversampling -p nnUNetPlans --npz