# 3D Brain Tumor Segmentation with ProtoSeg3D - AWS S3 Deployment

This notebook deploys your **interpretable** 3D brain tumor segmentation network (ProtoSeg3D) from GitHub with data from AWS S3.

**What is ProtoSeg3D?**
- Prototype-based segmentation model
- Interpretable: Decisions based on learned prototypes
- Uses ASPP 3D with isotropic pooling for multi-scale context
- Diversity loss based on Jeffrey's divergence (graph-compatible)
- Multi-step training protocol (+3-5% mIoU improvement)

**Quick Setup:**
1. Upload this notebook to Google Colab
2. Go to Runtime ‚Üí Change runtime type ‚Üí Select **GPU (T4 or better)**
3. Add AWS credentials to Colab Secrets (üîë icon on left)
4. Update GitHub repository URL and S3 bucket details
5. Choose training mode: Single-phase (faster) or Multi-step (better results)
6. Run all cells

---

## Step 1: Check GPU and System Info

In [None]:
import tensorflow as tf
import sys

print("=" * 60)
print("SYSTEM INFORMATION")
print("=" * 60)
print(f"Python version: {sys.version}")
print(f"TensorFlow version: {tf.__version__}")
print(f"\nGPU Devices: {tf.config.list_physical_devices('GPU')}")

# Enable memory growth to prevent OOM errors
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"\n‚úì Memory growth enabled for {len(gpus)} GPU(s)")
        print(f"‚úì GPU: {gpus[0].name}")
    except RuntimeError as e:
        print(f"Error enabling memory growth: {e}")
else:
    print("\n‚ö†Ô∏è WARNING: No GPU detected! Please enable GPU in Runtime ‚Üí Change runtime type")

print("=" * 60)

## Step 2: Check Available Resources

In [None]:
print("=" * 60)
print("AVAILABLE RESOURCES")
print("=" * 60)

print("\nüì¶ Disk Space:")
!df -h /content | grep -E 'Filesystem|/content'

print("\nüß† RAM:")
!free -h | grep -E 'total|Mem'

print("\nüéÆ GPU Memory:")
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv

print("\n" + "=" * 60)
print("Note: Your ~4 GB preprocessed dataset will use disk space, not RAM")
print("Only one batch (~200 MB) is loaded in RAM at a time")
print("=" * 60)

## Step 3: Install Dependencies

In [None]:
%%capture
# Silent installation - remove %%capture to see output
!pip install h5py numpy tensorflow keras matplotlib awscli boto3 -q

## Step 4: Clone GitHub Repository

**‚ö†Ô∏è IMPORTANT: Update the repository URL below with your GitHub repository!**

In [None]:
import os

# ============================================================================
# UPDATE THIS WITH YOUR GITHUB REPOSITORY URL
# ============================================================================
GITHUB_REPO_URL = "https://github.com/dariamarc/brainTumorSurvival.git"
# ============================================================================

# Repository name (extracted from URL)
repo_name = GITHUB_REPO_URL.split('/')[-1].replace('.git', '')

print(f"Cloning repository: {GITHUB_REPO_URL}")
print(f"Repository name: {repo_name}")
print("-" * 60)

# IMPORTANT: Change to /content first to avoid directory issues
os.chdir('/content')
print("Changed to /content directory")

# Remove if exists (for re-running)
if os.path.exists(f'/content/{repo_name}'):
    !rm -rf /content/{repo_name}
    print(f"Removed existing directory: {repo_name}")

# Clone the repository
!git clone {GITHUB_REPO_URL}

# Change to repository directory
os.chdir(f'/content/{repo_name}')
print(f"\n‚úì Changed to directory: {os.getcwd()}")

# List files to verify
print("\nRepository contents:")
!ls -la

## Step 5: Verify Required Files

In [None]:
required_files = ['model_protoseg.py', 'data_generator.py', 'losses_protoseg.py']
optional_files = ['main_protoseg.py', 'main_protoseg_multistep.py']

print("Checking required files for ProtoSeg3D...")
print("=" * 60)

all_present = True
for file in required_files:
    if os.path.exists(file):
        print(f"‚úì {file} - Found")
    else:
        print(f"‚úó {file} - MISSING")
        all_present = False

print("\nChecking optional files...")
for file in optional_files:
    if os.path.exists(file):
        print(f"‚úì {file} - Found")
    else:
        print(f"- {file} - Not present (optional)")

print("=" * 60)
if all_present:
    print("‚úì All required files present! Ready to proceed.")
else:
    print("‚ö†Ô∏è WARNING: Some required files are missing!")
    print("Please check your repository structure.")

## Step 6: Configure AWS Credentials

**IMPORTANT SECURITY STEPS:**

1. Click the **üîë Secrets** icon in the left sidebar
2. Add these secrets:
   - Name: `AWS_ACCESS_KEY_ID`, Value: Your AWS access key
   - Name: `AWS_SECRET_ACCESS_KEY`, Value: Your AWS secret key
3. Enable "Notebook access" for both secrets

**Never hardcode credentials in notebooks!**

In [None]:
from google.colab import userdata
import os

print("Configuring AWS credentials...")
print("-" * 60)

try:
    # Get credentials from Colab Secrets
    os.environ['AWS_ACCESS_KEY_ID'] = userdata.get('AWS_ACCESS_KEY_ID')
    os.environ['AWS_SECRET_ACCESS_KEY'] = userdata.get('AWS_SECRET_ACCESS_KEY')

    print("‚úì AWS credentials loaded from Colab Secrets")
    print("‚úì Access Key ID: " + os.environ['AWS_ACCESS_KEY_ID'][:8] + "...")

except Exception as e:
    print("‚úó Error loading AWS credentials from Colab Secrets")
    print(f"Error: {e}")
    print("\nPlease add AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to Colab Secrets (üîë icon)")
    raise

## Step 7: Download Preprocessed Dataset from AWS S3

**‚ö†Ô∏è UPDATE YOUR S3 BUCKET DETAILS BELOW**

This will download your preprocessed dataset (128√ó160√ó192) to Colab's local storage.  
**Benefits:**
- Faster download: ~4 GB instead of ~8 GB
- No preprocessing needed: Ready to train immediately
- Center-cropped only (no resizing): Preserves original resolution
- Saves 15-20 minutes of preprocessing time

**Estimated download time: 5-8 minutes**

In [None]:
# ============================================================================
# UPDATE THESE WITH YOUR S3 DETAILS
# ============================================================================
S3_BUCKET = 'your-brats2020-data'           # Your S3 bucket name
S3_PATH = 'preprocessed_data'               # Path to PREPROCESSED data in S3
AWS_REGION = 'eu-central-1'                 # Your bucket's region
# ============================================================================

LOCAL_PATH = '/content/brainTumorData_preprocessed'

print("=" * 60)
print("DOWNLOADING PREPROCESSED DATASET FROM AWS S3")
print("=" * 60)

# Set AWS region
os.environ['AWS_DEFAULT_REGION'] = AWS_REGION

print(f"\nSource: s3://{S3_BUCKET}/{S3_PATH}")
print(f"Destination: {LOCAL_PATH}")
print(f"Region: {AWS_REGION}")
print(f"Dataset: Preprocessed (128√ó160√ó192, center-cropped)")
print(f"Dataset size: ~4 GB")
print(f"Estimated time: 5-8 minutes")
print("-" * 60)
print("Starting download...\n")

# Create local directory
!mkdir -p {LOCAL_PATH}

# Download preprocessed data using AWS CLI sync (shows progress)
!aws s3 sync s3://{S3_BUCKET}/{S3_PATH} {LOCAL_PATH}

# Verify download
print("\n" + "=" * 60)
if os.path.exists(LOCAL_PATH):
    # Count files
    file_count = sum([len(files) for r, d, files in os.walk(LOCAL_PATH)])

    # Calculate size
    total_size = sum(
        os.path.getsize(os.path.join(dirpath, filename))
        for dirpath, dirnames, filenames in os.walk(LOCAL_PATH)
        for filename in filenames
    ) / (1024**3)  # Convert to GB

    print("‚úì DOWNLOAD COMPLETE!")
    print("=" * 60)
    print(f"Location: {LOCAL_PATH}")
    print(f"Files downloaded: {file_count:,}")
    print(f"Total size: {total_size:.2f} GB")
    print(f"Expected files: 47,232 (369 volumes √ó 128 slices)")

    # Show sample files
    print("\nSample files:")
    !ls {LOCAL_PATH} | head -10

    # Check disk space after download
    print("\nüì¶ Disk Usage After Download:")
    !df -h /content | grep -E 'Filesystem|/content'

    # Set data path to preprocessed directory
    DATA_PATH = LOCAL_PATH
    print(f"\n‚úì DATA_PATH set to: {DATA_PATH}")
    print("‚úì Data is already preprocessed - ready to train!")
else:
    print("‚úó DOWNLOAD FAILED!")
    print("Please check:")
    print("  1. S3 bucket name is correct")
    print("  2. S3 path is correct (should be 'preprocessed_data')")
    print("  3. AWS credentials have read permissions")
    print("  4. AWS region is correct")
    raise FileNotFoundError(f"Data not found at {LOCAL_PATH}")

print("=" * 60)

## Step 8: Import Modules

In [None]:
import sys

# Ensure repository is in Python path
repo_dir = f'/content/{repo_name}'
if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)

print(f"Python path includes: {repo_dir}")
print(f"Working directory: {os.getcwd()}")
print("-" * 60)

# Import your modules
try:
    from model_protoseg import ProtoSeg3D
    from data_generator import MRIDataGenerator
    from losses_protoseg import compute_diversity_loss
    from tensorflow import keras
    import numpy as np

    print("‚úì All ProtoSeg3D modules imported successfully!")
    print("‚úì Diversity loss is graph-compatible (@tf.function)")
    print("‚úì Supports similarities from distances (via activation functions)")
except ImportError as e:
    print(f"‚úó Import error: {e}")
    print("\nDebugging info:")
    print("Files in repository:")
    !ls -la
    raise

## Step 9: Training Mode Selection

**Choose your training mode:**

### Option 1: Single-Phase Training (Faster)
- Trains all components together from start
- Estimated time: ~10-12 hours
- Expected mIoU: 0.58-0.63

### Option 2: Multi-Step Training (Better Results) ‚≠ê RECOMMENDED
- Phase 1: Warmup (freeze encoder, train ASPP + prototypes)
- Phase 2: Joint training (train all except FC layer)
- Phase 3-4: Fine-tuning (train FC layer only)
- Estimated time: ~14 hours
- Expected mIoU: 0.62-0.68 (+3-5% improvement)
- Better prototype quality and interpretability

**Set the TRAINING_MODE below:**

In [None]:
# ============================================================================
# CHOOSE YOUR TRAINING MODE
# ============================================================================
TRAINING_MODE = "multi-step"  # Options: "single-phase" or "multi-step"
# ============================================================================

print("=" * 60)
print("TRAINING MODE SELECTION")
print("=" * 60)
print(f"Selected mode: {TRAINING_MODE.upper()}")

if TRAINING_MODE == "single-phase":
    print("\n‚úì Single-phase training selected")
    print("  - Trains all components together")
    print("  - Estimated time: ~10-12 hours")
    print("  - Expected mIoU: 0.58-0.63")
elif TRAINING_MODE == "multi-step":
    print("\n‚úì Multi-step training selected (RECOMMENDED)")
    print("  - Phase 1: Warmup (30k steps)")
    print("  - Phase 2: Joint training (30k steps)")
    print("  - Phase 3-4: Fine-tuning (2k steps each)")
    print("  - Estimated time: ~14 hours")
    print("  - Expected mIoU: 0.62-0.68 (+3-5% improvement)")
    print("  - Better prototype quality")
else:
    raise ValueError(f"Invalid TRAINING_MODE: {TRAINING_MODE}. Use 'single-phase' or 'multi-step'")

print("=" * 60)

## Step 10: Training Configuration

Adjust these parameters based on your needs.

In [None]:
# ============================================================================
# TRAINING CONFIGURATION - ADJUST AS NEEDED
# ============================================================================

# Data configuration
BATCH_SIZE = 2          # Batch size (2 recommended for ProtoSeg3D)
SPLIT_RATIO = 0.2       # 20% for validation
RANDOM_STATE = 42
NUM_VOLUMES = 369       # Total number of volumes

# Volume dimensions (PREPROCESSED DATA - center-cropped only)
D = 128                 # Depth (number of slices)
H = 160                 # Height
W = 192                 # Width
C = 4                   # Channels (FLAIR, T1, T1ce, T2)

# Model configuration
NUM_CLASSES = 4         # Background + 3 tumor classes
NUM_PROTOTYPES_PER_CLASS = 7
PROTOTYPE_DIM = 128
ASPP_OUT_CHANNELS = 128  # Updated to match current architecture

# Diversity loss configuration
USE_DIVERSITY_LOSS = True   # Now graph-compatible with @tf.function
LAMBDA_J = 0.25             # Diversity loss weight

# Single-phase training settings
SINGLE_PHASE_EPOCHS = 100
SINGLE_PHASE_LR = 0.0001

# Multi-step training settings
WARMUP_STEPS = 30000
JOINT_STEPS = 30000
FINETUNE_STEPS = 2000
WARMUP_LR = 2.5e-4
JOINT_LR = 1.3e-4       # Average of backbone (2.5e-5) and other (2.5e-4)
FINETUNE_LR = 1e-5

# ============================================================================

INPUT_SHAPE = (D, H, W, C)

print("=" * 60)
print("TRAINING CONFIGURATION")
print("=" * 60)
print(f"Data path: {DATA_PATH}")
print(f"Input shape: {INPUT_SHAPE}")
print(f"Number of classes: {NUM_CLASSES}")
print(f"Prototypes per class: {NUM_PROTOTYPES_PER_CLASS}")
print(f"Total prototypes: {NUM_PROTOTYPES_PER_CLASS * NUM_CLASSES}")
print(f"Prototype dimension: {PROTOTYPE_DIM}")
print(f"ASPP output channels: {ASPP_OUT_CHANNELS}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Total volumes: {NUM_VOLUMES}")
print(f"Train/Val split: {int((1-SPLIT_RATIO)*100)}% / {int(SPLIT_RATIO*100)}%")

if USE_DIVERSITY_LOSS:
    print(f"\n‚úì Diversity loss ENABLED (Œª_J = {LAMBDA_J})")
    print("  ‚úì Graph-compatible implementation with @tf.function")
else:
    print(f"\n‚úó Diversity loss DISABLED")

if TRAINING_MODE == "single-phase":
    print(f"\nSingle-phase settings:")
    print(f"  - Epochs: {SINGLE_PHASE_EPOCHS}")
    print(f"  - Learning rate: {SINGLE_PHASE_LR}")
else:
    # Convert steps to epochs
    samples_per_epoch = int(NUM_VOLUMES * (1 - SPLIT_RATIO))
    warmup_epochs = max(1, WARMUP_STEPS * BATCH_SIZE // samples_per_epoch)
    joint_epochs = max(1, JOINT_STEPS * BATCH_SIZE // samples_per_epoch)
    finetune_epochs = max(1, FINETUNE_STEPS * BATCH_SIZE // samples_per_epoch)
    
    print(f"\nMulti-step settings:")
    print(f"  - Warmup: {WARMUP_STEPS} steps (~{warmup_epochs} epochs), LR={WARMUP_LR}")
    print(f"  - Joint: {JOINT_STEPS} steps (~{joint_epochs} epochs), LR={JOINT_LR}")
    print(f"  - Fine-tune: {FINETUNE_STEPS} steps (~{finetune_epochs} epochs), LR={FINETUNE_LR}")

print("=" * 60)

## Step 11: Create Data Generators

In [None]:
print("Creating data generators...")
print("-" * 60)

train_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=D,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='train',
    shuffle=True,
    random_state=RANDOM_STATE
)

validation_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=D,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='val',
    shuffle=False,
    random_state=RANDOM_STATE
)

print(f"\n‚úì Training batches: {len(train_generator)}")
print(f"‚úì Validation batches: {len(validation_generator)}")

## Step 12: Build and Compile Model

In [None]:
print("Building ProtoSeg3D model...")
print("-" * 60)

# Build the ProtoSeg3D model
model = ProtoSeg3D(
    in_size=INPUT_SHAPE,
    num_classes=NUM_CLASSES,
    num_prototypes_per_class=NUM_PROTOTYPES_PER_CLASS,
    prototype_dim=PROTOTYPE_DIM,
    features='resnet50_ri',
    f_dist='l2',
    prototype_activation_function='log',
    aspp_out_channels=ASPP_OUT_CHANNELS
)

print("‚úì Model architecture created!")
print(f"  - Encoder: Isotropic pooling (2√ó2√ó2)")
print(f"  - ASPP: 128 channels, rates [1, 2, 4, 8]")
print(f"  - Prototype dimension: {PROTOTYPE_DIM}")

# Setup optimizer and loss
if TRAINING_MODE == "single-phase":
    optimizer = keras.optimizers.Adam(learning_rate=SINGLE_PHASE_LR)
else:
    optimizer = keras.optimizers.Adam(learning_rate=WARMUP_LR)

loss_fn = keras.losses.CategoricalCrossentropy(from_logits=True)

# Compile model
model.compile(
    optimizer=optimizer,
    loss=loss_fn,
    metrics=[
        keras.metrics.MeanIoU(num_classes=NUM_CLASSES, name='mean_iou'),
        keras.metrics.CategoricalAccuracy(name='accuracy')
    ]
)

print("‚úì Model compiled successfully!")

# Enable diversity loss if requested
if USE_DIVERSITY_LOSS:
    model.enable_diversity_loss(lambda_j=LAMBDA_J)
    print(f"\n‚úì Diversity loss enabled (Œª_J = {LAMBDA_J})")
    print("  ‚úì Graph-compatible with @tf.function")
    print("  ‚úì Works with prototype similarities")

print("\nLoss function: Categorical Cross-Entropy (from logits)")
if USE_DIVERSITY_LOSS:
    print(f"              + Jeffrey's Divergence Diversity Loss (weight={LAMBDA_J})")
print("\nMetrics tracked:")
print("  - Mean IoU (primary metric for segmentation)")
print("  - Categorical Accuracy (overall voxel correctness)")

# Display prototype information
proto_info = model.get_prototype_info()
print("\n" + "=" * 60)
print("PROTOTYPE INFORMATION")
print("=" * 60)
print(f"Total prototypes: {proto_info['num_prototypes']}")
print(f"Prototypes per class: {proto_info['prototypes_per_class']}")
print(f"Prototype dimension: {proto_info['prototype_dim']}")
print("\nPrototype assignments:")
for c in range(NUM_CLASSES):
    class_protos = [i for i in range(proto_info['num_prototypes'])
                   if proto_info['prototype_class_identity'][i, c] == 1]
    print(f"  Class {c}: Prototypes {class_protos}")
print("=" * 60)

## Step 13: Setup Callbacks and Checkpointing

**IMPORTANT:** We'll save checkpoints to Google Drive for persistence across sessions.

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard, CSVLogger
import datetime

# Mount Google Drive for checkpoint storage
from google.colab import drive
drive.mount('/content/drive')

# Create directories
checkpoint_dir = '/content/checkpoints'
logs_dir = '/content/logs'
!mkdir -p {checkpoint_dir}
!mkdir -p {logs_dir}

# Google Drive checkpoint directory (for persistence)
drive_checkpoint_dir = f'/content/drive/MyDrive/protoseg_{TRAINING_MODE}_checkpoints'
!mkdir -p {drive_checkpoint_dir}

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

def create_callbacks(phase_name=""):
    """Create callbacks for training."""
    phase_suffix = f"_{phase_name}" if phase_name else ""
    
    return [
        # Save best model to Google Drive
        ModelCheckpoint(
            filepath=f'{drive_checkpoint_dir}/best_model{phase_suffix}_{timestamp}.keras',
            monitor='val_mean_iou',
            mode='max',
            save_best_only=True,
            verbose=1
        ),
        
        # TensorBoard logging
        TensorBoard(
            log_dir=f'{logs_dir}/{timestamp}{phase_suffix}',
            histogram_freq=1,
            write_graph=True
        ),
        
        # CSV Logger
        CSVLogger(
            filename=f'{drive_checkpoint_dir}/training_log{phase_suffix}_{timestamp}.csv',
            append=True
        )
    ]

# Create callbacks based on training mode
if TRAINING_MODE == "single-phase":
    callbacks = create_callbacks()
    
    # Add early stopping and LR reduction for single-phase
    callbacks.extend([
        EarlyStopping(
            monitor='val_mean_iou',
            patience=10,
            mode='max',
            restore_best_weights=True,
            verbose=1
        ),
        ReduceLROnPlateau(
            monitor='val_mean_iou',
            factor=0.5,
            patience=5,
            mode='max',
            min_lr=1e-7,
            verbose=1
        )
    ])
    
    print("‚úì Single-phase callbacks configured!")
    print(f"  - Drive backups: {drive_checkpoint_dir}")
    print(f"  - TensorBoard logs: {logs_dir}")
    print("  - EarlyStopping: patience=10 epochs")
    print("  - ReduceLROnPlateau: patience=5 epochs")
else:
    print("‚úì Multi-step callbacks will be created per phase")
    print(f"  - Drive backups: {drive_checkpoint_dir}")
    print(f"  - TensorBoard logs: {logs_dir}")

## Step 14: Train the Model

**This will take several hours. Keep the browser tab active to prevent disconnection!**

In [None]:
print("=" * 60)
print(f"STARTING {TRAINING_MODE.upper()} TRAINING")
print("=" * 60)
print("\n‚ö†Ô∏è IMPORTANT: Keep this browser tab active to prevent disconnection!")
print("‚ö†Ô∏è Models are being saved to Google Drive automatically")
print("=" * 60)
print()

if TRAINING_MODE == "single-phase":
    # ========== SINGLE-PHASE TRAINING ==========
    print(f"Training for {SINGLE_PHASE_EPOCHS} epochs")
    print(f"Learning rate: {SINGLE_PHASE_LR}")
    print()
    
    history = model.fit(
        train_generator,
        epochs=SINGLE_PHASE_EPOCHS,
        validation_data=validation_generator,
        callbacks=callbacks,
        verbose=1
    )
    
    # Save final model
    model.save(f'{drive_checkpoint_dir}/final_model_{timestamp}.keras')
    print(f"\n‚úì Final model saved: {drive_checkpoint_dir}/final_model_{timestamp}.keras")

else:
    # ========== MULTI-STEP TRAINING ==========
    
    # Calculate epochs per phase
    samples_per_epoch = int(NUM_VOLUMES * (1 - SPLIT_RATIO))
    warmup_epochs = max(1, WARMUP_STEPS * BATCH_SIZE // samples_per_epoch)
    joint_epochs = max(1, JOINT_STEPS * BATCH_SIZE // samples_per_epoch)
    finetune_epochs = max(1, FINETUNE_STEPS * BATCH_SIZE // samples_per_epoch)
    
    histories = {}
    
    # ===== PHASE 1: WARMUP =====
    print("\n" + "=" * 60)
    print("PHASE 1: WARMUP")
    print("=" * 60)
    model.setup_warmup_phase()
    model.print_trainable_status()
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=WARMUP_LR),
        loss=loss_fn,
        metrics=model.metrics
    )
    if USE_DIVERSITY_LOSS:
        model.use_diversity_loss = True
    
    histories['warmup'] = model.fit(
        train_generator,
        epochs=warmup_epochs,
        validation_data=validation_generator,
        callbacks=create_callbacks("warmup"),
        verbose=1
    )
    model.save(f'{drive_checkpoint_dir}/after_warmup_{timestamp}.keras')
    
    # ===== PHASE 2: JOINT TRAINING =====
    print("\n" + "=" * 60)
    print("PHASE 2: JOINT TRAINING")
    print("=" * 60)
    model.setup_joint_training_phase()
    model.print_trainable_status()
    
    # Polynomial LR decay
    lr_schedule = keras.optimizers.schedules.PolynomialDecay(
        initial_learning_rate=JOINT_LR,
        decay_steps=JOINT_STEPS,
        end_learning_rate=JOINT_LR * 0.01,
        power=0.9
    )
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=lr_schedule),
        loss=loss_fn,
        metrics=model.metrics
    )
    if USE_DIVERSITY_LOSS:
        model.use_diversity_loss = True
    
    histories['joint'] = model.fit(
        train_generator,
        epochs=joint_epochs,
        validation_data=validation_generator,
        callbacks=create_callbacks("joint"),
        verbose=1
    )
    model.save(f'{drive_checkpoint_dir}/after_joint_{timestamp}.keras')
    
    # ===== PHASE 3: FINE-TUNING 1 =====
    print("\n" + "=" * 60)
    print("PHASE 3: FINE-TUNING 1")
    print("=" * 60)
    model.setup_finetuning_phase()
    model.print_trainable_status()
    
    model.compile(
        optimizer=keras.optimizers.Adam(learning_rate=FINETUNE_LR),
        loss=loss_fn,
        metrics=model.metrics
    )
    
    histories['finetune1'] = model.fit(
        train_generator,
        epochs=finetune_epochs,
        validation_data=validation_generator,
        callbacks=create_callbacks("finetune1"),
        verbose=1
    )
    model.save(f'{drive_checkpoint_dir}/after_finetune1_{timestamp}.keras')
    
    # ===== PHASE 4: FINE-TUNING 2 =====
    print("\n" + "=" * 60)
    print("PHASE 4: FINE-TUNING 2 (FINAL)")
    print("=" * 60)
    
    histories['finetune2'] = model.fit(
        train_generator,
        epochs=finetune_epochs,
        validation_data=validation_generator,
        callbacks=create_callbacks("finetune2"),
        verbose=1
    )
    model.save(f'{drive_checkpoint_dir}/final_model_{timestamp}.keras')

print("\n" + "=" * 60)
print("‚úì TRAINING COMPLETED!")
print("=" * 60)

## Step 15: Visualize Training History

In [None]:
import matplotlib.pyplot as plt

if TRAINING_MODE == "single-phase":
    # Plot single-phase training
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    fig.suptitle('ProtoSeg3D Training Metrics (Single-Phase)', fontsize=16, fontweight='bold')
    
    # Loss
    axes[0].plot(history.history['loss'], label='Training Loss', linewidth=2)
    axes[0].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
    axes[0].set_title('Loss', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].legend(fontsize=10)
    axes[0].grid(True, alpha=0.3)
    
    # Accuracy
    axes[1].plot(history.history['accuracy'], label='Training Accuracy', linewidth=2)
    axes[1].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
    axes[1].set_title('Accuracy', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Accuracy', fontsize=12)
    axes[1].legend(fontsize=10)
    axes[1].grid(True, alpha=0.3)
    
    # Mean IoU
    axes[2].plot(history.history['mean_iou'], label='Training Mean IoU', linewidth=2)
    axes[2].plot(history.history['val_mean_iou'], label='Validation Mean IoU', linewidth=2)
    axes[2].set_title('Mean IoU', fontsize=14, fontweight='bold')
    axes[2].set_xlabel('Epoch', fontsize=12)
    axes[2].set_ylabel('Mean IoU', fontsize=12)
    axes[2].legend(fontsize=10)
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{drive_checkpoint_dir}/training_history_{timestamp}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print metrics
    print("\n" + "=" * 60)
    print("FINAL METRICS")
    print("=" * 60)
    print(f"Best validation Mean IoU: {max(history.history['val_mean_iou']):.4f}")
    print(f"Best validation Accuracy: {max(history.history['val_accuracy']):.4f}")
    print(f"Best validation Loss: {min(history.history['val_loss']):.4f}")

else:
    # Plot multi-step training
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    fig.suptitle('ProtoSeg3D Multi-Step Training Metrics', fontsize=16, fontweight='bold')
    
    phases = ['warmup', 'joint', 'finetune1', 'finetune2']
    colors = ['blue', 'green', 'orange', 'red']
    
    for idx, (phase, color) in enumerate(zip(phases, colors)):
        if phase in histories:
            row = idx // 2
            col = idx % 2
            
            ax = axes[row, col]
            ax.plot(histories[phase].history['mean_iou'], label='Train IoU', color=color, linewidth=2)
            ax.plot(histories[phase].history['val_mean_iou'], label='Val IoU', color=color, linestyle='--', linewidth=2)
            ax.set_title(f'Phase {idx+1}: {phase.title()}', fontsize=14, fontweight='bold')
            ax.set_xlabel('Epoch', fontsize=12)
            ax.set_ylabel('Mean IoU', fontsize=12)
            ax.legend(fontsize=10)
            ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f'{drive_checkpoint_dir}/multistep_history_{timestamp}.png', dpi=300, bbox_inches='tight')
    plt.show()
    
    # Print metrics per phase
    print("\n" + "=" * 60)
    print("FINAL METRICS PER PHASE")
    print("=" * 60)
    for phase in phases:
        if phase in histories:
            best_iou = max(histories[phase].history['val_mean_iou'])
            print(f"{phase.title():15s} - Best val Mean IoU: {best_iou:.4f}")

print(f"\n‚úì Training history saved to Google Drive")
print("=" * 60)

## Step 16: Test Prediction and Visualization

In [None]:
# Get a sample from validation set
print("Loading sample for prediction...")
sample_x, sample_y = validation_generator[0]

print(f"Input shape: {sample_x.shape}")
print(f"Label shape: {sample_y.shape}")

# Make prediction
print("\nGenerating prediction...")
prediction = model.predict(sample_x, verbose=0)
print(f"Prediction shape: {prediction.shape}")

# Visualize middle slice
slice_idx = D // 2  # Middle slice

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle(f'ProtoSeg3D Brain Tumor Segmentation - Slice {slice_idx}', fontsize=16, fontweight='bold')

# Input modalities
axes[0, 0].imshow(sample_x[0, slice_idx, :, :, 0], cmap='gray')
axes[0, 0].set_title('FLAIR', fontsize=12)
axes[0, 0].axis('off')

axes[0, 1].imshow(sample_x[0, slice_idx, :, :, 1], cmap='gray')
axes[0, 1].set_title('T1', fontsize=12)
axes[0, 1].axis('off')

axes[0, 2].imshow(sample_x[0, slice_idx, :, :, 2], cmap='gray')
axes[0, 2].set_title('T1ce', fontsize=12)
axes[0, 2].axis('off')

# Ground truth and prediction
axes[1, 0].imshow(sample_x[0, slice_idx, :, :, 3], cmap='gray')
axes[1, 0].set_title('T2', fontsize=12)
axes[1, 0].axis('off')

axes[1, 1].imshow(np.argmax(sample_y[0, slice_idx], axis=-1), cmap='jet', vmin=0, vmax=NUM_CLASSES-1)
axes[1, 1].set_title('Ground Truth', fontsize=12)
axes[1, 1].axis('off')

axes[1, 2].imshow(np.argmax(prediction[0, slice_idx], axis=-1), cmap='jet', vmin=0, vmax=NUM_CLASSES-1)
axes[1, 2].set_title('ProtoSeg Prediction', fontsize=12)
axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig(f'{drive_checkpoint_dir}/prediction_visualization_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\n‚úì Prediction visualization saved to Google Drive")

## Step 17: Summary and Results

In [None]:
print("=" * 70)
print(" " * 20 + "TRAINING COMPLETE!" + " " * 20)
print("=" * 70)
print(f"\n‚úì Training mode: {TRAINING_MODE.upper()}")
print(f"‚úì Your trained models are safely stored in Google Drive:")
print(f"  üìÅ {drive_checkpoint_dir}/")

if TRAINING_MODE == "single-phase":
    print("\n‚úì Files saved:")
    print(f"  - best_model_{timestamp}.keras")
    print(f"  - final_model_{timestamp}.keras")
    print(f"  - training_log_{timestamp}.csv")
    print(f"  - training_history_{timestamp}.png")
else:
    print("\n‚úì Files saved:")
    print(f"  - best_model_warmup_{timestamp}.keras")
    print(f"  - best_model_joint_{timestamp}.keras")
    print(f"  - best_model_finetune1_{timestamp}.keras")
    print(f"  - best_model_finetune2_{timestamp}.keras")
    print(f"  - after_warmup_{timestamp}.keras")
    print(f"  - after_joint_{timestamp}.keras")
    print(f"  - after_finetune1_{timestamp}.keras")
    print(f"  - final_model_{timestamp}.keras")
    print(f"  - training_log_*_{timestamp}.csv (per phase)")
    print(f"  - multistep_history_{timestamp}.png")

print(f"  - prediction_visualization_{timestamp}.png")

print("\nüìä Model Configuration:")
print(f"  - Architecture: ProtoSeg3D")
print(f"  - Prototypes per class: {NUM_PROTOTYPES_PER_CLASS}")
print(f"  - Total prototypes: {NUM_PROTOTYPES_PER_CLASS * NUM_CLASSES}")
print(f"  - Diversity loss: {'Enabled' if USE_DIVERSITY_LOSS else 'Disabled'}")
print(f"  - Training mode: {TRAINING_MODE}")

print("\nüéØ Next Steps:")
print("  1. Evaluate model on test set")
print("  2. Visualize learned prototypes")
print("  3. Analyze prototype activations for interpretability")
print("  4. Compare with baseline MProtoNet3D")

print("\nüíæ All important files are backed up to Google Drive!")
print("=" * 70)

# List all saved files
print("\nFiles in Google Drive checkpoint directory:")
!ls -lh {drive_checkpoint_dir}

## Optional: Launch TensorBoard

In [None]:
# Uncomment to launch TensorBoard
# %load_ext tensorboard
# %tensorboard --logdir {logs_dir}

## Optional: Download Files Locally

In [None]:
# Uncomment to download files to your computer
# from google.colab import files
# files.download(f'{drive_checkpoint_dir}/final_model_{timestamp}.keras')
# files.download(f'{drive_checkpoint_dir}/prediction_visualization_{timestamp}.png')

## Notes and Tips

### ProtoSeg3D Model Architecture
- **Interpretable**: Uses prototype-based learning
- **Encoder**: Custom 3D CNN with isotropic pooling (2√ó2√ó2)
- **ASPP 3D**: Multi-scale context with atrous convolutions (rates: 1, 2, 4, 8)
- **Prototypes**: 128-dimensional learned representations for each class
- **Diversity Loss**: Jeffrey's divergence-based (graph-compatible with @tf.function)
- **Multi-step Training**: Improved performance (+3-5% mIoU)

### Preprocessing Details
- **Method**: Center cropping only (no resizing)
- **Original**: 155 √ó 240 √ó 240 (D√óH√óW)
- **Preprocessed**: 128 √ó 160 √ó 192 (D√óH√óW)
- **Benefits**: Preserves original resolution, avoids blurring small tumor labels

### Training Modes

**Single-Phase:**
- Faster: ~10-12 hours
- Simpler: One training phase
- Expected mIoU: 0.58-0.63

**Multi-Step (Recommended):**
- Phase 1: Warmup - freeze encoder, train ASPP + prototypes
- Phase 2: Joint training - train all except FC layer
- Phase 3-4: Fine-tuning - train FC layer only
- Better results: +3-5% mIoU improvement
- Better prototype quality and interpretability
- Longer: ~14 hours

### Technical Details

**Diversity Loss Implementation:**
- Uses `@tf.function` decorator for TensorFlow graph mode compatibility
- Works with prototype similarities (converted from distances)
- Automatically downsamples ground truth labels to match activation resolution
- Uses `tf.while_loop` and `tf.cond` for graph-compatible control flow
- Class-specific: Only considers locations where each class appears
- No Python control flow issues during training

**Architecture Flow:**
```
Input: (B, 128, 160, 192, 4)
  ‚Üì Encoder (isotropic pooling)
Features: (B, 16, 20, 24, 128)
  ‚Üì ASPP 3D
Features: (B, 16, 20, 24, 128)
  ‚Üì Feature Projection
Features: (B, 16, 20, 24, 128)
  ‚Üì Prototype Layer
Similarities: (B, 16, 20, 24, 28)
  ‚Üì FC Layer
Logits: (B, 16, 20, 24, 4)
  ‚Üì Trilinear Upsample
Output: (B, 128, 160, 192, 4)
```

**Why Graph Mode Matters:**
- `model.fit()` uses graph mode for better performance
- Graph mode requires TensorFlow operations (not Python control flow)
- `@tf.function` automatically converts Python code to graph operations
- This enables faster training and better GPU utilization

### Loading Trained Models

```python
from model_protoseg import ProtoSeg3D
import tensorflow as tf

# Load model
model = tf.keras.models.load_model(
    'path/to/model.keras',
    custom_objects={'ProtoSeg3D': ProtoSeg3D}
)

# Get prototype information
proto_info = model.get_prototype_info()
print(f"Total prototypes: {proto_info['num_prototypes']}")
```

### Resuming Multi-Step Training

If your session disconnects during multi-step training, you can resume from the last saved checkpoint:

```python
from model_protoseg import ProtoSeg3D

# Load checkpoint from specific phase
checkpoint_path = f'{drive_checkpoint_dir}/after_warmup_{timestamp}.keras'
model = tf.keras.models.load_model(
    checkpoint_path,
    custom_objects={'ProtoSeg3D': ProtoSeg3D}
)

# Continue with next phase
model.setup_joint_training_phase()
# ... continue training
```

### Troubleshooting

**Issue: "OperatorNotAllowedInGraphError"**
- **Solution**: This has been fixed in `losses_protoseg.py` with `@tf.function` decorator
- If you still see this error, make sure you pulled the latest version from GitHub
- The diversity loss function is now fully graph-compatible

**Issue: "Out of memory" during training**
- **Solution**: Reduce batch size from 2 to 1
```python
BATCH_SIZE = 1  # Instead of 2
```
- Note: With new dimensions (128√ó160√ó192), memory usage is higher than old (96√ó160√ó160)
- Consider using Colab Pro for more GPU memory

**Issue: Training stuck or very slow**
- Check GPU is enabled: Runtime ‚Üí Change runtime type ‚Üí GPU
- Monitor GPU usage: `!nvidia-smi`
- Check if disk is full: `!df -h /content`

**Issue: Session disconnected**
- Colab Free: 12-hour limit, Colab Pro: 24-hour limit
- Keep browser tab active
- Models auto-saved to Google Drive every epoch
- Resume from last checkpoint (see above)

### Performance Expectations

From ProtoSeg paper:
- Multi-step training improves mIoU by ~3-5% over single-phase
- Better prototype diversity with diversity loss enabled
- More interpretable predictions

Expected for BraTS 2020:
- ProtoSeg3D (single-phase): mIoU 0.58-0.63
- ProtoSeg3D (multi-step): mIoU 0.62-0.68
- Trade-off: ~5% lower than MProtoNet3D, but much more interpretable

### Session Management
- **Colab Free**: 12-hour session limit
- **Colab Pro**: 24-hour session limit
- Keep browser tab active to prevent disconnection
- All checkpoints are saved to Google Drive automatically

### References
- ProtoSeg paper: "ProtoSeg: Interpretable Semantic Segmentation with Prototypical Parts" (WACV 2023)
- Jeffrey's Divergence: Symmetrized KL divergence for measuring distribution similarity
- Documentation: See PROTOSEG_ADAPTATION.md, MULTISTEP_TRAINING.md in repository
- GitHub Issues: Report problems at your repository's issues page