# Prototype-Based 3D Brain Tumor Segmentation - Training

Three-phase training for PrototypeSegNet3D:
1. **Phase 1**: Warm-up (frozen backbone)
2. **Phase 2**: Joint fine-tuning
3. **Phase 3**: Prototype projection & refinement

## Step 1: Clone GitHub Repository

In [None]:
import os

# ============================================================================
# UPDATE THIS WITH YOUR GITHUB REPOSITORY URL
# ============================================================================
GITHUB_REPO_URL = "https://github.com/dariamarc/brainTumorSurvival.git"
# ============================================================================

repo_name = GITHUB_REPO_URL.split('/')[-1].replace('.git', '')

os.chdir('/content')

if os.path.exists(f'/content/{repo_name}'):
    !rm -rf /content/{repo_name}

!git clone {GITHUB_REPO_URL}

os.chdir(f'/content/{repo_name}')
print(f"Working directory: {os.getcwd()}")

## Step 2: Download Data from AWS S3

Downloads the preprocessed data as a single zip file for faster, more reliable transfer.

In [None]:
from google.colab import userdata
import os

# Install AWS CLI
!pip install -q awscli

# ============================================================================
# UPDATE THESE WITH YOUR S3 DETAILS
# ============================================================================
S3_BUCKET = 'your-brats2020-data'
S3_ZIP_FILE = 'preprocessed_data_cropped.zip'
AWS_REGION = 'eu-central-1'
# ============================================================================

LOCAL_PATH = '/content/brainTumorData_preprocessed_cropped'
ZIP_PATH = '/content/preprocessed_data_cropped.zip'

# Load AWS credentials from Colab Secrets
os.environ['AWS_ACCESS_KEY_ID'] = userdata.get('AWS_ACCESS_KEY_ID')
os.environ['AWS_SECRET_ACCESS_KEY'] = userdata.get('AWS_SECRET_ACCESS_KEY')
os.environ['AWS_DEFAULT_REGION'] = AWS_REGION

# Download zip file from S3 (single file transfer - much faster)
print("Downloading zip file from S3...")
!aws s3 cp s3://{S3_BUCKET}/{S3_ZIP_FILE} {ZIP_PATH}

# Unzip the data
print("Extracting data...")
!mkdir -p {LOCAL_PATH}
!unzip -q {ZIP_PATH} -d {LOCAL_PATH}

# Clean up zip file to save disk space
!rm {ZIP_PATH}

# Check if files are in a nested folder (depends on how zip was created)
h5_files = [f for f in os.listdir(LOCAL_PATH) if f.endswith('.h5')]
if len(h5_files) == 0:
    # Files might be in a subdirectory
    subdirs = [d for d in os.listdir(LOCAL_PATH) if os.path.isdir(os.path.join(LOCAL_PATH, d))]
    if subdirs:
        nested_path = os.path.join(LOCAL_PATH, subdirs[0])
        h5_files = [f for f in os.listdir(nested_path) if f.endswith('.h5')]
        if h5_files:
            print(f"Found files in nested folder: {subdirs[0]}")
            LOCAL_PATH = nested_path

# Verify file count
file_count = len([f for f in os.listdir(LOCAL_PATH) if f.endswith('.h5')])
expected_count = 369 * 128
print(f"Found {file_count} files (expected {expected_count})")

DATA_PATH = LOCAL_PATH
print(f"Data ready at: {DATA_PATH}")

## Step 3: Import Modules

In [None]:
import sys
import tensorflow as tf

repo_dir = f'/content/{repo_name}'
resnet_dir = f'/content/{repo_name}/ResNet_architecture'

sys.path.insert(0, repo_dir)
sys.path.insert(0, resnet_dir)

from prototype_segnet3d import create_prototype_segnet3d
from trainer import PrototypeTrainer
from data_processing.data_generator import MRIDataGenerator

print("Modules imported.")

## Step 4: Configuration

In [None]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

# Data
BATCH_SIZE = 1
SPLIT_RATIO = 0.2
RANDOM_STATE = 42
NUM_VOLUMES = 369
NUM_SLICES = 128  # Preprocessed data has 128 slices per volume

# Volume dimensions (must match preprocessed data)
# Preprocessing crops from (155, 240, 240) to (128, 160, 192)
D = 128  # Depth (number of slices)
H = 160  # Height (cropped from 240)
W = 192  # Width (cropped from 240)
C = 4    # Channels (FLAIR, T1, T1ce, T2)

NUM_CLASSES = 4
N_PROTOTYPES = 3

# Model
BACKBONE_CHANNELS = 64
ASPP_OUT_CHANNELS = 256
DILATION_RATES = (2, 4, 8)

# Training epochs per phase
PHASE1_EPOCHS = 50
PHASE2_EPOCHS = 150
PHASE3_EPOCHS = 30

# ============================================================================

INPUT_SHAPE = (D, H, W, C)
print(f"Input shape: {INPUT_SHAPE}")
print(f"Classes: {NUM_CLASSES}, Prototypes: {N_PROTOTYPES}")

## Step 5: Create Data Generators

In [None]:
train_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=NUM_SLICES,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='train',
    shuffle=True,
    random_state=RANDOM_STATE
)

val_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=NUM_SLICES,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='val',
    shuffle=False,
    random_state=RANDOM_STATE
)

print(f"Training batches: {len(train_generator)}")
print(f"Validation batches: {len(val_generator)}")

## Step 6: Build Model

In [None]:
model = create_prototype_segnet3d(
    input_shape=INPUT_SHAPE,
    num_classes=NUM_CLASSES,
    n_prototypes=N_PROTOTYPES,
    backbone_channels=BACKBONE_CHANNELS,
    aspp_out_channels=ASPP_OUT_CHANNELS,
    dilation_rates=DILATION_RATES,
    distance_type='l2',
    activation_function='log'
)

# Initialize weights
dummy_input = tf.zeros((1,) + INPUT_SHAPE)
_ = model(dummy_input, training=False)

print("Model built.")
model.summary()

## Step 7: Setup Google Drive for Checkpoints

In [None]:
from google.colab import drive
import datetime

drive.mount('/content/drive')

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
DRIVE_CHECKPOINT_DIR = f'/content/drive/MyDrive/prototype_segnet_checkpoints/{timestamp}'
LOCAL_CHECKPOINT_DIR = '/content/checkpoints'

!mkdir -p {DRIVE_CHECKPOINT_DIR}
!mkdir -p {LOCAL_CHECKPOINT_DIR}

print(f"Drive checkpoint dir: {DRIVE_CHECKPOINT_DIR}")

## Step 8: Create Trainer

In [None]:
trainer = PrototypeTrainer(
    model=model,
    train_generator=train_generator,
    val_generator=val_generator,
    checkpoint_dir=LOCAL_CHECKPOINT_DIR
)

print("Trainer created.")

## Step 9: Phase 1 - Warm-up Training

In [None]:
trainer.train_phase1(epochs=PHASE1_EPOCHS)

# Save to Google Drive
model.save(f'{DRIVE_CHECKPOINT_DIR}/model_after_phase1.keras')
print(f"Phase 1 model saved to Google Drive.")

## Step 10: Phase 2 - Joint Fine-tuning

In [None]:
trainer.train_phase2(epochs=PHASE2_EPOCHS)

# Save to Google Drive
model.save(f'{DRIVE_CHECKPOINT_DIR}/model_after_phase2.keras')
print(f"Phase 2 model saved to Google Drive.")

## Step 11: Phase 3 - Prototype Projection & Refinement

In [None]:
trainer.train_phase3(epochs=PHASE3_EPOCHS)

# Save final model to Google Drive
model.save(f'{DRIVE_CHECKPOINT_DIR}/model_final.keras')
print(f"Final model saved to Google Drive.")

## Step 12: Save Training History

In [None]:
import json

history = trainer.get_full_history()

with open(f'{DRIVE_CHECKPOINT_DIR}/training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"Training complete. All files saved to: {DRIVE_CHECKPOINT_DIR}")
!ls -la {DRIVE_CHECKPOINT_DIR}

## Step 13: Plot Training Metrics

Visualize Dice scores, purity ratios, and loss curves across all training phases.

In [None]:
import matplotlib.pyplot as plt

# Get plottable history
history = trainer.get_full_history()

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Plot 1: Dice Scores
ax1 = axes[0, 0]
ax1.plot(history['epochs'], history['dice_gd_enhancing'], label='GD-Enhancing', alpha=0.8)
ax1.plot(history['epochs'], history['dice_edema'], label='Edema', alpha=0.8)
ax1.plot(history['epochs'], history['dice_necrotic'], label='Necrotic', alpha=0.8)
ax1.plot(history['epochs'], history['dice_mean'], label='Mean', linewidth=2, color='black')
for boundary in history['phase_boundaries']:
    ax1.axvline(x=boundary, color='gray', linestyle='--', alpha=0.5)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Dice Score')
ax1.set_title('Dice Scores by Class')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Plot 2: Purity Ratios
ax2 = axes[0, 1]
ax2.plot(history['epochs'], history['purity_proto_0'], label='Proto 0 (GD-Enh)', alpha=0.8)
ax2.plot(history['epochs'], history['purity_proto_1'], label='Proto 1 (Edema)', alpha=0.8)
ax2.plot(history['epochs'], history['purity_proto_2'], label='Proto 2 (Necrotic)', alpha=0.8)
ax2.plot(history['epochs'], history['purity_mean'], label='Mean', linewidth=2, color='black')
for boundary in history['phase_boundaries']:
    ax2.axvline(x=boundary, color='gray', linestyle='--', alpha=0.5)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Purity Ratio')
ax2.set_title('Prototype Purity Ratios')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Plot 3: Training and Validation Loss
ax3 = axes[1, 0]
ax3.plot(history['epochs'], history['train_loss'], label='Train Loss', alpha=0.8)
ax3.plot(history['epochs'], history['val_loss'], label='Val Loss', alpha=0.8)
for boundary in history['phase_boundaries']:
    ax3.axvline(x=boundary, color='gray', linestyle='--', alpha=0.5)
ax3.set_xlabel('Epoch')
ax3.set_ylabel('Loss')
ax3.set_title('Training and Validation Loss')
ax3.legend()
ax3.grid(True, alpha=0.3)

# Plot 4: Whole Tumor Dice
ax4 = axes[1, 1]
ax4.plot(history['epochs'], history['dice_whole_tumor'], label='Whole Tumor Dice', color='green', linewidth=2)
for boundary in history['phase_boundaries']:
    ax4.axvline(x=boundary, color='gray', linestyle='--', alpha=0.5)
ax4.set_xlabel('Epoch')
ax4.set_ylabel('Dice Score')
ax4.set_title('Whole Tumor Dice Score')
ax4.legend()
ax4.grid(True, alpha=0.3)

# Add phase labels
for ax in axes.flat:
    if history['phase_boundaries']:
        # Phase 1
        ax.text(history['phase_boundaries'][0] / 2, ax.get_ylim()[1] * 0.95, 
                'Phase 1', ha='center', fontsize=9, alpha=0.7)
        # Phase 2
        if len(history['phase_boundaries']) > 1:
            mid_phase2 = (history['phase_boundaries'][0] + history['phase_boundaries'][1]) / 2
            ax.text(mid_phase2, ax.get_ylim()[1] * 0.95, 
                    'Phase 2', ha='center', fontsize=9, alpha=0.7)
            # Phase 3
            mid_phase3 = (history['phase_boundaries'][1] + max(history['epochs'])) / 2
            ax.text(mid_phase3, ax.get_ylim()[1] * 0.95, 
                    'Phase 3', ha='center', fontsize=9, alpha=0.7)

plt.tight_layout()
plt.savefig(f'{DRIVE_CHECKPOINT_DIR}/training_metrics.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Plot saved to: {DRIVE_CHECKPOINT_DIR}/training_metrics.png")