# Proposed Model Training (Oversampling + Extensive Augmentation)

This notebook trains the **Proposed nnU-Net-based approach**. It incorporates:

1. **Radiological Feature-Based Oversampling**: Addressing class imbalance by 3x oversampling of underrepresented FCD Type II radiological features (transmantle sign, gray-white matter blurring).
2. **Extensive Data Augmentation**: Utilizing wider ranges for rotation ($\pm 60^{\circ}$), scaling [0.70, 1.50], and brightness/contrast to improve generalization.

---
**Paper Reference:** *Evaluation of nnU-Net for FCD II Lesion Segmentation in FLAIR MRI*

# 1. Environment Setup
Installing necessary dependencies and importing libraries.

In [None]:
import sys
import os
# Add the src directory to sys.path so we can import config
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "..", "src")))

from config import *
setup_env()
add_custom_modules_to_path()


In [None]:
# !pip install nnunetv2

import os
import json
import shutil
from pathlib import Path
import numpy as np
from typing import List, Dict

print("✓ Dependencies installed and imported")


# 2. Experiment Configuration
Defining experiment parameters, selecting the specific **Cross-Validation Fold**, and configuring rare subject oversampling.

In [None]:
# Training configuration - Fold Selection
FOLD = 0  # Select fold (0, 1, 2, 3, 4)


# Define your dataset paths - MODIFY THESE TO MATCH YOUR KAGGLE PATHS
PREPROCESSED_PATH = os.path.join(nnUNet_preprocessed, 'Dataset002_BonnFCD_FLAIR')
RAW_PATH = os.path.join(nnUNet_raw, 'Dataset002_BonnFCD_FLAIR')

# nnUNet environment variables

# Subjects to oversample (rare abnormalities - non-cortical thickening)
RARE_SUBJECTS = [

# Oversampling factor for rare subjects
# OVERSAMPLE_FACTOR loaded from config

print(f"✓ Configuration set")
print(f"  - Number of rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")



# 3. Custom nnU-Net Components
Defining and generating the `CustomOversamplingDataLoader` and `nnUNetTrainerOversampling` to handle class imbalance during training.

### Custom DataLoader
The `CustomOversamplingDataLoader` has been moved to `nnunet_extensions/custom_dataloader.py`.
It is imported automatically by the trainer.

In [None]:
import nnunetv2
from pathlib import Path
import shutil
import sys
import os

# Define paths
current_dir = os.getcwd()
custom_trainer_path = os.path.join(current_dir, '..', '..', 'src', 'nnunet_extensions', 'nnUNetTrainerOversampling.py')

nnunet_trainers_dir = Path(nnunetv2.__file__).parent / "training" / "nnUNetTrainer"
target_path = nnunet_trainers_dir / "nnUNetTrainerOversampling.py"

print(f"Copying custom trainer to: {target_path}")
shutil.copy(custom_trainer_path, target_path)

# Set PYTHONPATH so the copied trainer can find config.py and custom_dataloader.py
if current_dir not in sys.path:
    sys.path.append(current_dir)
if os.path.join(current_dir, '..', '..', 'src', 'nnunet_extensions') not in sys.path:
    sys.path.append(os.path.join(current_dir, '..', '..', 'src', 'nnunet_extensions'))

# Expose to environment for subprocesses (like nnUNetv2_train)
os.environ['PYTHONPATH'] = f"{current_dir}:{os.path.join(current_dir, '..', '..', 'src', 'nnunet_extensions')}:{os.environ.get('PYTHONPATH', '')}"
print("\u2714\ufe0f Environment configured for custom modules")


# 4. Data Preparation
Verifying the input dataset integrity and copying it to a writable directory (`/kaggle/working`) for processing.

In [None]:
def verify_dataset():
    """Verify the preprocessed dataset structure and subject identifiers"""
    
    print("\n" + "="*70)
    print("DATASET VERIFICATION")
    print("="*70)
    
    # Check preprocessed data
    if os.path.exists(PREPROCESSED_PATH):
        print(f"\n✓ Preprocessed data found at: {PREPROCESSED_PATH}")
        
        # List nnUNet plans
        plans_files = list(Path(PREPROCESSED_PATH).glob("nnUNetPlans*.json"))
        if plans_files:
            print(f"  - Found {len(plans_files)} plan file(s)")
            for plan in plans_files:
                print(f"    • {plan.name}")
        
        # Check for preprocessed data folders
        print("\n✓ Checking preprocessed data structure:")
        for config in ['nnUNetPlans_2d', 'nnUNetPlans_3d_fullres']:
            config_path = Path(PREPROCESSED_PATH) / config
            if config_path.exists():
                # Count .npz or .npy files in the configuration folder
                npz_files = list(config_path.glob("**/*.npz"))
                npy_files = list(config_path.glob("**/*.npy"))
                pkl_files = list(config_path.glob("**/*.pkl"))
                
                total_files = len(npz_files) + len(npy_files) + len(pkl_files)
                print(f"  - {config}: {total_files} preprocessed files")
                
                # Try to extract subject IDs from filenames
                if npz_files or npy_files or pkl_files:
                    all_files = npz_files + npy_files + pkl_files
                    subjects = set()
                    for f in all_files[:50]:  # Sample first 50 files
                        # Extract subject ID (format might be: sub-00001.npz or sub-00001_0000.npz)
                        fname = f.stem
                        if fname.startswith('sub-'):
                            subject_id = fname.split('_')[0]
                            subjects.add(subject_id)
                    
                    if subjects:
                        print(f"    • Found {len(subjects)} unique subjects (sampled)")
                        
                        # Check rare subjects
                        rare_found = [s for s in RARE_SUBJECTS if s in subjects]
                        if rare_found:
                            print(f"    • {len(rare_found)} rare subjects found in sample")
    else:
        print(f"\n✗ Preprocessed data NOT found at: {PREPROCESSED_PATH}")
        return False
    
    # Also check dataset.json for complete subject list
    dataset_json_path = Path(PREPROCESSED_PATH) / "dataset.json"
    if dataset_json_path.exists():
        print("\n✓ Reading dataset.json for complete subject list...")
        with open(dataset_json_path, 'r') as f:
            dataset_info = json.load(f)
            if 'training' in dataset_info:
                training_cases = dataset_info['training']
                subjects = set()
                for case in training_cases:
                    # Extract subject ID from image path
                    img_path = case.get('image', '')
                    if isinstance(img_path, list):
                        img_path = img_path[0]
                    fname = Path(img_path).stem
                    if fname.startswith('sub-'):
                        subject_id = fname.split('_')[0]
                        subjects.add(subject_id)
                
                print(f"  - Total training subjects: {len(subjects)}")
                
                # Check rare subjects
                rare_found = [s for s in RARE_SUBJECTS if s in subjects]
                print(f"  - Rare subjects found: {len(rare_found)}/{len(RARE_SUBJECTS)}")
                
                if len(rare_found) < len(RARE_SUBJECTS):
                    missing = set(RARE_SUBJECTS) - subjects
                    print(f"\n⚠ Warning: {len(missing)} rare subjects not in dataset:")
                    for m in sorted(missing)[:10]:
                        print(f"    • {m}")
                    if len(missing) > 10:
                        print(f"    ... and {len(missing)-10} more")
    
    print("\n" + "="*70)
    return True

verify_dataset()


In [None]:
# Copy preprocessed data to writable location
source = "../../data/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
dest = "../../data/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"

print("Copying preprocessed data to writable location...")
print(f"From: {source}")
print(f"To: {dest}")

# Create destination directory
Path(dest).parent.mkdir(parents=True, exist_ok=True)

# Copy the entire dataset
shutil.copytree(source, dest, dirs_exist_ok=True)

print(f"✓ Copied successfully")

# Update environment variable



In [None]:
import json
import math
import os

plans_path = "../../data/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR/nnUNetPlans.json"

# Load plans
with open(plans_path, "r") as f:
    plans = json.load(f)

# ----------------------------------------------------
# Ensure the augmentation keys exist
# ----------------------------------------------------
if "data_augmentation" not in plans:
    plans["data_augmentation"] = {}

if "spatial" not in plans["data_augmentation"]:
    plans["data_augmentation"]["spatial"] = {}

if "intensity" not in plans["data_augmentation"]:
    plans["data_augmentation"]["intensity"] = {}

# ----------------------------------------------------
# 🔥 Doubled spatial augmentations
# ----------------------------------------------------
# Rotation: ±60 degrees (Default was 30)
plans["data_augmentation"]["spatial"]["rotation"] = {
    "x": 60 * (math.pi / 180),
    "y": 60 * (math.pi / 180),
    "z": 60 * (math.pi / 180)
}

# Scale: Range [0.70, 1.50] (Default was [0.85, 1.25])
# Logic: Doubled the deviation from 1.0 (15%->30% down, 25%->50% up)
plans["data_augmentation"]["spatial"]["scale_range"] = [0.70, 1.50]

# Elastic Deformation: Still Disabled (Doubling None is None)
# If you want to force it on, you must add a dictionary here.
plans["data_augmentation"]["spatial"]["elastic_deform"] = None

# ----------------------------------------------------
# 🔥 Doubled intensity augmentations
# ----------------------------------------------------
# Brightness: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["brightness"] = [0.5, 1.5]

# Contrast: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["contrast"] = [0.5, 1.5]

# Gaussian Noise: Variance 0.2 (Default was 0.1)
plans["data_augmentation"]["intensity"]["gaussian_noise_std"] = 0.2

# Gamma: Range [0.4, 2.0] (Default was [0.7, 1.5])
# Logic: Doubled the deviation from 1.0
plans["data_augmentation"]["intensity"]["gamma_range"] = [0.4, 2.0]

# Save back
with open(plans_path, "w") as f:
    json.dump(plans, f, indent=4)

print("🔥 Augmentation parameters DOUBLED successfully!")
print("- Rotation: ±60°")
print("- Scale: 0.70 - 1.50")
print("- Noise/Contrast/Gamma: Deviation doubled")


In [None]:
# Load custom splits from JSON
import json
with open(SPLITS_FILE_PATH, 'r') as f:
    custom_splits = json.load(f)

print(f"\u2713 Loaded custom splits from: {SPLITS_FILE_PATH}")
print(f"  - Total folds: {len(custom_splits)}")

# Enforce these splits by saving them to the preprocessed directory
# (This is what nnU-Net will actually use)
preprocessed_dir = nnUNet_preprocessed
if os.path.basename(preprocessed_dir) != 'Dataset002_BonnFCD_FLAIR':
    # Ensure we are pointing to the dataset folder
    preprocessed_dir = os.path.join(preprocessed_dir, 'Dataset002_BonnFCD_FLAIR')

splits_file = os.path.join(preprocessed_dir, "splits_final.json")
os.makedirs(preprocessed_dir, exist_ok=True)
with open(splits_file, "w") as f:
    json.dump(custom_splits, f, indent=4)

print(f"\u2705 Custom splits enforced successfully!")
print(f"  - Copied to: {splits_file}")


In [None]:
print("\n" + "="*70)
print("✓ Dataset is already preprocessed - skipping preprocessing step")
print("="*70)
print("\nOversampling will be applied during training via custom DataLoader")
print("This modifies sampling probabilities AFTER preprocessing as intended")
print("="*70)

In [None]:
# CELL: Disable torch.compile for P100 GPU compatibility
import os

print("✓ Disabled torch.compile for GPU compatibility")
print("  (P100 GPU doesn't support Triton compiler)")

# 6. Training Execution
Final configuration checks and executing the nnU-Net training command.

In [None]:
print("\n" + "="*70)
print("STARTING TRAINING WITH CUSTOM OVERSAMPLING")
print("="*70)
print(f"\nConfiguration:")
print(f"  - Dataset: 002 (BonnFCD_FLAIR)")
print(f"  - Configuration: 3d_fullres")
print(f"  - Fold: {FOLD}")
print(f"  - Trainer: nnUNetTrainerOversampling")
print(f"  - Rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")
print("="*70 + "\n")

!nnUNetv2_train 002 3d_fullres {FOLD} -tr nnUNetTrainerOversampling -p nnUNetPlans --npz