# 1. Environment Setup
Installing dependencies and importing required libraries.

In [None]:
import sys
import os
# Add the src directory to sys.path so we can import config
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "src")))

from config import *
setup_env()
add_custom_modules_to_path()


In [None]:
# !pip install nnunetv2

import os
import json
import shutil
from pathlib import Path
import numpy as np
from typing import List, Dict

print("✓ Dependencies installed and imported")


# 2. Experiment Configuration
Setting up paths, environment variables, and defining rare subjects.

In [None]:
# ----------------------------------------
# Training configuration - Fold Selection
# ----------------------------------------

FOLD = 0  # Select fold (0, 1, 2, 3, 4)
CHECKPOINT_PATH = f"Enter your checkpoint path here" # Define your Checkpoint paths


TRAINING_TIME_MINUTES = (11 * 60)+ (30)  # 11 hours 30 minutes


PREPROCESSED_PATH = "../data/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"

# nnUNet environment variables

# Subjects to oversample (rare abnormalities - non-cortical thickening)
RARE_SUBJECTS = [
    'sub-00001', 'sub-00003', 'sub-00014', 'sub-00015', 'sub-00016', 'sub-00018', 
    'sub-00024', 'sub-00027', 'sub-00033', 'sub-00038', 'sub-00040', 'sub-00044', 
    'sub-00050', 'sub-00053', 'sub-00055', 'sub-00058', 'sub-00060', 'sub-00063', 
    'sub-00065', 'sub-00073', 'sub-00077', 'sub-00078', 'sub-00080', 'sub-00081', 
    'sub-00083', 'sub-00087', 'sub-00089', 'sub-00097', 'sub-00098', 'sub-00101', 
    'sub-00105', 'sub-00109', 'sub-00112', 'sub-00115', 'sub-00116', 'sub-00122', 
    'sub-00123', 'sub-00126', 'sub-00130', 'sub-00132', 'sub-00133', 'sub-00138', 
    'sub-00146'
]

# Oversampling factor for rare subjects
OVERSAMPLE_FACTOR = 3.0

print(f"✓ Configuration set")
print(f"  - Number of rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")
print(f"  - Checkpoint path: {CHECKPOINT_PATH}")


# 3. Custom nnU-Net Components
Defining the custom DataLoader and Trainer classes to handle oversampling and safe training limits.

### Custom DataLoader
The `CustomOversamplingDataLoader` has been moved to `custom_nnunet/custom_dataloader.py`.
It is imported automatically by the trainer.

In [None]:
import nnunetv2
from pathlib import Path
import shutil
import sys
import os

# Define paths
current_dir = os.getcwd()
custom_trainer_path = os.path.join(current_dir, '..', 'src', 'custom_nnunet', 'nnUNetTrainerOversampling.py')

nnunet_trainers_dir = Path(nnunetv2.__file__).parent / "training" / "nnUNetTrainer"
target_path = nnunet_trainers_dir / "nnUNetTrainerOversampling.py"

print(f"Copying custom trainer to: {target_path}")
shutil.copy(custom_trainer_path, target_path)

# Set PYTHONPATH so the copied trainer can find config.py and custom_dataloader.py
if current_dir not in sys.path:
    sys.path.append(current_dir)
if os.path.join(current_dir, '..', 'src', 'custom_nnunet') not in sys.path:
    sys.path.append(os.path.join(current_dir, '..', 'src', 'custom_nnunet'))

# Expose to environment for subprocesses (like nnUNetv2_train)
os.environ['PYTHONPATH'] = f"{current_dir}:{os.path.join(current_dir, '..', 'src', 'custom_nnunet')}:{os.environ.get('PYTHONPATH', '')}"
print("\u2714\ufe0f Environment configured for custom modules")


# 4. Data Preparation
Copying the preprocessed dataset to the working directory.

In [None]:
source = "../data/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
dest = "../data/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"

print("Copying preprocessed data to writable location...")
print(f"From: {source}")
print(f"To: {dest}")

Path(dest).parent.mkdir(parents=True, exist_ok=True)
shutil.copytree(source, dest, dirs_exist_ok=True)

print(f"✓ Copied successfully")
print(f"✓ Updated nnUNet_preprocessed path")


# 5. Advanced Customization (Augmentation & Splits)
Applying doubled data augmentation parameters and enforcing fixed cross-validation splits.

In [None]:
import json
import math
import os

plans_path = "../data/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR/nnUNetPlans.json"

# Load plans
with open(plans_path, "r") as f:
    plans = json.load(f)

# ----------------------------------------------------
# Ensure the augmentation keys exist
# ----------------------------------------------------
if "data_augmentation" not in plans:
    plans["data_augmentation"] = {}

if "spatial" not in plans["data_augmentation"]:
    plans["data_augmentation"]["spatial"] = {}

if "intensity" not in plans["data_augmentation"]:
    plans["data_augmentation"]["intensity"] = {}

# ----------------------------------------------------
# 🔥 Doubled spatial augmentations
# ----------------------------------------------------
# Rotation: ±60 degrees (Default was 30)
plans["data_augmentation"]["spatial"]["rotation"] = {
    "x": 60 * (math.pi / 180),
    "y": 60 * (math.pi / 180),
    "z": 60 * (math.pi / 180)
}

# Scale: Range [0.70, 1.50] (Default was [0.85, 1.25])
# Logic: Doubled the deviation from 1.0 (15%->30% down, 25%->50% up)
plans["data_augmentation"]["spatial"]["scale_range"] = [0.70, 1.50]

# Elastic Deformation: Still Disabled (Doubling None is None)
# If you want to force it on, you must add a dictionary here.
plans["data_augmentation"]["spatial"]["elastic_deform"] = None

# ----------------------------------------------------
# 🔥 Doubled intensity augmentations
# ----------------------------------------------------
# Brightness: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["brightness"] = [0.5, 1.5]

# Contrast: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["contrast"] = [0.5, 1.5]

# Gaussian Noise: Variance 0.2 (Default was 0.1)
plans["data_augmentation"]["intensity"]["gaussian_noise_std"] = 0.2

# Gamma: Range [0.4, 2.0] (Default was [0.7, 1.5])
# Logic: Doubled the deviation from 1.0
plans["data_augmentation"]["intensity"]["gamma_range"] = [0.4, 2.0]

# Save back
with open(plans_path, "w") as f:
    json.dump(plans, f, indent=4)

print("🔥 Augmentation parameters DOUBLED successfully!")
print("- Rotation: ±60°")
print("- Scale: 0.70 - 1.50")
print("- Noise/Contrast/Gamma: Deviation doubled")


In [None]:
# Load custom splits from JSON
import json
with open(SPLITS_FILE_PATH, 'r') as f:
    custom_splits = json.load(f)

print(f"\u2713 Loaded custom splits from: {SPLITS_FILE_PATH}")
print(f"  - Total folds: {len(custom_splits)}")

# Enforce these splits by saving them to the preprocessed directory
# (This is what nnU-Net will actually use)
preprocessed_dir = nnUNet_preprocessed
if os.path.basename(preprocessed_dir) != 'Dataset002_BonnFCD_FLAIR':
    # Ensure we are pointing to the dataset folder
    preprocessed_dir = os.path.join(preprocessed_dir, 'Dataset002_BonnFCD_FLAIR')

splits_file = os.path.join(preprocessed_dir, "splits_final.json")
os.makedirs(preprocessed_dir, exist_ok=True)
with open(splits_file, "w") as f:
    json.dump(custom_splits, f, indent=4)

print(f"\u2705 Custom splits enforced successfully!")
print(f"  - Copied to: {splits_file}")


# 6. Checkpoint Restoration
Copying the previous checkpoint logic to resume training.

In [None]:
print("\n" + "="*70)
print("COPYING PREVIOUS CHECKPOINT")
print("="*70)

# Create results directory structure
results_dir = Path(f"../data/nnUNet_results/Dataset002_BonnFCD_FLAIR/nnUNetTrainerOversampling__nnUNetPlans__3d_fullres/fold_{FOLD}")
results_dir.mkdir(parents=True, exist_ok=True)

# Define source
checkpoint_source = Path(CHECKPOINT_PATH)

if checkpoint_source.exists():
    print(f"\n✓ Found checkpoint source: {checkpoint_source}")

    # CASE 1: user provided a specific FILE (e.g., checkpoint_final.pth)
    if checkpoint_source.is_file():
        # Rename it to what nnU-Net expects for resuming: 'checkpoint_latest.pth'
        dest_name = "checkpoint_latest.pth"
        dest_item = results_dir / dest_name
        shutil.copy2(checkpoint_source, dest_item)
        print(f"  ✓ Copied file: {checkpoint_source.name} -> {dest_name}")

    # CASE 2: user provided a DIRECTORY
    elif checkpoint_source.is_dir():
        # Copy all files from that directory
        for item in checkpoint_source.iterdir():
            dest_item = results_dir / item.name
            if item.is_file():
                shutil.copy2(item, dest_item)
                print(f"  ✓ Copied: {item.name}")
            elif item.is_dir():
                shutil.copytree(item, dest_item, dirs_exist_ok=True)
                print(f"  ✓ Copied directory: {item.name}")
    
    # Check what we have now
    if (results_dir / "checkpoint_latest.pth").exists() or (results_dir / "checkpoint_final.pth").exists():
        print(f"\n✓ Checkpoint restoration successful.")
    else:
        print("\n⚠ Warning: No valid checkpoint file (latest/final) found in destination!")

else:
    print(f"\n✗ Checkpoint path not found: {CHECKPOINT_PATH}")
    print("Training will start from scratch")

print("="*70)

# 7. Training Execution
Launching the nnU-Net training command with the custom trainer and fold configuration.

In [None]:
print("\n" + "="*70)
print("CONTINUING TRAINING WITH CUSTOM OVERSAMPLING")
print("="*70)
print(f"\nConfiguration:")
print(f"  - Dataset: 002 (BonnFCD_FLAIR)")
print(f"  - Configuration: 3d_fullres")
print(f"  - Fold: {FOLD}")
print(f"  - Trainer: nnUNetTrainerOversampling")
print(f"  - Rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")
print(f"  - Continuing from previous checkpoint")
print("="*70 + "\n")

# Continue training - nnUNet will automatically detect and load the latest checkpoint
!nnUNetv2_train 002 3d_fullres {FOLD} -tr nnUNetTrainerOversampling -p nnUNetPlans --npz --c