# 1. Environment Setup
Installing dependencies and importing required libraries.

In [None]:
!pip install nnunetv2

import os
import json
import shutil
from pathlib import Path
import numpy as np
from typing import List, Dict

print("✓ Dependencies installed and imported")


# 2. Experiment Configuration
Setting up paths, environment variables, and defining rare subjects.

In [None]:
# ----------------------------------------
# Training configuration - Fold Selection
# ----------------------------------------

FOLD = 0  # Select fold (0, 1, 2, 3, 4)
CHECKPOINT_PATH = f"Enter your checkpoint path here" # Define your Checkpoint paths


TRAINING_TIME_MINUTES = (11 * 60)+ (30)  # 11 hours 30 minutes


PREPROCESSED_PATH = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"

# nnUNet environment variables
os.environ['nnUNet_raw'] = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_raw_data_base/nnUNet_raw"
os.environ['nnUNet_preprocessed'] = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed"
os.environ['nnUNet_results'] = "/kaggle/working/nnUNet_results"
os.environ['nnUNet_compile'] = 'false'  # Disable for P100 GPU compatibility

# Subjects to oversample (rare abnormalities - non-cortical thickening)
RARE_SUBJECTS = [
    'sub-00001', 'sub-00003', 'sub-00014', 'sub-00015', 'sub-00016', 'sub-00018', 
    'sub-00024', 'sub-00027', 'sub-00033', 'sub-00038', 'sub-00040', 'sub-00044', 
    'sub-00050', 'sub-00053', 'sub-00055', 'sub-00058', 'sub-00060', 'sub-00063', 
    'sub-00065', 'sub-00073', 'sub-00077', 'sub-00078', 'sub-00080', 'sub-00081', 
    'sub-00083', 'sub-00087', 'sub-00089', 'sub-00097', 'sub-00098', 'sub-00101', 
    'sub-00105', 'sub-00109', 'sub-00112', 'sub-00115', 'sub-00116', 'sub-00122', 
    'sub-00123', 'sub-00126', 'sub-00130', 'sub-00132', 'sub-00133', 'sub-00138', 
    'sub-00146'
]

# Oversampling factor for rare subjects
OVERSAMPLE_FACTOR = 3.0

print(f"✓ Configuration set")
print(f"  - Number of rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")
print(f"  - Checkpoint path: {CHECKPOINT_PATH}")


# 3. Custom nnU-Net Components
Defining the custom DataLoader and Trainer classes to handle oversampling and safe training limits.

In [None]:
def create_custom_dataloader_file():
    """
    Create a custom DataLoader that modifies sampling probabilities
    """
    custom_loader_code = '''
from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader
import numpy as np

class CustomOversamplingDataLoader(nnUNetDataLoader):
    """Custom DataLoader with oversampling for both 2D and 3D"""
    
    def __init__(self, data, batch_size, patch_size, final_patch_size=None,
                 label_manager=None, oversample_foreground_percent=0.33,
                 sampling_probabilities=None, pad_kwargs_data=None,
                 pad_mode="constant", rare_subjects=None, oversample_factor=3.0):
        
        self.rare_subjects = rare_subjects if rare_subjects else []
        self.oversample_factor = oversample_factor
        
        # Initialize parent class WITHOUT num_threads_in_multithreaded
        super().__init__(data, batch_size, patch_size, final_patch_size,
                        label_manager, oversample_foreground_percent,
                        sampling_probabilities, pad_kwargs_data, pad_mode)
        
        # Modify sampling probabilities after initialization
        if self.rare_subjects:
            self._modify_sampling_probabilities()
    
    def _modify_sampling_probabilities(self):
        """Modify sampling probabilities to favor rare subjects"""
        if not hasattr(self, '_data') or self._data is None:
            return
        
        # Get the list of case identifiers from the dataset
        # nnUNetDatasetBlosc2 uses the 'identifiers' attribute
        case_ids = None
        
        # Try the identifiers attribute (works for nnUNetDatasetBlosc2)
        if hasattr(self._data, 'identifiers'):
            case_ids = self._data.identifiers
        
        # Fallback: try other possible attributes
        if case_ids is None and hasattr(self._data, 'indices'):
            case_ids = self._data.indices
        
        if case_ids is None and hasattr(self._data, 'keys') and callable(self._data.keys):
            try:
                case_ids = list(self._data.keys())
            except:
                pass
        
        if case_ids is None:
            print("  ⚠ Warning: Could not extract case IDs from dataset, oversampling disabled")
            return
        
        num_cases = len(case_ids)
        sampling_probs = np.ones(num_cases)
        
        # Increase weight for rare subjects
        rare_count = 0
        for idx, case_id in enumerate(case_ids):
            # Extract subject ID (format: sub-00001 or sub-00001_0000)
            subject_id = str(case_id).split('_')[0]
            
            if subject_id in self.rare_subjects:
                sampling_probs[idx] *= self.oversample_factor
                rare_count += 1
        
        # Normalize probabilities
        sampling_probs = sampling_probs / sampling_probs.sum()
        self.sampling_probabilities = sampling_probs
        
        print(f"  ✓ Modified sampling: {rare_count}/{num_cases} cases are rare subjects ({self.oversample_factor}x weight)")
'''
    
    # Save the custom dataloader
    os.makedirs("/kaggle/working/custom_nnunet", exist_ok=True)
    with open("/kaggle/working/custom_nnunet/custom_dataloader.py", "w") as f:
        f.write(custom_loader_code)
    
    print("✓ Custom DataLoader file created")

create_custom_dataloader_file()


In [None]:
import os
from pathlib import Path

def create_custom_trainer():
    """
    Create a custom trainer that stops after 11 hours and saves frequently.
    FIXED: Compatible with nnUNet v2.6+ (removed unpack_dataset)
    """
    trainer_code = f'''
import sys
import time
import torch
from batchgenerators.utilities.file_and_folder_operations import join
sys.path.insert(0, '/kaggle/working/custom_nnunet')

from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from custom_dataloader import CustomOversamplingDataLoader
from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader

class nnUNetTrainerOversampling(nnUNetTrainer):
    """Custom trainer with oversampling AND time-limit safety"""
    
    # -------------------------------------------------------
    # 1. Configuration
    # -------------------------------------------------------
    rare_subjects = [
    'sub-00001', 'sub-00003', 'sub-00014', 'sub-00015', 'sub-00016', 'sub-00018', 
    'sub-00024', 'sub-00027', 'sub-00033', 'sub-00038', 'sub-00040', 'sub-00044', 
    'sub-00050', 'sub-00053', 'sub-00055', 'sub-00058', 'sub-00060', 'sub-00063', 
    'sub-00065', 'sub-00073', 'sub-00077', 'sub-00078', 'sub-00080', 'sub-00081', 
    'sub-00083', 'sub-00087', 'sub-00089', 'sub-00097', 'sub-00098', 'sub-00101', 
    'sub-00105', 'sub-00109', 'sub-00112', 'sub-00115', 'sub-00116', 'sub-00122', 
    'sub-00123', 'sub-00126', 'sub-00130', 'sub-00132', 'sub-00133', 'sub-00138', 
    'sub-00146'
]
    oversample_factor = 3.0
    
    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device: torch.device = torch.device('cuda')):
        # FIX: Removed unpack_dataset from super().__init__
        super().__init__(plans, configuration, fold, dataset_json, device=device)
        
        # 🔥 SAVE FREQUENCY: Save a permanent checkpoint every 20 epochs
        self.save_every = 20
        
        # 🔥 TIME LIMIT: Stop training after 11 hours (39600 seconds)
        self.max_time_seconds = {TRAINING_TIME_MINUTES} * 60 
        self.start_time = time.time()

    # -------------------------------------------------------
    # 2. Oversampling Logic
    # -------------------------------------------------------
    def get_tr_and_val_datasets(self):
        dataset_tr, dataset_val = super().get_tr_and_val_datasets()
        return dataset_tr, dataset_val
    
    def get_plain_dataloaders(self):
        dataset_tr, dataset_val = self.get_tr_and_val_datasets()
        
        dl_tr = CustomOversamplingDataLoader(
            dataset_tr, self.batch_size, self.patch_size, self.patch_size,
            self.label_manager, oversample_foreground_percent=self.oversample_foreground_percent,
            rare_subjects=self.rare_subjects, oversample_factor=self.oversample_factor
        )
        
        dl_val = nnUNetDataLoader(
            dataset_val, self.batch_size, self.patch_size, self.patch_size,
            self.label_manager, oversample_foreground_percent=self.oversample_foreground_percent
        )
        return dl_tr, dl_val

    # -------------------------------------------------------
    # 3. Time Check Logic
    # -------------------------------------------------------
    def on_epoch_end(self):
        """Check time at the end of every epoch"""
        super().on_epoch_end()
        
        # Calculate elapsed time
        elapsed = time.time() - self.start_time
        
        # If we passed the limit, force a stop
        if elapsed > self.max_time_seconds:
            self.print_to_log_file(f"\\n⏰ TIME LIMIT REACHED ({{elapsed/3600:.2f}} hours).")
            self.print_to_log_file("Stopping training gracefully to save checkpoints safely.")
            
            # Force nnU-Net to save the latest state explicitly
            self.save_checkpoint(join(self.output_folder, "checkpoint_latest.pth"))
            
            # Stop the training loop
            self.on_train_end()
            sys.exit(0) # Exit the script successfully so Kaggle saves the output

'''
    
    # Save the file
    import nnunetv2
    nnunet_trainers_dir = Path(nnunetv2.__file__).parent / "training" / "nnUNetTrainer"
    trainer_file = nnunet_trainers_dir / "nnUNetTrainerOversampling.py"
    
    with open(trainer_file, "w") as f:
        f.write(trainer_code)
    
    print(f"✔️ Custom Trainer updated (Fixed for nnUNet v2.6+)")
    
    # Also save in working directory
    os.makedirs("/kaggle/working/custom_nnunet", exist_ok=True)
    with open("/kaggle/working/custom_nnunet/nnUNetTrainerOversampling.py", "w") as f:
        f.write(trainer_code)

create_custom_trainer()

# 4. Data Preparation
Copying the preprocessed dataset to the working directory.

In [None]:
source = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
dest = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"

print("Copying preprocessed data to writable location...")
print(f"From: {source}")
print(f"To: {dest}")

Path(dest).parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(source, dest, dirs_exist_ok=True)

print(f"✓ Copied successfully")
os.environ['nnUNet_preprocessed'] = "/kaggle/working/nnUNet_preprocessed"
print(f"✓ Updated nnUNet_preprocessed path")


# 5. Advanced Customization (Augmentation & Splits)
Applying doubled data augmentation parameters and enforcing fixed cross-validation splits.

In [None]:
import json
import math
import os

plans_path = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR/nnUNetPlans.json"

# Load plans
with open(plans_path, "r") as f:
    plans = json.load(f)

# ----------------------------------------------------
# Ensure the augmentation keys exist
# ----------------------------------------------------
if "data_augmentation" not in plans:
    plans["data_augmentation"] = {}

if "spatial" not in plans["data_augmentation"]:
    plans["data_augmentation"]["spatial"] = {}

if "intensity" not in plans["data_augmentation"]:
    plans["data_augmentation"]["intensity"] = {}

# ----------------------------------------------------
# 🔥 Doubled spatial augmentations
# ----------------------------------------------------
# Rotation: ±60 degrees (Default was 30)
plans["data_augmentation"]["spatial"]["rotation"] = {
    "x": 60 * (math.pi / 180),
    "y": 60 * (math.pi / 180),
    "z": 60 * (math.pi / 180)
}

# Scale: Range [0.70, 1.50] (Default was [0.85, 1.25])
# Logic: Doubled the deviation from 1.0 (15%->30% down, 25%->50% up)
plans["data_augmentation"]["spatial"]["scale_range"] = [0.70, 1.50]

# Elastic Deformation: Still Disabled (Doubling None is None)
# If you want to force it on, you must add a dictionary here.
plans["data_augmentation"]["spatial"]["elastic_deform"] = None

# ----------------------------------------------------
# 🔥 Doubled intensity augmentations
# ----------------------------------------------------
# Brightness: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["brightness"] = [0.5, 1.5]

# Contrast: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["contrast"] = [0.5, 1.5]

# Gaussian Noise: Variance 0.2 (Default was 0.1)
plans["data_augmentation"]["intensity"]["gaussian_noise_std"] = 0.2

# Gamma: Range [0.4, 2.0] (Default was [0.7, 1.5])
# Logic: Doubled the deviation from 1.0
plans["data_augmentation"]["intensity"]["gamma_range"] = [0.4, 2.0]

# Save back
with open(plans_path, "w") as f:
    json.dump(plans, f, indent=4)

print("🔥 Augmentation parameters DOUBLED successfully!")
print("- Rotation: ±60°")
print("- Scale: 0.70 - 1.50")
print("- Noise/Contrast/Gamma: Deviation doubled")


In [None]:
import json
import os

# ----------------------------------------------------
# 🧪 Custom 5-Fold Cross-Validation Splits
# ----------------------------------------------------
custom_splits = [
    {
        "train": [
            "sub-00001",
            "sub-00003",
            "sub-00010",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00024",
            "sub-00027",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00044",
            "sub-00047",
            "sub-00050",
            "sub-00053",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00060",
            "sub-00063",
            "sub-00071",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00077",
            "sub-00078",
            "sub-00080",
            "sub-00083",
            "sub-00089",
            "sub-00090",
            "sub-00091",
            "sub-00097",
            "sub-00098",
            "sub-00101",
            "sub-00112",
            "sub-00120",
            "sub-00122",
            "sub-00123",
            "sub-00130",
            "sub-00132",
            "sub-00133",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ],
        "val": [
            "sub-00018",
            "sub-00065",
            "sub-00068",
            "sub-00081",
            "sub-00087",
            "sub-00105",
            "sub-00109",
            "sub-00115",
            "sub-00116",
            "sub-00126",
            "sub-00131",
            "sub-00138"
        ]
    },
    {
        "train": [
            "sub-00001",
            "sub-00003",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00018",
            "sub-00024",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00047",
            "sub-00050",
            "sub-00053",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00065",
            "sub-00068",
            "sub-00071",
            "sub-00077",
            "sub-00078",
            "sub-00080",
            "sub-00081",
            "sub-00083",
            "sub-00087",
            "sub-00089",
            "sub-00090",
            "sub-00091",
            "sub-00097",
            "sub-00098",
            "sub-00105",
            "sub-00109",
            "sub-00112",
            "sub-00115",
            "sub-00116",
            "sub-00120",
            "sub-00126",
            "sub-00130",
            "sub-00131",
            "sub-00133",
            "sub-00138",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ],
        "val": [
            "sub-00010",
            "sub-00027",
            "sub-00044",
            "sub-00060",
            "sub-00063",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00101",
            "sub-00122",
            "sub-00123",
            "sub-00132"
        ]
    },
    {
        "train": [
            "sub-00003",
            "sub-00010",
            "sub-00018",
            "sub-00024",
            "sub-00027",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00044",
            "sub-00047",
            "sub-00050",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00060",
            "sub-00063",
            "sub-00065",
            "sub-00068",
            "sub-00071",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00077",
            "sub-00080",
            "sub-00081",
            "sub-00087",
            "sub-00090",
            "sub-00097",
            "sub-00098",
            "sub-00101",
            "sub-00105",
            "sub-00109",
            "sub-00112",
            "sub-00115",
            "sub-00116",
            "sub-00122",
            "sub-00123",
            "sub-00126",
            "sub-00131",
            "sub-00132",
            "sub-00133",
            "sub-00138",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ],
        "val": [
            "sub-00001",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00053",
            "sub-00078",
            "sub-00083",
            "sub-00089",
            "sub-00091",
            "sub-00120",
            "sub-00130"
        ]
    },
    {
        "train": [
            "sub-00001",
            "sub-00003",
            "sub-00010",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00018",
            "sub-00027",
            "sub-00044",
            "sub-00050",
            "sub-00053",
            "sub-00060",
            "sub-00063",
            "sub-00065",
            "sub-00068",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00078",
            "sub-00080",
            "sub-00081",
            "sub-00083",
            "sub-00087",
            "sub-00089",
            "sub-00090",
            "sub-00091",
            "sub-00097",
            "sub-00098",
            "sub-00101",
            "sub-00105",
            "sub-00109",
            "sub-00112",
            "sub-00115",
            "sub-00116",
            "sub-00120",
            "sub-00122",
            "sub-00123",
            "sub-00126",
            "sub-00130",
            "sub-00131",
            "sub-00132",
            "sub-00138",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ],
        "val": [
            "sub-00024",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00047",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00071",
            "sub-00077",
            "sub-00133"
        ]
    },
    {
        "train": [
            "sub-00001",
            "sub-00010",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00018",
            "sub-00024",
            "sub-00027",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00044",
            "sub-00047",
            "sub-00053",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00060",
            "sub-00063",
            "sub-00065",
            "sub-00068",
            "sub-00071",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00077",
            "sub-00078",
            "sub-00081",
            "sub-00083",
            "sub-00087",
            "sub-00089",
            "sub-00091",
            "sub-00101",
            "sub-00105",
            "sub-00109",
            "sub-00115",
            "sub-00116",
            "sub-00120",
            "sub-00122",
            "sub-00123",
            "sub-00126",
            "sub-00130",
            "sub-00131",
            "sub-00132",
            "sub-00133",
            "sub-00138"
        ],
        "val": [
            "sub-00003",
            "sub-00050",
            "sub-00080",
            "sub-00090",
            "sub-00097",
            "sub-00098",
            "sub-00112",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ]
    }
]

preprocessed_dir = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
splits_file = os.path.join(preprocessed_dir, "splits_final.json")

os.makedirs(preprocessed_dir, exist_ok=True)
with open(splits_file, "w") as f:
    json.dump(custom_splits, f, indent=4)

print(f"✅ Custom splits enforced successfully!")
print(f"  - Location: {splits_file}")


# 6. Checkpoint Restoration
Copying the previous checkpoint logic to resume training.

In [None]:
print("\n" + "="*70)
print("COPYING PREVIOUS CHECKPOINT")
print("="*70)

# Create results directory structure
results_dir = Path(f"/kaggle/working/nnUNet_results/Dataset002_BonnFCD_FLAIR/nnUNetTrainerOversampling__nnUNetPlans__3d_fullres/fold_{FOLD}")
results_dir.mkdir(parents=True, exist_ok=True)

# Define source
checkpoint_source = Path(CHECKPOINT_PATH)

if checkpoint_source.exists():
    print(f"\n✓ Found checkpoint source: {checkpoint_source}")

    # CASE 1: user provided a specific FILE (e.g., checkpoint_final.pth)
    if checkpoint_source.is_file():
        # Rename it to what nnU-Net expects for resuming: 'checkpoint_latest.pth'
        dest_name = "checkpoint_latest.pth"
        dest_item = results_dir / dest_name
        shutil.copy2(checkpoint_source, dest_item)
        print(f"  ✓ Copied file: {checkpoint_source.name} -> {dest_name}")

    # CASE 2: user provided a DIRECTORY
    elif checkpoint_source.is_dir():
        # Copy all files from that directory
        for item in checkpoint_source.iterdir():
            dest_item = results_dir / item.name
            if item.is_file():
                shutil.copy2(item, dest_item)
                print(f"  ✓ Copied: {item.name}")
            elif item.is_dir():
                shutil.copytree(item, dest_item, dirs_exist_ok=True)
                print(f"  ✓ Copied directory: {item.name}")
    
    # Check what we have now
    if (results_dir / "checkpoint_latest.pth").exists() or (results_dir / "checkpoint_final.pth").exists():
        print(f"\n✓ Checkpoint restoration successful.")
    else:
        print("\n⚠ Warning: No valid checkpoint file (latest/final) found in destination!")

else:
    print(f"\n✗ Checkpoint path not found: {CHECKPOINT_PATH}")
    print("Training will start from scratch")

print("="*70)

# 7. Training Execution
Launching the nnU-Net training command with the custom trainer and fold configuration.

In [None]:
print("\n" + "="*70)
print("CONTINUING TRAINING WITH CUSTOM OVERSAMPLING")
print("="*70)
print(f"\nConfiguration:")
print(f"  - Dataset: 002 (BonnFCD_FLAIR)")
print(f"  - Configuration: 3d_fullres")
print(f"  - Fold: {FOLD}")
print(f"  - Trainer: nnUNetTrainerOversampling")
print(f"  - Rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")
print(f"  - Continuing from previous checkpoint")
print("="*70 + "\n")

# Continue training - nnUNet will automatically detect and load the latest checkpoint
!nnUNetv2_train 002 3d_fullres {FOLD} -tr nnUNetTrainerOversampling -p nnUNetPlans --npz --c