This is a specialized Inference Notebook designed to work with your specific file structure and the Custom Trainer (`nnUNetTrainerOversampling`) you used during training.

Since nnU-Net is very strict about folder structures (`nnUNet_results`), this notebook **artificially reconstructs the expected folder hierarchy** so the inference command can find your checkpoint without errors.

### **Inference Notebook: BonnFCD Segmentation**

#### **Cell 1: Environment Setup**

Install nnU-Net, MedPy (for metrics), and import necessary libraries.

In [None]:
import sys
import os
# Add the src directory to sys.path so we can import config
sys.path.append(os.path.abspath(os.path.join(os.getcwd(), "..", "src")))

from config import *
setup_env()


In [None]:
# Install nnU-Net V2 and MedPy for evaluation
# !pip install nnunetv2 medpy pandas

import os
import sys
import shutil
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import nibabel as nib
import matplotlib.pyplot as plt

from medpy.metric.binary import dc, hd95, asd
from medpy.metric.binary import precision as prec
from medpy.metric.binary import recall as rec

print("✓ Environment ready")

#### **Cell 2: Define Paths & Environment Variables**

Here we map the paths based on the file tree you provided.

In [None]:
# ---------------------------------------------------------
# NNUNET RESULTS INPUT (User Input)
# ---------------------------------------------------------
NNUNET_MODEL_OUTPUT_DATA = nnUNet_results

FOLD = 4  # Select fold (0, 1, 2, 3, 4)


# ---------------------------------------------------------
# CHECKPOINT FILE NAME (User Choice)
# ---------------------------------------------------------
CHECKPOINT_NAME = "checkpoint_best.pth"
# or: "checkpoint_final.pth"


# ---------------------------------------------------------
# FULL CHECKPOINT PATH (Derived)
# ---------------------------------------------------------
CHECKPOINT_SOURCE = os.path.join(
    NNUNET_MODEL_OUTPUT_DATA,
    "nnUNet_results",
    "Dataset002_BonnFCD_FLAIR",
    "nnUNetTrainerOversampling__nnUNetPlans__3d_fullres",
    f"fold_{FOLD}",
    CHECKPOINT_NAME,
)

# ---------------------------------------------------------
# 2. DATA PATHS (Based on your provided file tree)
# ---------------------------------------------------------
# Path where the Test images are located (imagesTs)
INPUT_IMAGES_FOLDER = os.path.join(nnUNet_raw, 'Dataset002_BonnFCD_FLAIR', 'imagesTs')

# Path to the preprocessed data (needed for plans.json)
PREPROCESSED_BASE = nnUNet_preprocessed

# Where to save predictions
OUTPUT_FOLDER = "../results/inference_output"
 "../data/inference_output"
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# ---------------------------------------------------------
# 3. NNUNET ENVIRONMENT VARIABLES
# ---------------------------------------------------------

print(f"✓ Input Folder: {INPUT_IMAGES_FOLDER}")
print(f"✓ Output Folder: {OUTPUT_FOLDER}")

#### **Cell 2b: Inspect Checkpoint Training History (Optional)**

Use this cell to visualize the training history (Loss & Dice) directly from the checkpoint file. This confirms if the model was trained properly.

In [None]:
print(f"Loading checkpoint: {CHECKPOINT_SOURCE} ...")

try:
    # FIX: Added weights_only=False to bypass the new PyTorch 2.6 security check
    checkpoint = torch.load(CHECKPOINT_SOURCE, map_location='cpu', weights_only=False)

    # 1. Basic Info (Epoch)
    # Note: nnU-Net epochs start at 0, so if it says 999, you finished 1000 epochs.
    current_epoch = checkpoint.get('epoch', 'Unknown')
    print(f"\n✅ Training Status:")
    print(f"  - Last Completed Epoch: {current_epoch + 1 if isinstance(current_epoch, int) else current_epoch}")

    # 2. Extract Logging History
    if 'logging' in checkpoint:
        log = checkpoint['logging']
        
        # Extract lists
        train_losses = log.get('train_losses', [])
        val_losses = log.get('val_losses', [])
        mean_dice = log.get('mean_fg_dice', [])  # Pseudo dice on validation set
        
        if len(train_losses) > 0:
            latest_epoch = len(train_losses)  # human-readable epoch number

            print(f"\n📊 Latest Metrics (Epoch {latest_epoch}):")
            print(f"  - Training Loss:   {train_losses[-1]:.4f}")
            print(f"  - Validation Loss: {val_losses[-1]:.4f}")
            print(f"  - Validation Dice: {mean_dice[-1]:.4f} (Mean Foreground)")
            
            # --- Save Training History to CSV ---
            # Lengths might differ slightly if training crashed
            min_len = min(len(train_losses), len(val_losses), len(mean_dice))
            history_data = {
                "Epoch": list(range(1, min_len + 1)),  # human-readable
                "Train_Loss": train_losses[:min_len],
                "Val_Loss": val_losses[:min_len],
                "Val_Dice": mean_dice[:min_len]
            }
            df_history = pd.DataFrame(history_data)
            
            history_csv_path = Path(OUTPUT_FOLDER) / "training_history.csv"
            df_history.to_csv(history_csv_path, index=False)
            print(f"✓ Training history saved to: {history_csv_path}")

            # --- Save Latest Metrics to CSV ---
            latest_metrics = {
                "Epoch": [latest_epoch],
                "Train_Loss": [train_losses[-1]],
                "Val_Loss": [val_losses[-1]],
                "Val_Dice": [mean_dice[-1]]
            }
            df_latest = pd.DataFrame(latest_metrics)

            latest_csv_path = Path(OUTPUT_FOLDER) / "latest_metrics.csv"
            df_latest.to_csv(latest_csv_path, index=False)
            print(f"✓ Latest metrics saved to: {latest_csv_path}")

            # --- Plotting and Saving ---
            plt.figure(figsize=(12, 5))
            
            # Loss Plot
            plt.subplot(1, 2, 1)
            plt.plot(train_losses, label='Train Loss', alpha=0.7)
            plt.plot(val_losses, label='Val Loss', alpha=0.7)
            plt.title("Training vs Validation Loss")
            plt.xlabel("Epoch")
            plt.ylabel("Loss")
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            # Dice Plot
            plt.subplot(1, 2, 2)
            plt.plot(mean_dice, label='Mean Foreground Dice')
            plt.title("Validation Dice Score (Pseudo)")
            plt.xlabel("Epoch")
            plt.ylabel("Dice Score")
            plt.legend()
            plt.grid(True, alpha=0.3)
            
            plt.tight_layout()
            
            # Save Plot
            history_plot_path = Path(OUTPUT_FOLDER) / "training_history.png"
            plt.savefig(history_plot_path)
            print(f"✓ Training history plot saved to: {history_plot_path}")
            
            plt.show()
        else:
            print("⚠ 'logging' key found but lists are empty.")
            
    else:
        print("⚠ 'logging' key not found in checkpoint. Available keys:", checkpoint.keys())

except FileNotFoundError:
    print(f"❌ Error: File not found at {CHECKPOINT_SOURCE}")
except Exception as e:
    print(f"❌ Error reading checkpoint: {e}")

#### **Cell 3: Register Custom Trainer (Crucial)**

Since you trained with `nnUNetTrainerOversampling`, the inference command needs to know this class exists, even if we are just predicting. We create a dummy file so nnU-Net doesn't crash on import.

In [None]:
# Create a dummy custom trainer file so nnU-Net can find the class definition
custom_dir = "../data/custom_nnunet"
os.makedirs(custom_dir, exist_ok=True)

trainer_code = '''
from nnunetv2.training.nnUNetTrainer.nnUNetTrainer import nnUNetTrainer

class nnUNetTrainerOversampling(nnUNetTrainer):
    """
    Dummy class for inference. 
    We don't need the oversampling logic here, just the class existence 
    so nnU-Net can load the checkpoint structure correctly.
    """
    def __init__(self, plans: dict, configuration: str, fold: int, dataset_json: dict, device=None):
        super().__init__(plans, configuration, fold, dataset_json, device=device)
'''

with open(f"{custom_dir}/nnUNetTrainerOversampling.py", "w") as f:
    f.write(trainer_code)

# Add to python path

# Move it to where nnU-Net looks for trainers (internal hack)
import nnunetv2
nnunet_trainers_dir = Path(nnunetv2.__file__).parent / "training" / "nnUNetTrainer"
shutil.copy(f"{custom_dir}/nnUNetTrainerOversampling.py", nnunet_trainers_dir / "nnUNetTrainerOversampling.py")

print("✓ Custom Trainer registered for Inference")

#### **Cell 4: Reconstruct Results Folder Structure**

nnU-Net will look for the checkpoint in a specific path format: `DatasetNAME / Trainer__Plans__Config / fold_X`. We manually create this and copy your checkpoint there.

In [None]:
# Define the strict structure nnU-Net expects
# Format: DatasetXXX_Name / TrainerName__PlansName__ConfigName / fold_X
dataset_name = "Dataset002_BonnFCD_FLAIR"
trainer_name = "nnUNetTrainerOversampling" # MUST match what you used in training
plans_name = "nnUNetPlans"
config_name = "3d_fullres"

expected_folder = results_base / dataset_name / f"{trainer_name}__{plans_name}__{config_name}" / f"fold_{FOLD}"
expected_folder.mkdir(parents=True, exist_ok=True)

# Copy the checkpoint
src_ckpt = Path(CHECKPOINT_SOURCE)
dest_ckpt = expected_folder / "checkpoint_final.pth"

if src_ckpt.exists():
    shutil.copy2(src_ckpt, dest_ckpt)
    print(f"✓ Checkpoint moved to:\n  {dest_ckpt}")
    
    # Copy metadata (plans.json, dataset.json) from the trained model folder
    # These are needed for inference to know how to process data
    # expected_folder.parent is .../Trainer__Plans__Config/
    
    # FIX: Correctly construct the path including 'nnUNet_results'
    # AND use the correct variable name defined in Cell 2 (NNUNET_MODEL_OUTPUT_DATA)
    trained_model_folder = Path(NNUNET_MODEL_OUTPUT_DATA) / "nnUNet_results" / dataset_name / f"{trainer_name}__{plans_name}__{config_name}"
    
    if trained_model_folder.exists():
        print(f"✓ Found trained model folder: {trained_model_folder}")
        
        # Copy plans.json
        plans_src = trained_model_folder / "plans.json"
        plans_dest = expected_folder.parent / "plans.json" 
        if plans_src.exists():
            shutil.copy2(plans_src, plans_dest)
            print(f"✓ plans.json moved to: {plans_dest}")
        else:
            print(f"⚠ plans.json not found in {trained_model_folder}")

        # Copy dataset.json
        dataset_src = trained_model_folder / "dataset.json"
        dataset_dest = expected_folder.parent / "dataset.json"
        if dataset_src.exists():
            shutil.copy2(dataset_src, dataset_dest)
            print(f"✓ dataset.json moved to: {dataset_dest}")
        else:
            print(f"⚠ dataset.json not found in {trained_model_folder}")
            
    else:
        print(f"⚠ Trained model folder not found at {trained_model_folder}")
        # Fallback check for dataset.json in preprocessed
        dataset_json_pre = Path(PREPROCESSED_BASE) / dataset_name / "dataset.json"
        if dataset_json_pre.exists():
            dataset_dest = expected_folder.parent / "dataset.json"
            shutil.copy2(dataset_json_pre, dataset_dest)
            print(f"✓ dataset.json (fallback: from preprocessed) moved to: {dataset_dest}")
    
else:
    print(f"❌ ERROR: Checkpoint not found at {CHECKPOINT_SOURCE}")

#### **Cell 5: Run Inference**

This executes the prediction.

* `-i`: Input folder (imagesTs)
* `-o`: Output folder
* `-d`: Dataset ID (002)
* `-c`: Configuration (3d_fullres)
* `-tr`: Your Custom Trainer name
* `-f`: Fold (0)

In [None]:
print("Starting Inference... (This may take a while depending on GPU)")
print("-------------------------------------------------------------")

!nnUNetv2_predict \
    -i {INPUT_IMAGES_FOLDER} \
    -o {OUTPUT_FOLDER} \
    -d 002 \
    -c 3d_fullres \
    -f {FOLD} \
    -tr nnUNetTrainerOversampling \
    -p nnUNetPlans \
    --save_probabilities # Optional: Remove if you only want the segmentation map

print("\n✓ Inference Complete!")
print(f"Results saved to: {OUTPUT_FOLDER}")

#### **Cell 6: Visualize Results (Sanity Check)**

Display a random slice of the input and the predicted output to verify the model is working.

In [None]:
def show_prediction(case_id):
    # Paths
    img_path = Path(INPUT_IMAGES_FOLDER) / f"{case_id}_0000.nii.gz"
    # If .nii.gz doesn't exist, try .nii
    if not img_path.exists():
        img_path = Path(INPUT_IMAGES_FOLDER) / f"{case_id}_0000.nii"
        
    pred_path = Path(OUTPUT_FOLDER) / f"{case_id}.nii.gz"
    if not pred_path.exists():
        pred_path = Path(OUTPUT_FOLDER) / f"{case_id}.nii"

    if not img_path.exists() or not pred_path.exists():
        print(f"Could not find files for {case_id}")
        return

    # Load
    img = nib.load(img_path).get_fdata()
    pred = nib.load(pred_path).get_fdata()

    # Find a slice with segmentation
    if np.sum(pred) > 0:
        # Get center of mass of segmentation
        coords = np.where(pred > 0)
        slice_idx = int(np.mean(coords[2]))
        print(f"Visualizing slice {slice_idx} (contains prediction)")
    else:
        slice_idx = img.shape[2] // 2
        print(f"Visualizing middle slice {slice_idx} (Empty prediction)")

    # Plot
    plt.figure(figsize=(12, 6))
    
    plt.subplot(1, 2, 1)
    plt.imshow(np.rot90(img[:, :, slice_idx]), cmap='gray')
    plt.title(f"Input: {case_id}")
    plt.axis('off')

    plt.subplot(1, 2, 2)
    plt.imshow(np.rot90(img[:, :, slice_idx]), cmap='gray')
    plt.imshow(np.rot90(pred[:, :, slice_idx]), cmap='jet', alpha=0.5)
    plt.title("Prediction Overlay")
    plt.axis('off')
    
    plt.tight_layout()
    
    # Save the plot
    save_path = Path(OUTPUT_FOLDER) / f"{case_id}_prediction.png"
    plt.savefig(save_path)
    print(f"✓ Saved visualization to: {save_path}")
    
    plt.show()
    plt.close()

# List processed files
predicted_files = [f.name.split('.')[0] for f in Path(OUTPUT_FOLDER).glob("*.nii*") if not f.name.endswith('.json')]

if len(predicted_files) > 0:
    # Show first 3 cases
    for case in predicted_files[:3]:
        show_prediction(case)
else:
    print("No predictions found to visualize.")

#### **Cell 7: Quantitative Evaluation (Dice, HD95, etc.)**

This section compares the model's predictions against the ground truth labels to calculate rigorous performance metrics.

**Metrics Calculated:**
- **Dice Score:** Overlap measure (higher is better, max 1.0).
- **IoU (Jaccard):** Intersection over Union (higher is better, max 1.0).
- **Precision & Recall:** correctness vs completeness.
- **HD95 (Hausdorff Distance 95%):** Distance error in mm (lower is better).

In [None]:
# ---------------------------------------------------------
# 1. Define Ground Truth Path
# ---------------------------------------------------------
# Assuming standard nnU-Net structure where labelsTs is parallel to imagesTs
# If your path is different, please update LABELS_FOLDER manually.
LABELS_FOLDER = INPUT_IMAGES_FOLDER.replace("imagesTs", "labelsTs")

print(f"✓ Looking for Ground Truth in: {LABELS_FOLDER}")

# ---------------------------------------------------------
# 2. Define Metric Functions
# ---------------------------------------------------------
def calculate_metrics(pred_path, gt_path):
    try:
        # Load NIfTI files
        pred_nii = nib.load(pred_path)
        gt_nii = nib.load(gt_path)
        
        pred_data = pred_nii.get_fdata().astype(bool)
        gt_data = gt_nii.get_fdata().astype(bool)
        
        # Voxel spacing for HD95 (read from header)
        voxel_spacing = pred_nii.header.get_zooms()
        
        # Handle empty masks
        if np.sum(pred_data) == 0 and np.sum(gt_data) == 0:
            return {
                "Dice": 1.0,
                "IoU": 1.0,
                "Precision": 1.0,
                "Recall": 1.0,
                "HD95": 0.0 # Perfect match (both empty)
            }
        elif np.sum(pred_data) == 0 or np.sum(gt_data) == 0:
             return {
                "Dice": 0.0,
                "IoU": 0.0,
                "Precision": 0.0,
                "Recall": 0.0,
                "HD95": np.nan # Undefined distance if one is empty
            }
            
        # Calculate Metrics using MedPy
        dice_score = dc(pred_data, gt_data)
        iou_score = dice_score / (2 - dice_score) # Mathematical relation
        precision_score = prec(pred_data, gt_data)
        recall_score = rec(pred_data, gt_data)
        hd95_score = hd95(pred_data, gt_data, voxelspacing=voxel_spacing)
        
        return {
            "Dice": dice_score,
            "IoU": iou_score,
            "Precision": precision_score,
            "Recall": recall_score,
            "HD95": hd95_score
        }
        
    except Exception as e:
        print(f"Error calculating metrics for {pred_path.name}: {e}")
        return None

# ---------------------------------------------------------
# 3. Run Evaluation Loop
# ---------------------------------------------------------
results = []

# Search for both .nii.gz and .nii
pred_files = list(Path(OUTPUT_FOLDER).glob("*.nii.gz")) + list(Path(OUTPUT_FOLDER).glob("*.nii"))

if not Path(LABELS_FOLDER).exists():
    print(f"❌ Error: Labels folder not found at {LABELS_FOLDER}. Cannot run evaluation.")
elif len(pred_files) == 0:
    print(f"❌ No predictions found in {OUTPUT_FOLDER}")
    print("Debug: Listing ALL files in output folder to help diagnosis:")
    try:
        all_files = os.listdir(OUTPUT_FOLDER)
        print(all_files if all_files else "  (Folder is empty)")
    except Exception as e:
        print(f"  Error reading folder: {e}")
        
    print("\nPossible reasons:")
    print("1. Inference (Cell 5) failed or didn't run.")
    print("2. Output path mismatch.")
else:
    print(f"Starting evaluation on {len(pred_files)} cases...")
    
    for pred_file in pred_files:
        # Match Prediction filename to Ground Truth filename
        case_id = pred_file.name
        gt_file = Path(LABELS_FOLDER) / case_id
        
        # Fallback: sometimes GT is .nii while pred is .nii.gz or vice versa
        if not gt_file.exists():
             if case_id.endswith('.nii.gz'):
                 gt_file = Path(LABELS_FOLDER) / case_id.replace('.nii.gz', '.nii')
             elif case_id.endswith('.nii'):
                 gt_file = Path(LABELS_FOLDER) / case_id.replace('.nii', '.nii.gz')
        
        if gt_file.exists():
            metrics = calculate_metrics(pred_file, gt_file)
            if metrics:
                metrics['Case_ID'] = case_id.replace('.nii.gz', '').replace('.nii', '')
                results.append(metrics)
        else:
            print(f"⚠ Missing ground truth for {case_id} (Looked for: {gt_file.name})")

    # ---------------------------------------------------------
    # 4. Display & Save Results
    # ---------------------------------------------------------
    if len(results) > 0:
        df = pd.DataFrame(results)
        
        # Reorder columns
        cols = ['Case_ID', 'Dice', 'IoU', 'Precision', 'Recall', 'HD95']
        df = df[cols]
        
        print("\n🏆 Evaluation Results:")
        print(df.to_string(index=False))
        
        # Summary Statistics
        print("\n📊 Summary Statistics:")
        summary = df.describe().loc[['mean', 'std', 'min', 'max']]
        print(summary)
        
        # Save Summary
        save_summary_path = Path(OUTPUT_FOLDER) / "evaluation_summary.csv"
        summary.to_csv(save_summary_path)
        print(f"\n✓ Summary Statistics saved to: {save_summary_path}")
        
        # Save Metrics
        save_csv_path = Path(OUTPUT_FOLDER) / "evaluation_metrics.csv"
        df.to_csv(save_csv_path, index=False)
        print(f"✓ Metrics saved to: {save_csv_path}")
    else:
        print("No matched results to evaluate.")

### **Instructions for Use:**

1. **Copy Path:** Before running, copy the file path of your trained checkpoint (e.g., from your previous notebook output or uploaded dataset) and paste it into **Cell 2** (`CHECKPOINT_SOURCE`).
2. **Dataset Location:** Ensure the `INPUT_IMAGES_FOLDER` in Cell 2 matches exactly where your test images (`imagesTs`) are located. I used the path derived from your file tree, but if you re-upload the dataset, the path might change slightly.
3. **Run All:** Execute the cells in order. The logic in Cell 4 ("Reconstruct Results Folder") is the key to making this work without re-training.