# Prototype-Based 3D Brain Tumor Segmentation - Paperspace Gradient

Three-phase training for PrototypeSegNet3D on Paperspace Gradient.

1. **Phase 1**: Warm-up (frozen backbone)
2. **Phase 2**: Joint fine-tuning
3. **Phase 3**: Prototype projection & refinement

**Note**: Use `/storage` for persistent data that survives notebook restarts.

## Step 1: Clone Repository

In [None]:
import os

# Gradient notebooks start in /notebooks
os.chdir('/notebooks')

# Clone repo if not exists
REPO_DIR = '/notebooks/brainTumorSurvival'

if not os.path.exists(REPO_DIR):
    !git clone https://github.com/dariamarc/brainTumorSurvival.git
else:
    # Pull latest changes
    os.chdir(REPO_DIR)
    !git pull

os.chdir(REPO_DIR)
print(f"Working directory: {os.getcwd()}")

## Step 2: Setup Paths

In [None]:
import sys

REPO_DIR = '/notebooks/brainTumorSurvival'
RESNET_DIR = f'{REPO_DIR}/ResNet_architecture'

sys.path.insert(0, REPO_DIR)
sys.path.insert(0, RESNET_DIR)

# Persistent storage for data and checkpoints (survives restarts)
STORAGE_DIR = '/storage'
DATA_DIR = f'{STORAGE_DIR}/brainTumorData_preprocessed_cropped'
CHECKPOINT_BASE = f'{STORAGE_DIR}/checkpoints'

os.makedirs(STORAGE_DIR, exist_ok=True)
os.makedirs(CHECKPOINT_BASE, exist_ok=True)

print(f"Repo: {REPO_DIR}")
print(f"Data will be stored at: {DATA_DIR}")
print(f"Checkpoints will be stored at: {CHECKPOINT_BASE}")

## Step 3: Download Data from S3

Data is stored in `/storage` so it persists across notebook sessions.

In [None]:
# Install boto3 for S3 access
!pip install -q boto3

# ============================================================================
# UPDATE THESE WITH YOUR AWS CREDENTIALS AND S3 DETAILS
# ============================================================================
AWS_ACCESS_KEY_ID = 'YOUR_ACCESS_KEY'          # Replace with your key
AWS_SECRET_ACCESS_KEY = 'YOUR_SECRET_KEY'      # Replace with your secret
AWS_REGION = 'eu-central-1'
S3_BUCKET = 'your-brats2020-data'              # Replace with your bucket
S3_ZIP_FILE = 'preprocessed_data_cropped.zip'
# ============================================================================

In [None]:
import os
import boto3
from botocore.config import Config

ZIP_PATH = f'{STORAGE_DIR}/preprocessed_data_cropped.zip'

# Check if data already exists
def count_h5_files(path):
    if not os.path.exists(path):
        return 0
    return len([f for f in os.listdir(path) if f.endswith('.h5')])

existing_files = count_h5_files(DATA_DIR)

if existing_files > 40000:
    print(f"Data already exists: {existing_files} files")
    print("Skipping download.")
    DATA_PATH = DATA_DIR
else:
    print(f"Downloading data from S3...")
    print(f"  Bucket: {S3_BUCKET}")
    print(f"  File: {S3_ZIP_FILE}")
    
    # Create S3 client with credentials
    s3_client = boto3.client(
        's3',
        aws_access_key_id=AWS_ACCESS_KEY_ID,
        aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
        region_name=AWS_REGION,
        config=Config(signature_version='s3v4')
    )
    
    # Download with progress
    import sys
    def download_progress(bytes_transferred):
        sys.stdout.write(f"\r  Downloaded: {bytes_transferred / (1024**3):.2f} GB")
        sys.stdout.flush()
    
    class ProgressCallback:
        def __init__(self):
            self.bytes_transferred = 0
        def __call__(self, bytes_amount):
            self.bytes_transferred += bytes_amount
            download_progress(self.bytes_transferred)
    
    progress = ProgressCallback()
    s3_client.download_file(
        S3_BUCKET, 
        S3_ZIP_FILE, 
        ZIP_PATH,
        Callback=progress
    )
    print("\n  Download complete!")
    
    # Extract
    print("Extracting...")
    os.makedirs(DATA_DIR, exist_ok=True)
    
    import zipfile
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(DATA_DIR)
    print("  Extraction complete!")
    
    # Remove zip to save space
    os.remove(ZIP_PATH)
    print("  Cleaned up zip file.")
    
    # Handle nested folder if present
    h5_files = count_h5_files(DATA_DIR)
    if h5_files == 0:
        subdirs = [d for d in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, d))]
        if subdirs:
            nested = os.path.join(DATA_DIR, subdirs[0])
            print(f"  Moving files from nested folder: {subdirs[0]}")
            for f in os.listdir(nested):
                os.rename(os.path.join(nested, f), os.path.join(DATA_DIR, f))
            os.rmdir(nested)
    
    DATA_PATH = DATA_DIR

# Verify
file_count = count_h5_files(DATA_PATH)
print(f"\nData ready: {file_count} files at {DATA_PATH}")

## Step 4: Check GPU and Import Modules

In [None]:
import tensorflow as tf

# Check GPU
gpus = tf.config.list_physical_devices('GPU')
print(f"GPUs available: {len(gpus)}")
for gpu in gpus:
    print(f"  {gpu}")

# Show GPU memory
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv

In [None]:
from prototype_segnet3d import create_prototype_segnet3d
from trainer import PrototypeTrainer
from data_processing.data_generator import MRIDataGenerator

print("Modules imported.")

## Step 5: Configuration

In [None]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

# Data
BATCH_SIZE = 1
SPLIT_RATIO = 0.2
RANDOM_STATE = 42
NUM_VOLUMES = 369
NUM_SLICES = 128

# Volume dimensions
D = 128
H = 160
W = 192
C = 4

NUM_CLASSES = 4
N_PROTOTYPES = 3

# Model
BACKBONE_CHANNELS = 64
ASPP_OUT_CHANNELS = 256
DILATION_RATES = (2, 4, 8)

# Training epochs per phase
PHASE1_EPOCHS = 50
PHASE2_EPOCHS = 150
PHASE3_EPOCHS = 30

# ============================================================================

INPUT_SHAPE = (D, H, W, C)
print(f"Input shape: {INPUT_SHAPE}")
print(f"Classes: {NUM_CLASSES}, Prototypes: {N_PROTOTYPES}")

## Step 6: Setup Checkpoint Directory

In [None]:
import datetime

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
CHECKPOINT_DIR = f'{CHECKPOINT_BASE}/{timestamp}'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print("(This is in /storage, so it persists across sessions)")

## Step 7: Create Data Generators

In [None]:
train_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=NUM_SLICES,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='train',
    shuffle=True,
    random_state=RANDOM_STATE
)

val_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=NUM_SLICES,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='val',
    shuffle=False,
    random_state=RANDOM_STATE
)

print(f"Training batches: {len(train_generator)}")
print(f"Validation batches: {len(val_generator)}")

## Step 8: Build Model

In [None]:
model = create_prototype_segnet3d(
    input_shape=INPUT_SHAPE,
    num_classes=NUM_CLASSES,
    n_prototypes=N_PROTOTYPES,
    backbone_channels=BACKBONE_CHANNELS,
    aspp_out_channels=ASPP_OUT_CHANNELS,
    dilation_rates=DILATION_RATES,
    distance_type='l2',
    activation_function='log'
)

# Initialize weights
dummy_input = tf.zeros((1,) + INPUT_SHAPE)
_ = model(dummy_input, training=False)

print("Model built.")
model.summary()

## Step 9: Create Trainer

In [None]:
trainer = PrototypeTrainer(
    model=model,
    train_generator=train_generator,
    val_generator=val_generator,
    checkpoint_dir=CHECKPOINT_DIR
)

print("Trainer created.")

## Step 10: Phase 1 - Warm-up Training

In [None]:
trainer.train_phase1(epochs=PHASE1_EPOCHS)

model.save(f'{CHECKPOINT_DIR}/model_after_phase1.keras')
print("Phase 1 complete and saved.")

## Step 11: Phase 2 - Joint Fine-tuning

In [None]:
trainer.train_phase2(epochs=PHASE2_EPOCHS)

model.save(f'{CHECKPOINT_DIR}/model_after_phase2.keras')
print("Phase 2 complete and saved.")

## Step 12: Phase 3 - Prototype Projection & Refinement

In [None]:
trainer.train_phase3(epochs=PHASE3_EPOCHS)

model.save(f'{CHECKPOINT_DIR}/model_final.keras')
print("Phase 3 complete and saved.")

## Step 13: Save Training History

In [None]:
import json

history = trainer.get_full_history()

with open(f'{CHECKPOINT_DIR}/training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print("Training history saved.")
print(f"\nAll files saved to: {CHECKPOINT_DIR}")
!ls -la {CHECKPOINT_DIR}

## Step 14: Plot Training Metrics

In [None]:
import matplotlib.pyplot as plt

history = trainer.get_full_history()

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Dice Scores
ax1 = axes[0, 0]
ax1.plot(history['epochs'], history['dice_gd_enhancing'], label='GD-Enhancing', alpha=0.8)
ax1.plot(history['epochs'], history['dice_edema'], label='Edema', alpha=0.8)
ax1.plot(history['epochs'], history['dice_necrotic'], label='Necrotic', alpha=0.8)
ax1.plot(history['epochs'], history['dice_mean'], label='Mean', linewidth=2, color='black')
for b in history['phase_boundaries']:
    ax1.axvline(x=b, color='gray', linestyle='--', alpha=0.5)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Dice Score')
ax1.set_title('Dice Scores by Class')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Purity Ratios
ax2 = axes[0, 1]
ax2.plot(history['epochs'], history['purity_proto_0'], label='Proto 0', alpha=0.8)
ax2.plot(history['epochs'], history['purity_proto_1'], label='Proto 1', alpha=0.8)
ax2.plot(history['epochs'], history['purity_proto_2'], label='Proto 2', alpha=0.8)
ax2.plot(history['epochs'], history['purity_mean'], label='Mean', linewidth=2, color='black')
for b in history['phase_boundaries']:
    ax2.axvline(x=b, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Purity Ratio')
ax2.set_title('Prototype Purity')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Loss
ax3 = axes[1, 0]
ax3.plot(history['epochs'], history['train_loss'], label='Train', alpha=0.8)
ax3.plot(history['epochs'], history['val_loss'], label='Val', alpha=0.8)
for b in history['phase_boundaries']:
    ax3.axvline(x=b, color='gray', linestyle='--', alpha=0.5)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Loss')
ax3.set_title('Training and Validation Loss')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Whole Tumor Dice
ax4 = axes[1, 1]
ax4.plot(history['epochs'], history['dice_whole_tumor'], color='green', linewidth=2)
for b in history['phase_boundaries']:
    ax4.axvline(x=b, color='gray', linestyle='--', alpha=0.5)
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Dice Score')
ax4.set_title('Whole Tumor Dice')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{CHECKPOINT_DIR}/training_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

## Step 15: Visual Comparison - Predictions vs Ground Truth

In [None]:
import numpy as np
from matplotlib.colors import ListedColormap

colors = ['black', 'red', 'limegreen', 'dodgerblue']
seg_cmap = ListedColormap(colors)

def visualize_prediction(model, val_generator, sample_idx=0, slices_to_show=5):
    images, masks = val_generator[sample_idx]
    
    logits, similarities = model(images, training=False)
    predictions = tf.nn.softmax(logits, axis=-1)
    pred_classes = tf.argmax(predictions, axis=-1).numpy()[0]
    gt_classes = tf.argmax(masks, axis=-1).numpy()[0]
    input_image = images[0, :, :, :, 0]
    
    depth = input_image.shape[0]
    slice_indices = np.linspace(depth // 4, 3 * depth // 4, slices_to_show, dtype=int)
    
    fig, axes = plt.subplots(slices_to_show, 4, figsize=(16, 4 * slices_to_show))
    
    for row, slice_idx in enumerate(slice_indices):
        axes[row, 0].imshow(input_image[slice_idx], cmap='gray')
        axes[row, 0].set_title(f'FLAIR (Slice {slice_idx})')
        axes[row, 0].axis('off')
        
        axes[row, 1].imshow(gt_classes[slice_idx], cmap=seg_cmap, vmin=0, vmax=3)
        axes[row, 1].set_title('Ground Truth')
        axes[row, 1].axis('off')
        
        axes[row, 2].imshow(pred_classes[slice_idx], cmap=seg_cmap, vmin=0, vmax=3)
        axes[row, 2].set_title('Prediction')
        axes[row, 2].axis('off')
        
        axes[row, 3].imshow(input_image[slice_idx], cmap='gray')
        pred_overlay = np.zeros((*pred_classes[slice_idx].shape, 4))
        for class_idx, color in enumerate([(0,0,0,0), (1,0,0,0.4), (0,1,0,0.4), (0,0,1,0.4)]):
            mask = pred_classes[slice_idx] == class_idx
            pred_overlay[mask] = color
        axes[row, 3].imshow(pred_overlay)
        axes[row, 3].set_title('Overlay')
        axes[row, 3].axis('off')
    
    legend_elements = [
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='red', markersize=10, label='GD-Enhancing'),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='limegreen', markersize=10, label='Edema'),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='dodgerblue', markersize=10, label='Necrotic')
    ]
    fig.legend(handles=legend_elements, loc='upper center', ncol=3, bbox_to_anchor=(0.5, 1.02))
    
    plt.tight_layout()
    plt.savefig(f'{CHECKPOINT_DIR}/prediction_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    from metrics import SegmentationMetrics
    seg_metrics = SegmentationMetrics()
    dice_scores = seg_metrics.compute_all(masks, logits)
    
    print(f"\nDice scores for this sample:")
    print(f"  GD-Enhancing: {dice_scores['dice_gd_enhancing']:.3f}")
    print(f"  Edema:        {dice_scores['dice_edema']:.3f}")
    print(f"  Necrotic:     {dice_scores['dice_necrotic']:.3f}")
    print(f"  Mean:         {dice_scores['dice_mean']:.3f}")

visualize_prediction(model, val_generator, sample_idx=0, slices_to_show=5)

## Step 16: Backup to S3 (Optional)

In [None]:
# Backup checkpoints to S3
import boto3
from botocore.config import Config

s3_client = boto3.client(
    's3',
    aws_access_key_id=AWS_ACCESS_KEY_ID,
    aws_secret_access_key=AWS_SECRET_ACCESS_KEY,
    region_name=AWS_REGION,
    config=Config(signature_version='s3v4')
)

S3_BACKUP_PREFIX = f'gradient_checkpoints/{timestamp}'

print(f"Backing up to s3://{S3_BUCKET}/{S3_BACKUP_PREFIX}/")

# Upload all files in checkpoint directory
for filename in os.listdir(CHECKPOINT_DIR):
    filepath = os.path.join(CHECKPOINT_DIR, filename)
    if os.path.isfile(filepath):
        s3_key = f"{S3_BACKUP_PREFIX}/{filename}"
        print(f"  Uploading: {filename}")
        s3_client.upload_file(filepath, S3_BUCKET, s3_key)

print(f"\nBackup complete: s3://{S3_BUCKET}/{S3_BACKUP_PREFIX}/")

---

## Resume Training (If Session Restarts)

Since data and checkpoints are in `/storage`, they persist across sessions.

In [None]:
# # Uncomment to resume from a checkpoint
# 
# RESUME_CHECKPOINT_DIR = '/storage/checkpoints/YYYYMMDD-HHMMSS'  # Update this
# 
# # Rebuild model
# model = create_prototype_segnet3d(
#     input_shape=INPUT_SHAPE,
#     num_classes=NUM_CLASSES,
#     n_prototypes=N_PROTOTYPES,
#     backbone_channels=BACKBONE_CHANNELS,
#     aspp_out_channels=ASPP_OUT_CHANNELS,
#     dilation_rates=DILATION_RATES,
#     distance_type='l2',
#     activation_function='log'
# )
# _ = model(tf.zeros((1,) + INPUT_SHAPE), training=False)
# 
# # Load weights
# checkpoint_file = f'{RESUME_CHECKPOINT_DIR}/model_after_phase2.keras'
# loaded = tf.keras.models.load_model(checkpoint_file)
# model.set_weights(loaded.get_weights())
# print(f"Loaded weights from: {checkpoint_file}")
# 
# # Create trainer and continue
# trainer = PrototypeTrainer(
#     model=model,
#     train_generator=train_generator,
#     val_generator=val_generator,
#     checkpoint_dir=RESUME_CHECKPOINT_DIR
# )
# trainer.train_phase3(epochs=PHASE3_EPOCHS)

## Download Checkpoints Locally

To download your trained model from Gradient:

In [None]:
# Create a zip of the checkpoint directory for easy download
import shutil

zip_name = f'/notebooks/checkpoint_{timestamp}'
shutil.make_archive(zip_name, 'zip', CHECKPOINT_DIR)
print(f"Created: {zip_name}.zip")
print("You can download this file from the Gradient file browser.")