In [None]:
# Environment setup and Colab detection
# Ref: setup_colab.ipynb for Colab integration patterns
import sys
import os
from pathlib import Path
from datetime import datetime

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    print("🌐 Running in Google Colab")
    from google.colab import drive
    drive.mount('/content/drive')
    
    # Set up paths for Colab - adjust these paths as needed
    DRIVE_ROOT = '/content/drive/MyDrive'
    PROJECT_ROOT = f'{DRIVE_ROOT}/RL-CC-SAM'
    
    # Change to project directory
    os.chdir(PROJECT_ROOT)
    sys.path.append(PROJECT_ROOT)
    
    print(f"📁 Working directory: {os.getcwd()}")
else:
    print("💻 Running locally")
    # Assume notebook is in notebooks/ folder
    PROJECT_ROOT = Path.cwd().parent
    os.chdir(PROJECT_ROOT)
    
    print(f"📁 Working directory: {PROJECT_ROOT}")

# Create experiment directory with timestamp
# Following pattern from DATASET_SETUP_README.md
EXPERIMENT_ID = datetime.now().strftime("%Y%m%d_%H%M%S")
EXPERIMENT_DIR = Path("notebooks") / f"baseline_test_{EXPERIMENT_ID}"
EXPERIMENT_DIR.mkdir(exist_ok=True)

print(f"🧪 Experiment ID: {EXPERIMENT_ID}")
print(f"📁 Experiment directory: {EXPERIMENT_DIR}")


In [None]:
# Install required dependencies
# Ref: download_requirements.txt for required packages
%pip install -q requests tqdm nibabel matplotlib seaborn plotly scipy ipywidgets

# Import essential libraries
import json
import time
import shutil
import tempfile
import subprocess
import warnings
warnings.filterwarnings('ignore')

# Data handling
import numpy as np
import nibabel as nib
from tqdm.auto import tqdm

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.patches import Rectangle

# Set style
plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

print("✅ Dependencies imported successfully")
print(f"📊 NumPy version: {np.__version__}")
print(f"🧠 NiBabel version: {nib.__version__}")


In [None]:
# Setup nnU-Net environment variables and paths
# Ref: setup_nnunet_env.sh and DATASET_SETUP_README.md lines 45-54

# Define paths
PROJECT_ROOT = Path.cwd()
DATASETS_DIR = PROJECT_ROOT / "datasets"
NNUNET_DIR = PROJECT_ROOT / "nnunet"

print(f"📁 Project root: {PROJECT_ROOT}")
print(f"📁 Datasets directory: {DATASETS_DIR}")
print(f"📁 nnU-Net submodule: {NNUNET_DIR}")

# Set nnU-Net environment variables
# Following nnunet/documentation/installation_instructions.md section 3
env_vars = {
    'nnUNet_raw': str(DATASETS_DIR / "nnUNet_raw"),
    'nnUNet_preprocessed': str(DATASETS_DIR / "nnUNet_preprocessed"),
    'nnUNet_results': str(DATASETS_DIR / "nnUNet_results"),
    'EXPERIMENT_DIR': str(EXPERIMENT_DIR)
}

# Set environment variables and create directories
for key, value in env_vars.items():
    os.environ[key] = value
    if key.startswith('nnUNet'):
        Path(value).mkdir(parents=True, exist_ok=True)
    print(f"✅ {key}: {value}")

# Verify nnU-Net installation
# Following test_nnunet_setup.py pattern
try:
    result = subprocess.run(['nnUNetv2_train', '-h'], 
                          capture_output=True, text=True, timeout=10)
    if result.returncode == 0:
        print("✅ nnU-Net commands available")
    else:
        print("⚠️ nnU-Net commands may not be properly installed")
except Exception as e:
    print(f"❌ nnU-Net verification failed: {e}")
    print("💡 Make sure nnU-Net is installed: pip install nnunetv2")


In [None]:
# Dataset configuration
# Following nnunet/documentation/dataset_format.md for naming conventions
DATASET_ID = 999  # Using 999 for synthetic/test datasets
DATASET_NAME = f"Dataset{DATASET_ID:03d}_BaselineTest"
DATASET_PATH = Path(os.environ['nnUNet_raw']) / DATASET_NAME

print(f"🏷️ Dataset ID: {DATASET_ID}")
print(f"📁 Dataset path: {DATASET_PATH}")

def create_synthetic_medical_dataset(dataset_path, num_train=15, num_test=3):
    """
    Create a synthetic 3D medical dataset for baseline testing.
    
    This creates realistic-looking medical data following the nnU-Net format:
    - imagesTr/: Training images (case_XXX_0000.nii.gz)
    - labelsTr/: Training labels (case_XXX.nii.gz)  
    - imagesTs/: Test images (test_XXX_0000.nii.gz)
    - dataset.json: Dataset metadata
    
    Ref: nnunet/documentation/dataset_format.md lines 1-50
    """
    print(f"🔬 Creating synthetic dataset: {num_train} training + {num_test} test cases")
    
    # Create directory structure per nnU-Net specification
    dataset_path = Path(dataset_path)
    imagesTr = dataset_path / "imagesTr"
    labelsTr = dataset_path / "labelsTr"
    imagesTs = dataset_path / "imagesTs"
    
    for dir_path in [imagesTr, labelsTr, imagesTs]:
        dir_path.mkdir(parents=True, exist_ok=True)
    
    # Dataset parameters - realistic medical image dimensions
    img_shape = (96, 96, 48)  # X, Y, Z dimensions
    spacing = [1.5, 1.5, 3.0]  # Typical medical image spacing (mm)
    affine = np.eye(4)
    np.fill_diagonal(affine[:3, :3], spacing)
    
    print(f"   📐 Image dimensions: {img_shape}")
    print(f"   📏 Voxel spacing: {spacing} mm")
    
    # Create training data
    for i in tqdm(range(num_train), desc="Creating training data"):
        case_id = f"case_{i:03d}"
        
        # Generate realistic medical image with tissue contrast
        # Background: low intensity with noise (air/background)
        image = np.random.normal(100, 30, img_shape).astype(np.float32)
        
        # Add anatomical structure (heart-like organ)
        center = np.array(img_shape) // 2
        organ_mask = np.zeros(img_shape, dtype=bool)
        
        # Create ellipsoid organ using standard equation
        for z in range(img_shape[2]):
            for y in range(img_shape[1]):
                for x in range(img_shape[0]):
                    # Ellipsoid: (x-cx)²/a² + (y-cy)²/b² + (z-cz)²/c² ≤ 1
                    dx = (x - center[0]) / 20
                    dy = (y - center[1]) / 25
                    dz = (z - center[2]) / 15
                    
                    if dx**2 + dy**2 + dz**2 <= 1:
                        organ_mask[x, y, z] = True
        
        # Set organ intensity (soft tissue HU values)
        image[organ_mask] = np.random.normal(300, 50, np.sum(organ_mask))
        
        # Add small lesions (pathological findings)
        lesion_mask = np.zeros(img_shape, dtype=bool)
        num_lesions = np.random.randint(1, 4)
        
        for _ in range(num_lesions):
            # Random position within organ
            organ_coords = np.where(organ_mask)
            if len(organ_coords[0]) > 0:
                idx = np.random.randint(len(organ_coords[0]))
                lx, ly, lz = organ_coords[0][idx], organ_coords[1][idx], organ_coords[2][idx]
                
                # Small spherical lesion (radius ~3 voxels)
                for dx in range(-3, 4):
                    for dy in range(-3, 4):
                        for dz in range(-2, 3):
                            x, y, z = lx + dx, ly + dy, lz + dz
                            if (0 <= x < img_shape[0] and 0 <= y < img_shape[1] and 
                                0 <= z < img_shape[2] and dx**2 + dy**2 + dz**2 <= 9):
                                lesion_mask[x, y, z] = True
                                image[x, y, z] = np.random.normal(500, 30)
        
        # Create segmentation mask with proper labels
        # Following nnU-Net label convention: 0=background, 1=organ, 2=lesion
        seg_mask = np.zeros(img_shape, dtype=np.uint8)
        seg_mask[organ_mask] = 1  # Organ label
        seg_mask[lesion_mask] = 2  # Lesion label
        
        # Add realistic blur and noise
        from scipy import ndimage
        image = ndimage.gaussian_filter(image, sigma=0.5)
        image = np.clip(image, 0, 1000)  # Realistic HU range
        
        # Save as NIfTI files with proper naming convention
        # Ref: nnunet/documentation/dataset_format.md lines 15-25
        img_nifti = nib.Nifti1Image(image.astype(np.int16), affine)
        seg_nifti = nib.Nifti1Image(seg_mask, affine)
        
        nib.save(img_nifti, imagesTr / f"{case_id}_0000.nii.gz")
        nib.save(seg_nifti, labelsTr / f"{case_id}.nii.gz")
    
    # Create test data (images only, no labels)
    test_cases = []
    for i in tqdm(range(num_test), desc="Creating test data"):
        case_id = f"test_{i:03d}"
        test_cases.append(case_id)
        
        # Similar to training data but with slight variations
        image = np.random.normal(100, 30, img_shape).astype(np.float32)
        
        # Add organ structure with slight shape variation
        center = np.array(img_shape) // 2 + np.random.randint(-5, 6, 3)
        organ_mask = np.zeros(img_shape, dtype=bool)
        
        for z in range(img_shape[2]):
            for y in range(img_shape[1]):
                for x in range(img_shape[0]):
                    # Slightly different organ shape
                    dx = (x - center[0]) / (18 + np.random.normal(0, 2))
                    dy = (y - center[1]) / (23 + np.random.normal(0, 2))
                    dz = (z - center[2]) / (13 + np.random.normal(0, 2))
                    
                    if dx**2 + dy**2 + dz**2 <= 1:
                        organ_mask[x, y, z] = True
        
        image[organ_mask] = np.random.normal(300, 50, np.sum(organ_mask))
        
        # Add blur and save
        image = ndimage.gaussian_filter(image, sigma=0.5)
        image = np.clip(image, 0, 1000)
        
        img_nifti = nib.Nifti1Image(image.astype(np.int16), affine)
        nib.save(img_nifti, imagesTs / f"{case_id}_0000.nii.gz")
    
    # Create dataset.json with metadata
    # Following nnunet/documentation/dataset_format.md lines 60-100
    dataset_json = {
        "name": "BaselineTest",
        "description": "Synthetic dataset for nnU-Net baseline testing",
        "tensorImageSize": "3D",
        "reference": "Created for baseline testing",
        "licence": "For testing purposes only",
        "release": "1.0",
        "modality": {"0": "CT"},  # Single modality
        "labels": {
            "0": "Background",
            "1": "Organ", 
            "2": "Lesion"
        },
        "numTraining": num_train,
        "numTest": num_test,
        "training": [{
            "image": f"./imagesTr/case_{i:03d}_0000.nii.gz",
            "label": f"./labelsTr/case_{i:03d}.nii.gz"
        } for i in range(num_train)],
        "test": [f"./imagesTs/{case_id}_0000.nii.gz" for case_id in test_cases]
    }
    
    with open(dataset_path / "dataset.json", 'w') as f:
        json.dump(dataset_json, f, indent=2)
    
    print(f"✅ Synthetic dataset created successfully!")
    print(f"   📊 Training cases: {num_train}")
    print(f"   🧪 Test cases: {num_test}")
    print(f"   📐 Image shape: {img_shape}")
    print(f"   🏷️ Labels: Background (0), Organ (1), Lesion (2)")
    
    return dataset_path

# Create the synthetic dataset
created_dataset = create_synthetic_medical_dataset(DATASET_PATH, num_train=15, num_test=3)
print(f"\n📁 Dataset created at: {created_dataset}")


In [None]:
# Visualize sample data from the created dataset
def visualize_sample_data(dataset_path, case_idx=0):
    """Visualize a sample case from the dataset."""
    
    dataset_path = Path(dataset_path)
    case_id = f"case_{case_idx:03d}"
    
    # Load image and label
    img_path = dataset_path / "imagesTr" / f"{case_id}_0000.nii.gz"
    label_path = dataset_path / "labelsTr" / f"{case_id}.nii.gz"
    
    if not img_path.exists() or not label_path.exists():
        print(f"❌ Files not found for case {case_id}")
        return
    
    # Load NIfTI files
    img_nii = nib.load(img_path)
    label_nii = nib.load(label_path)
    
    img_data = img_nii.get_fdata()
    label_data = label_nii.get_fdata()
    
    print(f"📊 Sample Case: {case_id}")
    print(f"   Image shape: {img_data.shape}")
    print(f"   Image range: [{img_data.min():.1f}, {img_data.max():.1f}]")
    print(f"   Label shape: {label_data.shape}")
    print(f"   Unique labels: {np.unique(label_data)}")
    print(f"   Voxel spacing: {img_nii.header.get_zooms()[:3]}")
    
    # Create visualization - 3 orthogonal views
    fig, axes = plt.subplots(2, 3, figsize=(15, 10))
    
    # Get middle slices
    mid_x, mid_y, mid_z = [s // 2 for s in img_data.shape]
    
    # Axial view (XY plane)
    axes[0, 0].imshow(img_data[:, :, mid_z].T, cmap='gray', origin='lower')
    axes[0, 0].set_title(f'Image - Axial (Z={mid_z})')
    axes[0, 0].axis('off')
    
    axes[1, 0].imshow(label_data[:, :, mid_z].T, cmap='viridis', origin='lower')
    axes[1, 0].set_title(f'Label - Axial (Z={mid_z})')
    axes[1, 0].axis('off')
    
    # Sagittal view (YZ plane)
    axes[0, 1].imshow(img_data[mid_x, :, :].T, cmap='gray', origin='lower')
    axes[0, 1].set_title(f'Image - Sagittal (X={mid_x})')
    axes[0, 1].axis('off')
    
    axes[1, 1].imshow(label_data[mid_x, :, :].T, cmap='viridis', origin='lower')
    axes[1, 1].set_title(f'Label - Sagittal (X={mid_x})')
    axes[1, 1].axis('off')
    
    # Coronal view (XZ plane)
    axes[0, 2].imshow(img_data[:, mid_y, :].T, cmap='gray', origin='lower')
    axes[0, 2].set_title(f'Image - Coronal (Y={mid_y})')
    axes[0, 2].axis('off')
    
    axes[1, 2].imshow(label_data[:, mid_y, :].T, cmap='viridis', origin='lower')
    axes[1, 2].set_title(f'Label - Coronal (Y={mid_y})')
    axes[1, 2].axis('off')
    
    plt.tight_layout()
    plt.savefig(EXPERIMENT_DIR / f'sample_data_{case_id}.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    # Print label statistics
    print(f"\n📈 Label Statistics:")
    for label_val in np.unique(label_data):
        count = np.sum(label_data == label_val)
        percentage = (count / label_data.size) * 100
        label_names = {0: 'Background', 1: 'Organ', 2: 'Lesion'}
        name = label_names.get(int(label_val), f'Unknown({int(label_val)})')
        print(f"   {name}: {count:,} voxels ({percentage:.1f}%)")

# Visualize first case
visualize_sample_data(DATASET_PATH, case_idx=0)


In [None]:
# Run nnU-Net preprocessing
# Ref: nnunet/documentation/how_to_use_nnunet.md lines 100-120
print(f"🔄 Starting nnU-Net preprocessing for dataset {DATASET_ID}...")
print(f"📁 Raw data: {os.environ['nnUNet_raw']}")
print(f"📁 Preprocessed: {os.environ['nnUNet_preprocessed']}")

# Run preprocessing command
# Following DATASET_SETUP_README.md lines 80-90
preprocessing_cmd = f"nnUNetv2_plan_and_preprocess -d {DATASET_ID} -c 2d 3d_fullres --verify_dataset_integrity"
print(f"\n🚀 Running command: {preprocessing_cmd}")

start_time = time.time()
try:
    result = subprocess.run(preprocessing_cmd.split(), 
                          capture_output=True, text=True, timeout=600)  # 10 min timeout
    
    preprocessing_time = time.time() - start_time
    
    if result.returncode == 0:
        print(f"✅ Preprocessing completed successfully in {preprocessing_time:.1f}s")
        print("\n📋 Preprocessing output (last 1000 chars):")
        print(result.stdout[-1000:])
    else:
        print(f"❌ Preprocessing failed with return code: {result.returncode}")
        print("\nError output:")
        print(result.stderr)
        
except subprocess.TimeoutExpired:
    print("⏰ Preprocessing timed out after 10 minutes")
except Exception as e:
    print(f"❌ Error during preprocessing: {e}")

# Check if preprocessing files were created
preprocessed_path = Path(os.environ['nnUNet_preprocessed']) / f"Dataset{DATASET_ID:03d}_BaselineTest"
if preprocessed_path.exists():
    print(f"\n✅ Preprocessed dataset found at: {preprocessed_path}")
    
    # List key files
    print("\n📁 Key preprocessed files:")
    for pattern in ["*.json", "*.pkl"]:
        for item in preprocessed_path.glob(pattern):
            size = item.stat().st_size / 1024  # KB
            print(f"   {item.name} ({size:.1f} KB)")
else:
    print(f"❌ Preprocessed dataset not found at: {preprocessed_path}")


In [None]:
# Quick baseline training (2D, single fold)
# Ref: nnunet/documentation/training_a_new_model.md lines 1-50
print(f"🏋️ Starting baseline training...")
print(f"   Dataset: {DATASET_ID}")
print(f"   Configuration: 2d (fast training)")
print(f"   Fold: 0 (single fold for quick testing)")
print(f"   Expected time: 5-15 minutes")

# Training command with memory optimization
# Following DATASET_SETUP_README.md lines 150-160
training_cmd = f"nnUNetv2_train {DATASET_ID} 2d 0 --npz"  # --npz saves memory
print(f"\n🚀 Training command: {training_cmd}")

# Create training log file
training_log = EXPERIMENT_DIR / "training_log.txt"
print(f"📝 Training log will be saved to: {training_log}")
print("\n⏳ Training started... This may take several minutes.")

training_start_time = time.time()

try:
    # Run training with live output capture
    process = subprocess.Popen(training_cmd.split(), 
                              stdout=subprocess.PIPE, 
                              stderr=subprocess.STDOUT,
                              universal_newlines=True,
                              bufsize=1)
    
    # Stream output to both console and file
    output_lines = []
    with open(training_log, 'w') as log_file:
        while True:
            output = process.stdout.readline()
            if output == '' and process.poll() is not None:
                break
            if output:
                output_lines.append(output.strip())
                # Show last few lines to avoid overwhelming output
                if len(output_lines) % 10 == 0:
                    print(f"  ... {output.strip()}")
                log_file.write(output)
                log_file.flush()
    
    return_code = process.poll()
    training_time = time.time() - training_start_time
    
    if return_code == 0:
        print(f"\n✅ Training completed successfully in {training_time/60:.1f} minutes!")
        
        # Check if model files were created
        model_path = Path(os.environ['nnUNet_results']) / f"Dataset{DATASET_ID:03d}_BaselineTest" / "nnUNetTrainer__nnUNetPlans__2d" / "fold_0"
        
        if model_path.exists():
            print(f"\n🎯 Model saved at: {model_path}")
            
            # List model files
            model_files = list(model_path.glob("*.pth"))
            print(f"\n📦 Model files created:")
            for model_file in model_files:
                size = model_file.stat().st_size / (1024*1024)  # MB
                print(f"   {model_file.name} ({size:.1f} MB)")
        else:
            print(f"⚠️ Model directory not found at: {model_path}")
            
    else:
        print(f"❌ Training failed with return code: {return_code}")
        print(f"📝 Check training log for details: {training_log}")
        
except KeyboardInterrupt:
    print("\n⏹️ Training interrupted by user")
    process.terminate()
except Exception as e:
    print(f"\n❌ Error during training: {e}")
    print(f"📝 Check training log for details: {training_log}")

print(f"\n📊 Training Summary:")
print(f"   Duration: {(time.time() - training_start_time)/60:.1f} minutes")
print(f"   Log file: {training_log}")
print(f"   Dataset: {DATASET_ID} (2D configuration)")


In [None]:
# Setup inference directories
INPUT_DIR = EXPERIMENT_DIR / "inference_input"
OUTPUT_DIR = EXPERIMENT_DIR / "inference_output"

# Create directories
INPUT_DIR.mkdir(exist_ok=True)
OUTPUT_DIR.mkdir(exist_ok=True)

print(f"🔮 Setting up inference...")
print(f"   Input directory: {INPUT_DIR}")
print(f"   Output directory: {OUTPUT_DIR}")

# Copy test images to inference input directory
test_images_dir = DATASET_PATH / "imagesTs"
test_images = list(test_images_dir.glob("*.nii.gz"))

print(f"\n📁 Found {len(test_images)} test images")

# Copy test images to inference directory
for test_img in test_images:
    shutil.copy2(test_img, INPUT_DIR)
    print(f"   Copied: {test_img.name}")

# Run inference
# Ref: nnunet/documentation/how_to_use_nnunet.md lines 200-220
inference_cmd = f"nnUNetv2_predict -i {INPUT_DIR} -o {OUTPUT_DIR} -d {DATASET_ID} -c 2d -f 0"
print(f"\n🚀 Inference command: {inference_cmd}")

inference_start_time = time.time()

try:
    result = subprocess.run(inference_cmd.split(), 
                          capture_output=True, text=True, timeout=300)  # 5 min timeout
    
    inference_time = time.time() - inference_start_time
    
    if result.returncode == 0:
        print(f"✅ Inference completed successfully in {inference_time:.1f}s")
        
        # List output files
        output_files = list(OUTPUT_DIR.glob("*.nii.gz"))
        print(f"\n📊 Generated {len(output_files)} prediction files:")
        
        for output_file in output_files:
            size = output_file.stat().st_size / (1024*1024)  # MB
            print(f"   {output_file.name} ({size:.2f} MB)")
            
    else:
        print(f"❌ Inference failed with return code: {result.returncode}")
        print("\nError output:")
        print(result.stderr)
        
except subprocess.TimeoutExpired:
    print("⏰ Inference timed out after 5 minutes")
except Exception as e:
    print(f"❌ Error during inference: {e}")

print(f"\n🔮 Inference Summary:")
print(f"   Duration: {inference_time:.1f}s")
print(f"   Input files: {len(test_images)}")
print(f"   Output directory: {OUTPUT_DIR}")


In [None]:
# Analyze and visualize inference results
def analyze_predictions(input_dir, output_dir, experiment_dir):
    """Analyze and visualize prediction results."""
    
    input_files = list(Path(input_dir).glob("*.nii.gz"))
    output_files = list(Path(output_dir).glob("*.nii.gz"))
    
    print(f"📊 Analyzing {len(output_files)} predictions...")
    
    results_summary = []
    
    for i, (input_file, output_file) in enumerate(zip(sorted(input_files), sorted(output_files))):
        print(f"\n🔍 Analyzing case {i+1}: {input_file.stem}")
        
        # Load input image and prediction
        try:
            input_nii = nib.load(input_file)
            pred_nii = nib.load(output_file)
            
            input_data = input_nii.get_fdata()
            pred_data = pred_nii.get_fdata()
            
            # Basic statistics
            unique_labels = np.unique(pred_data)
            pred_stats = {}
            
            for label in unique_labels:
                count = np.sum(pred_data == label)
                percentage = (count / pred_data.size) * 100
                pred_stats[int(label)] = {'count': count, 'percentage': percentage}
            
            # Store results
            case_result = {
                'case': input_file.stem,
                'input_shape': input_data.shape,
                'pred_shape': pred_data.shape,
                'input_range': [float(input_data.min()), float(input_data.max())],
                'predicted_labels': pred_stats
            }
            results_summary.append(case_result)
            
            print(f"   Input shape: {input_data.shape}")
            print(f"   Input range: [{input_data.min():.1f}, {input_data.max():.1f}]")
            print(f"   Predicted labels: {list(unique_labels)}")
            
            for label, stats in pred_stats.items():
                label_names = {0: 'Background', 1: 'Organ', 2: 'Lesion'}
                name = label_names.get(label, f'Label_{label}')
                print(f"     {name}: {stats['count']:,} voxels ({stats['percentage']:.1f}%)")
            
        except Exception as e:
            print(f"   ❌ Error loading case: {e}")
            continue
    
    # Save results summary
    results_file = Path(experiment_dir) / "inference_results.json"
    with open(results_file, 'w') as f:
        json.dump(results_summary, f, indent=2)
    
    print(f"\n💾 Results summary saved to: {results_file}")
    return results_summary

# Analyze results
results = analyze_predictions(INPUT_DIR, OUTPUT_DIR, EXPERIMENT_DIR)


In [None]:
# Visualize prediction results
def visualize_predictions(input_dir, output_dir, experiment_dir, case_idx=0):
    """Create detailed visualizations of prediction results."""
    
    input_files = sorted(list(Path(input_dir).glob("*.nii.gz")))
    output_files = sorted(list(Path(output_dir).glob("*.nii.gz")))
    
    if case_idx >= len(input_files):
        print(f"❌ Case index {case_idx} out of range. Available: 0-{len(input_files)-1}")
        return
    
    input_file = input_files[case_idx]
    output_file = output_files[case_idx]
    
    print(f"📊 Visualizing case {case_idx}: {input_file.stem}")
    
    # Load data
    input_nii = nib.load(input_file)
    pred_nii = nib.load(output_file)
    
    input_data = input_nii.get_fdata()
    pred_data = pred_nii.get_fdata()
    
    # Create comprehensive visualization
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Get middle slices
    mid_x, mid_y, mid_z = [s // 2 for s in input_data.shape]
    
    # Define colormap for predictions
    from matplotlib.colors import ListedColormap
    colors = ['black', 'red', 'yellow', 'cyan']
    pred_cmap = ListedColormap(colors[:len(np.unique(pred_data))])
    
    # Axial view
    axes[0, 0].imshow(input_data[:, :, mid_z].T, cmap='gray', origin='lower')
    axes[0, 0].set_title(f'Input - Axial (Z={mid_z})')
    axes[0, 0].axis('off')
    
    im1 = axes[1, 0].imshow(pred_data[:, :, mid_z].T, cmap=pred_cmap, origin='lower', vmin=0)
    axes[1, 0].set_title(f'Prediction - Axial (Z={mid_z})')
    axes[1, 0].axis('off')
    
    # Sagittal view
    axes[0, 1].imshow(input_data[mid_x, :, :].T, cmap='gray', origin='lower')
    axes[0, 1].set_title(f'Input - Sagittal (X={mid_x})')
    axes[0, 1].axis('off')
    
    axes[1, 1].imshow(pred_data[mid_x, :, :].T, cmap=pred_cmap, origin='lower', vmin=0)
    axes[1, 1].set_title(f'Prediction - Sagittal (X={mid_x})')
    axes[1, 1].axis('off')
    
    # Coronal view
    axes[0, 2].imshow(input_data[:, mid_y, :].T, cmap='gray', origin='lower')
    axes[0, 2].set_title(f'Input - Coronal (Y={mid_y})')
    axes[0, 2].axis('off')
    
    axes[1, 2].imshow(pred_data[:, mid_y, :].T, cmap=pred_cmap, origin='lower', vmin=0)
    axes[1, 2].set_title(f'Prediction - Coronal (Y={mid_y})')
    axes[1, 2].axis('off')
    
    # Add colorbar
    cbar = plt.colorbar(im1, ax=axes[1, :], orientation='horizontal', fraction=0.02, pad=0.05)
    cbar.set_label('Predicted Labels')
    
    plt.tight_layout()
    
    # Save visualization
    vis_path = Path(experiment_dir) / f'prediction_visualization_case_{case_idx}.png'
    plt.savefig(vis_path, dpi=150, bbox_inches='tight')
    plt.show()
    
    print(f"💾 Visualization saved to: {vis_path}")
    
    # Print detailed statistics
    print(f"\n📈 Detailed Statistics:")
    print(f"   Input shape: {input_data.shape}")
    print(f"   Input range: [{input_data.min():.1f}, {input_data.max():.1f}]")
    print(f"   Prediction shape: {pred_data.shape}")
    print(f"   Unique predictions: {np.unique(pred_data)}")
    
         label_names = {0: 'Background', 1: 'Organ', 2: 'Lesion'}
          for label in np.unique(pred_data):
         count = np.sum(pred_data == label)
         percentage = (count / pred_data.size) * 100
         name = label_names.get(int(label), f'Label_{int(label)}')
         print(f"   {name}: {count:,} voxels ({percentage:.2f}%)")

# Visualize first prediction
if len(list(OUTPUT_DIR.glob("*.nii.gz"))) > 0:
    visualize_predictions(INPUT_DIR, OUTPUT_DIR, EXPERIMENT_DIR, case_idx=0)
else:
    print("❌ No prediction files found for visualization")


In [None]:
# Experiment Summary and Cleanup
print("🎯 Baseline Test Summary")
print("=" * 50)

# Calculate total experiment time
experiment_end_time = time.time()
total_time = (experiment_end_time - time.time()) / 60  # This will be updated when run

print(f"📊 Dataset Information:")
print(f"   Dataset ID: {DATASET_ID}")
print(f"   Dataset Name: {DATASET_NAME}")
print(f"   Training Cases: 15")
print(f"   Test Cases: 3")
print(f"   Image Dimensions: 96x96x48")

print(f"\n🏋️ Training Information:")
print(f"   Configuration: 2D")
print(f"   Fold: 0 (single fold)")
print(f"   Model Type: nnUNetTrainer")

print(f"\n🔮 Inference Information:")
print(f"   Test Cases Processed: {len(list(OUTPUT_DIR.glob('*.nii.gz')))}")
print(f"   Prediction Files Generated: {len(list(OUTPUT_DIR.glob('*.nii.gz')))}")

print(f"\n📁 Generated Files:")
print(f"   Experiment Directory: {EXPERIMENT_DIR}")
print(f"   Training Log: {EXPERIMENT_DIR}/training_log.txt")
print(f"   Results Summary: {EXPERIMENT_DIR}/inference_results.json")
print(f"   Visualizations: {len(list(EXPERIMENT_DIR.glob('*.png')))} PNG files")

# List all generated files
print(f"\n📋 All Experiment Files:")
for file_path in sorted(EXPERIMENT_DIR.rglob("*")):
    if file_path.is_file():
        size = file_path.stat().st_size / (1024*1024)  # MB
        print(f"   {file_path.relative_to(EXPERIMENT_DIR)} ({size:.2f} MB)")

print(f"\n💡 Next Steps:")
print(f"   1. Review training log: {EXPERIMENT_DIR}/training_log.txt")
print(f"   2. Analyze predictions in: {OUTPUT_DIR}")
print(f"   3. For better results, try 3D training: nnUNetv2_train {DATASET_ID} 3d_fullres 0")
print(f"   4. For cross-validation, train all 5 folds")
print(f"   5. Use real medical data for production experiments")

print(f"\n🔧 Cleanup Options:")
print(f"   - Keep experiment files: {EXPERIMENT_DIR}")
print(f"   - Remove large training files to save space")
print(f"   - Archive results for future reference")

# Optional cleanup (commented out by default)
# print(f"\n🧹 Cleaning up temporary files...")
# if EXPERIMENT_DIR.exists():
#     for temp_file in EXPERIMENT_DIR.glob("*.tmp"):
#         temp_file.unlink()
#     print("✅ Temporary files cleaned up")

print(f"\n🎉 Baseline test completed successfully!")
print(f"📁 All results saved in: {EXPERIMENT_DIR}")

# Save experiment metadata
experiment_metadata = {
    "experiment_id": EXPERIMENT_ID,
    "dataset_id": DATASET_ID,
    "dataset_name": DATASET_NAME,
    "configuration": "2d",
    "fold": 0,
    "num_training_cases": 15,
    "num_test_cases": 3,
    "image_dimensions": [96, 96, 48],
    "experiment_directory": str(EXPERIMENT_DIR),
    "timestamp": datetime.now().isoformat(),
    "status": "completed"
}

metadata_file = EXPERIMENT_DIR / "experiment_metadata.json"
with open(metadata_file, 'w') as f:
    json.dump(experiment_metadata, f, indent=2)

print(f"💾 Experiment metadata saved to: {metadata_file}")
