# 1. Environment Setup
Installing necessary dependencies and importing libraries.

In [None]:
!pip install nnunetv2

import os
import json
import shutil
from pathlib import Path
import numpy as np
from typing import List, Dict

print("✓ Dependencies installed and imported")


# 2. Experiment Configuration
Defining experiment parameters, selecting the specific **Cross-Validation Fold**, and configuring rare subject oversampling.

In [None]:
# Training configuration - Fold Selection
FOLD = 4  # Select fold (0, 1, 2, 3, 4)

TRAINING_TIME_MINUTES = (11 * 60)+ (45)  # 11 hours 45 minutes


# Define your dataset paths - MODIFY THESE TO MATCH YOUR KAGGLE PATHS
PREPROCESSED_PATH = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
RAW_PATH = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_raw_data_base/nnUNet_raw/Dataset002_BonnFCD_FLAIR"

# nnUNet environment variables
os.environ['nnUNet_raw'] = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_raw_data_base/nnUNet_raw"
os.environ['nnUNet_preprocessed'] = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed"
os.environ['nnUNet_results'] = "/kaggle/working/nnUNet_results"

# Subjects to oversample (rare abnormalities - non-cortical thickening)
RARE_SUBJECTS = [
    'sub-00001', 'sub-00003', 'sub-00014', 'sub-00015', 'sub-00016', 'sub-00018', 
    'sub-00024', 'sub-00027', 'sub-00033', 'sub-00038', 'sub-00040', 'sub-00044', 
    'sub-00050', 'sub-00053', 'sub-00055', 'sub-00058', 'sub-00060', 'sub-00063', 
    'sub-00065', 'sub-00073', 'sub-00077', 'sub-00078', 'sub-00080', 'sub-00081', 
    'sub-00083', 'sub-00087', 'sub-00089', 'sub-00097', 'sub-00098', 'sub-00101', 
    'sub-00105', 'sub-00109', 'sub-00112', 'sub-00115', 'sub-00116', 'sub-00122', 
    'sub-00123', 'sub-00126', 'sub-00130', 'sub-00132', 'sub-00133', 'sub-00138', 
    'sub-00146'
]

# Oversampling factor for rare subjects
OVERSAMPLE_FACTOR = 3.0  # Adjust this value as needed (2.0-5.0 recommended)

print(f"✓ Configuration set")
print(f"  - Number of rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")



# 3. Custom nnU-Net Components
Defining and generating the `CustomOversamplingDataLoader` and `nnUNetTrainerOversampling` to handle class imbalance during training.

In [None]:
def create_custom_dataloader_file():
    """
    Create a custom DataLoader that modifies sampling probabilities
    """
    custom_loader_code = '''
from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader
import numpy as np

class CustomOversamplingDataLoader(nnUNetDataLoader):
    """Custom DataLoader with oversampling for both 2D and 3D"""
    
    def __init__(self, data, batch_size, patch_size, final_patch_size=None,
                 label_manager=None, oversample_foreground_percent=0.33,
                 sampling_probabilities=None, pad_kwargs_data=None,
                 pad_mode="constant", rare_subjects=None, oversample_factor=3.0):
        
        self.rare_subjects = rare_subjects if rare_subjects else []
        self.oversample_factor = oversample_factor
        
        # Initialize parent class WITHOUT num_threads_in_multithreaded
        super().__init__(data, batch_size, patch_size, final_patch_size,
                        label_manager, oversample_foreground_percent,
                        sampling_probabilities, pad_kwargs_data, pad_mode)
        
        # Modify sampling probabilities after initialization
        if self.rare_subjects:
            self._modify_sampling_probabilities()
    
    def _modify_sampling_probabilities(self):
        """Modify sampling probabilities to favor rare subjects"""
        if not hasattr(self, '_data') or self._data is None:
            return
        
        # Get the list of case identifiers from the dataset
        # nnUNetDatasetBlosc2 uses the 'identifiers' attribute
        case_ids = None
        
        # Try the identifiers attribute (works for nnUNetDatasetBlosc2)
        if hasattr(self._data, 'identifiers'):
            case_ids = self._data.identifiers
        
        # Fallback: try other possible attributes
        if case_ids is None and hasattr(self._data, 'indices'):
            case_ids = self._data.indices
        
        if case_ids is None and hasattr(self._data, 'keys') and callable(self._data.keys):
            try:
                case_ids = list(self._data.keys())
            except:
                pass
        
        if case_ids is None:
            print("  ⚠ Warning: Could not extract case IDs from dataset, oversampling disabled")
            return
        
        num_cases = len(case_ids)
        sampling_probs = np.ones(num_cases)
        
        # Increase weight for rare subjects
        rare_count = 0
        for idx, case_id in enumerate(case_ids):
            # Extract subject ID (format: sub-00001 or sub-00001_0000)
            subject_id = str(case_id).split('_')[0]
            
            if subject_id in self.rare_subjects:
                sampling_probs[idx] *= self.oversample_factor
                rare_count += 1
        
        # Normalize probabilities
        sampling_probs = sampling_probs / sampling_probs.sum()
        self.sampling_probabilities = sampling_probs
        
        print(f"  ✓ Modified sampling: {rare_count}/{num_cases} cases are rare subjects ({self.oversample_factor}x weight)")
'''
    
    # Save the custom dataloader
    os.makedirs("/kaggle/working/custom_nnunet", exist_ok=True)
    with open("/kaggle/working/custom_nnunet/custom_dataloader.py", "w") as f:
        f.write(custom_loader_code)
    
    print("✓ Custom DataLoader file created")

create_custom_dataloader_file()


In [None]:
import time
from batchgenerators.utilities.file_and_folder_operations import join

def create_custom_trainer():
    """
    Create a custom trainer that:
    1. Uses oversampling
    2. Saves checkpoints more frequently
    3. Stops automatically after a fixed time limit (Kaggle-safe)
    4. FIX: Removed 'unpack_dataset' entirely to prevent KeyError in nnU-Net v2.6+
    """

    trainer_code = f'''import sys
import time
from batchgenerators.utilities.file_and_folder_operations import join
import torch

sys.path.insert(0, "/kaggle/working/custom_nnunet")

from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer
from custom_dataloader import CustomOversamplingDataLoader
from nnunetv2.training.dataloading.data_loader import nnUNetDataLoader


class nnUNetTrainerOversampling(nnUNetTrainer):
    """
    Custom nnU-Net trainer with:
    - Oversampling for rare subjects
    - Frequent checkpoint saving
    - Time-limit safety for Kaggle
    """

    # -------------------------------------------------------
    # 1. Configuration
    # -------------------------------------------------------
    rare_subjects = {RARE_SUBJECTS}
    oversample_factor = {OVERSAMPLE_FACTOR}

    def __init__(
        self,
        plans: dict,
        configuration: str,
        fold: int,
        dataset_json: dict,
        # unpack_dataset REMOVED entirely
        device: torch.device = torch.device('cuda'),
    ):
        # Initialize parent without unpack_dataset
        super().__init__(plans, configuration, fold, dataset_json, device=device)

        # SAVE FREQUENCY: save permanent checkpoint every 20 epochs
        self.save_every = 20

        # TIME LIMIT: stop training after ~11h (Kaggle max is ~12h)
        self.max_time_seconds = {TRAINING_TIME_MINUTES} * 60
        self.start_time = time.time()

    # -------------------------------------------------------
    # 2. Oversampling Logic
    # -------------------------------------------------------
    def get_tr_and_val_datasets(self):
        dataset_tr, dataset_val = super().get_tr_and_val_datasets()
        return dataset_tr, dataset_val

    def get_plain_dataloaders(self):
        dataset_tr, dataset_val = self.get_tr_and_val_datasets()

        dl_tr = CustomOversamplingDataLoader(
            dataset_tr,
            self.batch_size,
            self.patch_size,
            self.patch_size,
            self.label_manager,
            oversample_foreground_percent=self.oversample_foreground_percent,
            rare_subjects=self.rare_subjects,
            oversample_factor=self.oversample_factor,
        )

        dl_val = nnUNetDataLoader(
            dataset_val,
            self.batch_size,
            self.patch_size,
            self.patch_size,
            self.label_manager,
            oversample_foreground_percent=self.oversample_foreground_percent,
        )

        return dl_tr, dl_val

    # -------------------------------------------------------
    # 3. Time-Limit Safety
    # -------------------------------------------------------
    def on_epoch_end(self):
        super().on_epoch_end()

        elapsed = time.time() - self.start_time

        if elapsed > self.max_time_seconds:
            self.print_to_log_file(
                f"\\\\nTIME LIMIT REACHED ({{elapsed / 3600:.2f}} hours)."
            )
            self.print_to_log_file(
                "Stopping training gracefully to save checkpoints."
            )

            # Explicitly save latest checkpoint
            self.save_checkpoint(
                join(self.output_folder, "checkpoint_latest.pth")
            )

            # Clean shutdown so Kaggle persists outputs
            self.on_train_end()
            sys.exit(0)
'''

    # -------------------------------------------------------
    # Save trainer file
    # -------------------------------------------------------
    import nnunetv2
    from pathlib import Path

    nnunet_trainers_dir = (
        Path(nnunetv2.__file__).parent / "training" / "nnUNetTrainer"
    )
    trainer_file = nnunet_trainers_dir / "nnUNetTrainerOversampling.py"

    with open(trainer_file, "w", encoding="utf-8") as f:
        f.write(trainer_code)

    print("✔️ Custom Trainer updated (unpack_dataset removed completely)")


create_custom_trainer()

# 4. Data Preparation
Verifying the input dataset integrity and copying it to a writable directory (`/kaggle/working`) for processing.

In [None]:
def verify_dataset():
    """Verify the preprocessed dataset structure and subject identifiers"""
    
    print("\n" + "="*70)
    print("DATASET VERIFICATION")
    print("="*70)
    
    # Check preprocessed data
    if os.path.exists(PREPROCESSED_PATH):
        print(f"\n✓ Preprocessed data found at: {PREPROCESSED_PATH}")
        
        # List nnUNet plans
        plans_files = list(Path(PREPROCESSED_PATH).glob("nnUNetPlans*.json"))
        if plans_files:
            print(f"  - Found {len(plans_files)} plan file(s)")
            for plan in plans_files:
                print(f"    • {plan.name}")
        
        # Check for preprocessed data folders
        print("\n✓ Checking preprocessed data structure:")
        for config in ['nnUNetPlans_2d', 'nnUNetPlans_3d_fullres']:
            config_path = Path(PREPROCESSED_PATH) / config
            if config_path.exists():
                # Count .npz or .npy files in the configuration folder
                npz_files = list(config_path.glob("**/*.npz"))
                npy_files = list(config_path.glob("**/*.npy"))
                pkl_files = list(config_path.glob("**/*.pkl"))
                
                total_files = len(npz_files) + len(npy_files) + len(pkl_files)
                print(f"  - {config}: {total_files} preprocessed files")
                
                # Try to extract subject IDs from filenames
                if npz_files or npy_files or pkl_files:
                    all_files = npz_files + npy_files + pkl_files
                    subjects = set()
                    for f in all_files[:50]:  # Sample first 50 files
                        # Extract subject ID (format might be: sub-00001.npz or sub-00001_0000.npz)
                        fname = f.stem
                        if fname.startswith('sub-'):
                            subject_id = fname.split('_')[0]
                            subjects.add(subject_id)
                    
                    if subjects:
                        print(f"    • Found {len(subjects)} unique subjects (sampled)")
                        
                        # Check rare subjects
                        rare_found = [s for s in RARE_SUBJECTS if s in subjects]
                        if rare_found:
                            print(f"    • {len(rare_found)} rare subjects found in sample")
    else:
        print(f"\n✗ Preprocessed data NOT found at: {PREPROCESSED_PATH}")
        return False
    
    # Also check dataset.json for complete subject list
    dataset_json_path = Path(PREPROCESSED_PATH) / "dataset.json"
    if dataset_json_path.exists():
        print("\n✓ Reading dataset.json for complete subject list...")
        with open(dataset_json_path, 'r') as f:
            dataset_info = json.load(f)
            if 'training' in dataset_info:
                training_cases = dataset_info['training']
                subjects = set()
                for case in training_cases:
                    # Extract subject ID from image path
                    img_path = case.get('image', '')
                    if isinstance(img_path, list):
                        img_path = img_path[0]
                    fname = Path(img_path).stem
                    if fname.startswith('sub-'):
                        subject_id = fname.split('_')[0]
                        subjects.add(subject_id)
                
                print(f"  - Total training subjects: {len(subjects)}")
                
                # Check rare subjects
                rare_found = [s for s in RARE_SUBJECTS if s in subjects]
                print(f"  - Rare subjects found: {len(rare_found)}/{len(RARE_SUBJECTS)}")
                
                if len(rare_found) < len(RARE_SUBJECTS):
                    missing = set(RARE_SUBJECTS) - subjects
                    print(f"\n⚠ Warning: {len(missing)} rare subjects not in dataset:")
                    for m in sorted(missing)[:10]:
                        print(f"    • {m}")
                    if len(missing) > 10:
                        print(f"    ... and {len(missing)-10} more")
    
    print("\n" + "="*70)
    return True

verify_dataset()


In [None]:
# Copy preprocessed data to writable location
source = "/kaggle/input/preprocessed-bonnfcd-flair/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
dest = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"

print("Copying preprocessed data to writable location...")
print(f"From: {source}")
print(f"To: {dest}")

# Create destination directory
Path(dest).parent.mkdir(parents=True, exist_ok=True)

# Copy the entire dataset
shutil.copytree(source, dest, dirs_exist_ok=True)

print(f"✓ Copied successfully")

# Update environment variable
os.environ['nnUNet_preprocessed'] = "/kaggle/working/nnUNet_preprocessed"

print(f"✓ Updated nnUNet_preprocessed path to: {os.environ['nnUNet_preprocessed']}")


In [None]:
import json
import math
import os

plans_path = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR/nnUNetPlans.json"

# Load plans
with open(plans_path, "r") as f:
    plans = json.load(f)

# ----------------------------------------------------
# Ensure the augmentation keys exist
# ----------------------------------------------------
if "data_augmentation" not in plans:
    plans["data_augmentation"] = {}

if "spatial" not in plans["data_augmentation"]:
    plans["data_augmentation"]["spatial"] = {}

if "intensity" not in plans["data_augmentation"]:
    plans["data_augmentation"]["intensity"] = {}

# ----------------------------------------------------
# 🔥 Doubled spatial augmentations
# ----------------------------------------------------
# Rotation: ±60 degrees (Default was 30)
plans["data_augmentation"]["spatial"]["rotation"] = {
    "x": 60 * (math.pi / 180),
    "y": 60 * (math.pi / 180),
    "z": 60 * (math.pi / 180)
}

# Scale: Range [0.70, 1.50] (Default was [0.85, 1.25])
# Logic: Doubled the deviation from 1.0 (15%->30% down, 25%->50% up)
plans["data_augmentation"]["spatial"]["scale_range"] = [0.70, 1.50]

# Elastic Deformation: Still Disabled (Doubling None is None)
# If you want to force it on, you must add a dictionary here.
plans["data_augmentation"]["spatial"]["elastic_deform"] = None

# ----------------------------------------------------
# 🔥 Doubled intensity augmentations
# ----------------------------------------------------
# Brightness: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["brightness"] = [0.5, 1.5]

# Contrast: Range [0.5, 1.5] (Default was [0.75, 1.25])
plans["data_augmentation"]["intensity"]["contrast"] = [0.5, 1.5]

# Gaussian Noise: Variance 0.2 (Default was 0.1)
plans["data_augmentation"]["intensity"]["gaussian_noise_std"] = 0.2

# Gamma: Range [0.4, 2.0] (Default was [0.7, 1.5])
# Logic: Doubled the deviation from 1.0
plans["data_augmentation"]["intensity"]["gamma_range"] = [0.4, 2.0]

# Save back
with open(plans_path, "w") as f:
    json.dump(plans, f, indent=4)

print("🔥 Augmentation parameters DOUBLED successfully!")
print("- Rotation: ±60°")
print("- Scale: 0.70 - 1.50")
print("- Noise/Contrast/Gamma: Deviation doubled")


In [None]:
import json
import os

# ----------------------------------------------------
# 🧪 Custom 5-Fold Cross-Validation Splits
# ----------------------------------------------------
# Hardcoded splits to ensure consistency across Kaggle sessions
custom_splits = [
    {
        "train": [
            "sub-00001",
            "sub-00003",
            "sub-00010",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00024",
            "sub-00027",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00044",
            "sub-00047",
            "sub-00050",
            "sub-00053",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00060",
            "sub-00063",
            "sub-00071",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00077",
            "sub-00078",
            "sub-00080",
            "sub-00083",
            "sub-00089",
            "sub-00090",
            "sub-00091",
            "sub-00097",
            "sub-00098",
            "sub-00101",
            "sub-00112",
            "sub-00120",
            "sub-00122",
            "sub-00123",
            "sub-00130",
            "sub-00132",
            "sub-00133",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ],
        "val": [
            "sub-00018",
            "sub-00065",
            "sub-00068",
            "sub-00081",
            "sub-00087",
            "sub-00105",
            "sub-00109",
            "sub-00115",
            "sub-00116",
            "sub-00126",
            "sub-00131",
            "sub-00138"
        ]
    },
    {
        "train": [
            "sub-00001",
            "sub-00003",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00018",
            "sub-00024",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00047",
            "sub-00050",
            "sub-00053",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00065",
            "sub-00068",
            "sub-00071",
            "sub-00077",
            "sub-00078",
            "sub-00080",
            "sub-00081",
            "sub-00083",
            "sub-00087",
            "sub-00089",
            "sub-00090",
            "sub-00091",
            "sub-00097",
            "sub-00098",
            "sub-00105",
            "sub-00109",
            "sub-00112",
            "sub-00115",
            "sub-00116",
            "sub-00120",
            "sub-00126",
            "sub-00130",
            "sub-00131",
            "sub-00133",
            "sub-00138",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ],
        "val": [
            "sub-00010",
            "sub-00027",
            "sub-00044",
            "sub-00060",
            "sub-00063",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00101",
            "sub-00122",
            "sub-00123",
            "sub-00132"
        ]
    },
    {
        "train": [
            "sub-00003",
            "sub-00010",
            "sub-00018",
            "sub-00024",
            "sub-00027",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00044",
            "sub-00047",
            "sub-00050",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00060",
            "sub-00063",
            "sub-00065",
            "sub-00068",
            "sub-00071",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00077",
            "sub-00080",
            "sub-00081",
            "sub-00087",
            "sub-00090",
            "sub-00097",
            "sub-00098",
            "sub-00101",
            "sub-00105",
            "sub-00109",
            "sub-00112",
            "sub-00115",
            "sub-00116",
            "sub-00122",
            "sub-00123",
            "sub-00126",
            "sub-00131",
            "sub-00132",
            "sub-00133",
            "sub-00138",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ],
        "val": [
            "sub-00001",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00053",
            "sub-00078",
            "sub-00083",
            "sub-00089",
            "sub-00091",
            "sub-00120",
            "sub-00130"
        ]
    },
    {
        "train": [
            "sub-00001",
            "sub-00003",
            "sub-00010",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00018",
            "sub-00027",
            "sub-00044",
            "sub-00050",
            "sub-00053",
            "sub-00060",
            "sub-00063",
            "sub-00065",
            "sub-00068",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00078",
            "sub-00080",
            "sub-00081",
            "sub-00083",
            "sub-00087",
            "sub-00089",
            "sub-00090",
            "sub-00091",
            "sub-00097",
            "sub-00098",
            "sub-00101",
            "sub-00105",
            "sub-00109",
            "sub-00112",
            "sub-00115",
            "sub-00116",
            "sub-00120",
            "sub-00122",
            "sub-00123",
            "sub-00126",
            "sub-00130",
            "sub-00131",
            "sub-00132",
            "sub-00138",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ],
        "val": [
            "sub-00024",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00047",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00071",
            "sub-00077",
            "sub-00133"
        ]
    },
    {
        "train": [
            "sub-00001",
            "sub-00010",
            "sub-00014",
            "sub-00015",
            "sub-00016",
            "sub-00018",
            "sub-00024",
            "sub-00027",
            "sub-00033",
            "sub-00038",
            "sub-00040",
            "sub-00044",
            "sub-00047",
            "sub-00053",
            "sub-00055",
            "sub-00058",
            "sub-00059",
            "sub-00060",
            "sub-00063",
            "sub-00065",
            "sub-00068",
            "sub-00071",
            "sub-00072",
            "sub-00073",
            "sub-00076",
            "sub-00077",
            "sub-00078",
            "sub-00081",
            "sub-00083",
            "sub-00087",
            "sub-00089",
            "sub-00091",
            "sub-00101",
            "sub-00105",
            "sub-00109",
            "sub-00115",
            "sub-00116",
            "sub-00120",
            "sub-00122",
            "sub-00123",
            "sub-00126",
            "sub-00130",
            "sub-00131",
            "sub-00132",
            "sub-00133",
            "sub-00138"
        ],
        "val": [
            "sub-00003",
            "sub-00050",
            "sub-00080",
            "sub-00090",
            "sub-00097",
            "sub-00098",
            "sub-00112",
            "sub-00139",
            "sub-00140",
            "sub-00141",
            "sub-00146"
        ]
    }
]

# Define the target path where nnU-Net expects the splits file
# It must be in the preprocessed directory of the specific task
preprocessed_dir = "/kaggle/working/nnUNet_preprocessed/Dataset002_BonnFCD_FLAIR"
splits_file = os.path.join(preprocessed_dir, "splits_final.json")

# Write the splits to the file
os.makedirs(preprocessed_dir, exist_ok=True)
with open(splits_file, "w") as f:
    json.dump(custom_splits, f, indent=4)

print(f"✅ Custom splits enforced successfully!")
print(f"  - Location: {splits_file}")
print(f"  - Total folds: {len(custom_splits)}")
print(f"  - nnU-Net will now use these fixed splits instead of random ones.")


In [None]:
print("\n" + "="*70)
print("✓ Dataset is already preprocessed - skipping preprocessing step")
print("="*70)
print("\nOversampling will be applied during training via custom DataLoader")
print("This modifies sampling probabilities AFTER preprocessing as intended")
print("="*70)

In [None]:
# CELL: Disable torch.compile for P100 GPU compatibility
import os
os.environ['nnUNet_compile'] = 'false'

print("✓ Disabled torch.compile for GPU compatibility")
print("  (P100 GPU doesn't support Triton compiler)")

# 6. Training Execution
Final configuration checks and executing the nnU-Net training command.

In [None]:
print("\n" + "="*70)
print("STARTING TRAINING WITH CUSTOM OVERSAMPLING")
print("="*70)
print(f"\nConfiguration:")
print(f"  - Dataset: 002 (BonnFCD_FLAIR)")
print(f"  - Configuration: 3d_fullres")
print(f"  - Fold: {FOLD}")
print(f"  - Trainer: nnUNetTrainerOversampling")
print(f"  - Rare subjects: {len(RARE_SUBJECTS)}")
print(f"  - Oversample factor: {OVERSAMPLE_FACTOR}x")
print("="*70 + "\n")

!nnUNetv2_train 002 3d_fullres {FOLD} -tr nnUNetTrainerOversampling -p nnUNetPlans --npz