# Prototype-Based 3D Brain Tumor Segmentation - SageMaker Training

Three-phase training for PrototypeSegNet3D on AWS SageMaker.

1. **Phase 1**: Warm-up (frozen backbone)
2. **Phase 2**: Joint fine-tuning
3. **Phase 3**: Prototype projection & refinement

## Step 1: Install Dependencies

In [None]:
# Install required packages
!pip install -q tensorflow==2.15.0 h5py tqdm matplotlib numpy

# Verify TensorFlow installation
import tensorflow as tf
print(f"TensorFlow version: {tf.__version__}")

# Check GPU availability
gpus = tf.config.list_physical_devices('GPU')
print(f"GPUs available: {len(gpus)}")
for gpu in gpus:
    print(f"  {gpu}")

# Enable memory growth to prevent TF from grabbing all GPU memory
if gpus:
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)
    print("Memory growth enabled.")

In [None]:
import os
import sys

# Set working directory to the repo
REPO_DIR = '/home/ec2-user/SageMaker/brainTumorSurvival'
os.chdir(REPO_DIR)

# Add paths for imports
sys.path.insert(0, REPO_DIR)
sys.path.insert(0, os.path.join(REPO_DIR, 'ResNet_architecture'))

print(f"Working directory: {os.getcwd()}")

## Step 2: Download Data from S3

In [None]:
import boto3
import zipfile

# ============================================================================
# UPDATE THESE WITH YOUR S3 DETAILS
# ============================================================================
S3_BUCKET = 'your-brats2020-data'
S3_ZIP_FILE = 'preprocessed_data_cropped.zip'
# ============================================================================

LOCAL_DATA_DIR = '/home/ec2-user/SageMaker/data'
ZIP_PATH = os.path.join(LOCAL_DATA_DIR, 'preprocessed_data_cropped.zip')
DATA_PATH = os.path.join(LOCAL_DATA_DIR, 'preprocessed_data_cropped')

os.makedirs(LOCAL_DATA_DIR, exist_ok=True)

# Check if data already exists (persists across notebook restarts)
if os.path.exists(DATA_PATH) and len(os.listdir(DATA_PATH)) > 0:
    h5_files = [f for f in os.listdir(DATA_PATH) if f.endswith('.h5')]
    if len(h5_files) > 40000:
        print(f"Data already exists: {len(h5_files)} files found")
        print("Skipping download.")
    else:
        print(f"Incomplete data found ({len(h5_files)} files). Re-downloading...")
        need_download = True
else:
    need_download = True

if 'need_download' in dir() and need_download:
    print(f"Downloading data from s3://{S3_BUCKET}/{S3_ZIP_FILE}...")
    
    s3 = boto3.client('s3')
    s3.download_file(S3_BUCKET, S3_ZIP_FILE, ZIP_PATH)
    print("Download complete.")
    
    # Extract
    print("Extracting...")
    os.makedirs(DATA_PATH, exist_ok=True)
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(DATA_PATH)
    
    # Remove zip to save space
    os.remove(ZIP_PATH)
    print("Extraction complete.")
    
    # Check for nested folder
    h5_files = [f for f in os.listdir(DATA_PATH) if f.endswith('.h5')]
    if len(h5_files) == 0:
        subdirs = [d for d in os.listdir(DATA_PATH) if os.path.isdir(os.path.join(DATA_PATH, d))]
        if subdirs:
            DATA_PATH = os.path.join(DATA_PATH, subdirs[0])
            h5_files = [f for f in os.listdir(DATA_PATH) if f.endswith('.h5')]

# Verify
file_count = len([f for f in os.listdir(DATA_PATH) if f.endswith('.h5')])
print(f"Data ready: {file_count} files at {DATA_PATH}")

## Step 3: Import Modules

In [None]:
import tensorflow as tf

# Check GPU availability
gpus = tf.config.list_physical_devices('GPU')
print(f"GPUs available: {len(gpus)}")
for gpu in gpus:
    print(f"  {gpu}")

# Import project modules
from prototype_segnet3d import create_prototype_segnet3d
from trainer import PrototypeTrainer
from data_processing.data_generator import MRIDataGenerator

print("Modules imported.")

## Step 4: Configuration

In [None]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

# Data
BATCH_SIZE = 1
SPLIT_RATIO = 0.2
RANDOM_STATE = 42
NUM_VOLUMES = 369
NUM_SLICES = 128

# Volume dimensions
D = 128
H = 160
W = 192
C = 4

NUM_CLASSES = 4
N_PROTOTYPES = 3

# Model
BACKBONE_CHANNELS = 64
ASPP_OUT_CHANNELS = 256
DILATION_RATES = (2, 4, 8)

# Training epochs per phase
PHASE1_EPOCHS = 50
PHASE2_EPOCHS = 150
PHASE3_EPOCHS = 30

# ============================================================================

INPUT_SHAPE = (D, H, W, C)
print(f"Input shape: {INPUT_SHAPE}")
print(f"Classes: {NUM_CLASSES}, Prototypes: {N_PROTOTYPES}")

## Step 5: Setup Checkpoint Directory

In [None]:
import datetime

# Local checkpoint directory (persists on the notebook volume)
timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
CHECKPOINT_DIR = f'/home/ec2-user/SageMaker/checkpoints/{timestamp}'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# S3 path for backup (optional but recommended)
S3_CHECKPOINT_PREFIX = f'checkpoints/{timestamp}'

print(f"Checkpoint directory: {CHECKPOINT_DIR}")
print(f"S3 backup path: s3://{S3_BUCKET}/{S3_CHECKPOINT_PREFIX}/")

## Step 6: Create Data Generators

In [None]:
train_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=NUM_SLICES,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='train',
    shuffle=True,
    random_state=RANDOM_STATE
)

val_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=NUM_SLICES,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='val',
    shuffle=False,
    random_state=RANDOM_STATE
)

print(f"Training batches: {len(train_generator)}")
print(f"Validation batches: {len(val_generator)}")

## Step 7: Build Model

In [None]:
model = create_prototype_segnet3d(
    input_shape=INPUT_SHAPE,
    num_classes=NUM_CLASSES,
    n_prototypes=N_PROTOTYPES,
    backbone_channels=BACKBONE_CHANNELS,
    aspp_out_channels=ASPP_OUT_CHANNELS,
    dilation_rates=DILATION_RATES,
    distance_type='l2',
    activation_function='log'
)

# Initialize weights
dummy_input = tf.zeros((1,) + INPUT_SHAPE)
_ = model(dummy_input, training=False)

print("Model built.")
model.summary()

## Step 8: Create Trainer

In [None]:
trainer = PrototypeTrainer(
    model=model,
    train_generator=train_generator,
    val_generator=val_generator,
    checkpoint_dir=CHECKPOINT_DIR
)

print("Trainer created.")

## Step 9: Phase 1 - Warm-up Training

In [None]:
trainer.train_phase1(epochs=PHASE1_EPOCHS)

# Save model
model.save(f'{CHECKPOINT_DIR}/model_after_phase1.keras')
print("Phase 1 complete.")

## Step 10: Backup to S3 (After Phase 1)

In [None]:
# Backup checkpoints to S3
!aws s3 sync {CHECKPOINT_DIR} s3://{S3_BUCKET}/{S3_CHECKPOINT_PREFIX}/ --quiet
print(f"Phase 1 backed up to S3.")

## Step 11: Phase 2 - Joint Fine-tuning

In [None]:
trainer.train_phase2(epochs=PHASE2_EPOCHS)

# Save model
model.save(f'{CHECKPOINT_DIR}/model_after_phase2.keras')
print("Phase 2 complete.")

## Step 12: Backup to S3 (After Phase 2)

In [None]:
!aws s3 sync {CHECKPOINT_DIR} s3://{S3_BUCKET}/{S3_CHECKPOINT_PREFIX}/ --quiet
print(f"Phase 2 backed up to S3.")

## Step 13: Phase 3 - Prototype Projection & Refinement

In [None]:
trainer.train_phase3(epochs=PHASE3_EPOCHS)

# Save final model
model.save(f'{CHECKPOINT_DIR}/model_final.keras')
print("Phase 3 complete.")

## Step 14: Save Training History

In [None]:
import json

history = trainer.get_full_history()

with open(f'{CHECKPOINT_DIR}/training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"Training history saved.")

## Step 15: Final Backup to S3

In [None]:
!aws s3 sync {CHECKPOINT_DIR} s3://{S3_BUCKET}/{S3_CHECKPOINT_PREFIX}/ --quiet
print(f"All files backed up to: s3://{S3_BUCKET}/{S3_CHECKPOINT_PREFIX}/")

# List files
!ls -la {CHECKPOINT_DIR}

## Step 16: Plot Training Metrics

In [None]:
import matplotlib.pyplot as plt

history = trainer.get_full_history()

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Dice Scores
ax1 = axes[0, 0]
ax1.plot(history['epochs'], history['dice_gd_enhancing'], label='GD-Enhancing', alpha=0.8)
ax1.plot(history['epochs'], history['dice_edema'], label='Edema', alpha=0.8)
ax1.plot(history['epochs'], history['dice_necrotic'], label='Necrotic', alpha=0.8)
ax1.plot(history['epochs'], history['dice_mean'], label='Mean', linewidth=2, color='black')
for boundary in history['phase_boundaries']:
    ax1.axvline(x=boundary, color='gray', linestyle='--', alpha=0.5)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Dice Score')
ax1.set_title('Dice Scores by Class')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Purity Ratios
ax2 = axes[0, 1]
ax2.plot(history['epochs'], history['purity_proto_0'], label='Proto 0 (GD-Enh)', alpha=0.8)
ax2.plot(history['epochs'], history['purity_proto_1'], label='Proto 1 (Edema)', alpha=0.8)
ax2.plot(history['epochs'], history['purity_proto_2'], label='Proto 2 (Necrotic)', alpha=0.8)
ax2.plot(history['epochs'], history['purity_mean'], label='Mean', linewidth=2, color='black')
for boundary in history['phase_boundaries']:
    ax2.axvline(x=boundary, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Purity Ratio')
ax2.set_title('Prototype Purity Ratios')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Loss
ax3 = axes[1, 0]
ax3.plot(history['epochs'], history['train_loss'], label='Train Loss', alpha=0.8)
ax3.plot(history['epochs'], history['val_loss'], label='Val Loss', alpha=0.8)
for boundary in history['phase_boundaries']:
    ax3.axvline(x=boundary, color='gray', linestyle='--', alpha=0.5)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Loss')
ax3.set_title('Training and Validation Loss')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Whole Tumor Dice
ax4 = axes[1, 1]
ax4.plot(history['epochs'], history['dice_whole_tumor'], label='Whole Tumor', color='green', linewidth=2)
for boundary in history['phase_boundaries']:
    ax4.axvline(x=boundary, color='gray', linestyle='--', alpha=0.5)
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Dice Score')
ax4.set_title('Whole Tumor Dice Score')
ax4.legend()
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{CHECKPOINT_DIR}/training_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Plot saved.")

## Step 17: Visual Comparison - Predictions vs Ground Truth

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

colors = ['black', 'red', 'limegreen', 'dodgerblue']
seg_cmap = ListedColormap(colors)

def visualize_prediction(model, val_generator, sample_idx=0, slices_to_show=5):
    images, masks = val_generator[sample_idx]
    
    logits, similarities = model(images, training=False)
    predictions = tf.nn.softmax(logits, axis=-1)
    pred_classes = tf.argmax(predictions, axis=-1).numpy()[0]
    gt_classes = tf.argmax(masks, axis=-1).numpy()[0]
    input_image = images[0, :, :, :, 0]
    
    depth = input_image.shape[0]
    slice_indices = np.linspace(depth // 4, 3 * depth // 4, slices_to_show, dtype=int)
    
    fig, axes = plt.subplots(slices_to_show, 4, figsize=(16, 4 * slices_to_show))
    
    for row, slice_idx in enumerate(slice_indices):
        axes[row, 0].imshow(input_image[slice_idx], cmap='gray')
        axes[row, 0].set_title(f'FLAIR (Slice {slice_idx})')
        axes[row, 0].axis('off')
        
        axes[row, 1].imshow(gt_classes[slice_idx], cmap=seg_cmap, vmin=0, vmax=3)
        axes[row, 1].set_title('Ground Truth')
        axes[row, 1].axis('off')
        
        axes[row, 2].imshow(pred_classes[slice_idx], cmap=seg_cmap, vmin=0, vmax=3)
        axes[row, 2].set_title('Prediction')
        axes[row, 2].axis('off')
        
        axes[row, 3].imshow(input_image[slice_idx], cmap='gray')
        pred_overlay = np.zeros((*pred_classes[slice_idx].shape, 4))
        for class_idx, color in enumerate([(0,0,0,0), (1,0,0,0.4), (0,1,0,0.4), (0,0,1,0.4)]):
            mask = pred_classes[slice_idx] == class_idx
            pred_overlay[mask] = color
        axes[row, 3].imshow(pred_overlay)
        axes[row, 3].set_title('Overlay')
        axes[row, 3].axis('off')
    
    legend_elements = [
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='red', markersize=10, label='GD-Enhancing'),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='limegreen', markersize=10, label='Edema'),
        plt.Line2D([0], [0], marker='s', color='w', markerfacecolor='dodgerblue', markersize=10, label='Necrotic')
    ]
    fig.legend(handles=legend_elements, loc='upper center', ncol=3, bbox_to_anchor=(0.5, 1.02))
    
    plt.tight_layout()
    plt.savefig(f'{CHECKPOINT_DIR}/prediction_comparison.png', dpi=150, bbox_inches='tight')
    plt.show()
    
    from metrics import SegmentationMetrics
    seg_metrics = SegmentationMetrics()
    dice_scores = seg_metrics.compute_all(masks, logits)
    
    print(f"\nDice scores for this sample:")
    print(f"  GD-Enhancing: {dice_scores['dice_gd_enhancing']:.3f}")
    print(f"  Edema:        {dice_scores['dice_edema']:.3f}")
    print(f"  Necrotic:     {dice_scores['dice_necrotic']:.3f}")
    print(f"  Mean:         {dice_scores['dice_mean']:.3f}")

visualize_prediction(model, val_generator, sample_idx=0, slices_to_show=5)

## Step 18: Final S3 Sync (Including Plots)

In [None]:
!aws s3 sync {CHECKPOINT_DIR} s3://{S3_BUCKET}/{S3_CHECKPOINT_PREFIX}/
print(f"\nTraining complete!")
print(f"All outputs saved to: s3://{S3_BUCKET}/{S3_CHECKPOINT_PREFIX}/")

---

## Resume Training (Optional)

If the notebook disconnects, run these cells to resume from a checkpoint.

In [None]:
# # Uncomment and run this cell to resume from Phase 2 checkpoint
# 
# RESUME_CHECKPOINT_DIR = '/home/ec2-user/SageMaker/checkpoints/YYYYMMDD-HHMMSS'  # Update this
# 
# # Rebuild model
# model = create_prototype_segnet3d(
#     input_shape=INPUT_SHAPE,
#     num_classes=NUM_CLASSES,
#     n_prototypes=N_PROTOTYPES,
#     backbone_channels=BACKBONE_CHANNELS,
#     aspp_out_channels=ASPP_OUT_CHANNELS,
#     dilation_rates=DILATION_RATES,
#     distance_type='l2',
#     activation_function='log'
# )
# _ = model(tf.zeros((1,) + INPUT_SHAPE), training=False)
# 
# # Load weights from checkpoint
# checkpoint_path = f'{RESUME_CHECKPOINT_DIR}/model_after_phase2.keras'
# loaded_model = tf.keras.models.load_model(checkpoint_path)
# model.set_weights(loaded_model.get_weights())
# print(f"Weights loaded from {checkpoint_path}")
# 
# # Create new trainer and run Phase 3
# trainer = PrototypeTrainer(
#     model=model,
#     train_generator=train_generator,
#     val_generator=val_generator,
#     checkpoint_dir=RESUME_CHECKPOINT_DIR
# )
# trainer.train_phase3(epochs=PHASE3_EPOCHS)