# Prototype-Based 3D Brain Tumor Segmentation - Training

Three-phase training for PrototypeSegNet3D:
1. **Phase 1**: Warm-up (frozen backbone)
2. **Phase 2**: Joint fine-tuning
3. **Phase 3**: Prototype projection & refinement

## Step 1: Clone GitHub Repository

In [None]:
import os

# ============================================================================
# UPDATE THIS WITH YOUR GITHUB REPOSITORY URL
# ============================================================================
GITHUB_REPO_URL = "https://github.com/dariamarc/brainTumorSurvival.git"
# ============================================================================

repo_name = GITHUB_REPO_URL.split('/')[-1].replace('.git', '')

os.chdir('/content')

if os.path.exists(f'/content/{repo_name}'):
    !rm -rf /content/{repo_name}

!git clone {GITHUB_REPO_URL}

os.chdir(f'/content/{repo_name}')
print(f"Working directory: {os.getcwd()}")

## Step 2: Download Data from AWS S3

In [None]:
from google.colab import userdata
import os

# ============================================================================
# UPDATE THESE WITH YOUR S3 DETAILS
# ============================================================================
S3_BUCKET = 'your-brats2020-data'
S3_PATH = 'preprocessed_data'
AWS_REGION = 'eu-central-1'
# ============================================================================

LOCAL_PATH = '/content/brainTumorData_preprocessed'

# Load AWS credentials from Colab Secrets
os.environ['AWS_ACCESS_KEY_ID'] = userdata.get('AWS_ACCESS_KEY_ID')
os.environ['AWS_SECRET_ACCESS_KEY'] = userdata.get('AWS_SECRET_ACCESS_KEY')
os.environ['AWS_DEFAULT_REGION'] = AWS_REGION

!mkdir -p {LOCAL_PATH}
!aws s3 sync s3://{S3_BUCKET}/{S3_PATH} {LOCAL_PATH}

DATA_PATH = LOCAL_PATH
print(f"Data downloaded to: {DATA_PATH}")

## Step 3: Import Modules

In [None]:
import sys
import tensorflow as tf

repo_dir = f'/content/{repo_name}'
resnet_dir = f'/content/{repo_name}/ResNet_architecture'

sys.path.insert(0, repo_dir)
sys.path.insert(0, resnet_dir)

from prototype_segnet3d import create_prototype_segnet3d
from trainer import PrototypeTrainer
from data_processing.data_generator import MRIDataGenerator

print("Modules imported.")

## Step 4: Configuration

In [None]:
# ============================================================================
# TRAINING CONFIGURATION
# ============================================================================

# Data
BATCH_SIZE = 1
SPLIT_RATIO = 0.2
RANDOM_STATE = 42
NUM_VOLUMES = 369
NUM_SLICES = 128  # Preprocessed data has 128 slices per volume

# Volume dimensions (must match preprocessed data)
# Preprocessing crops from (155, 240, 240) to (128, 160, 192)
D = 128  # Depth (number of slices)
H = 160  # Height (cropped from 240)
W = 192  # Width (cropped from 240)
C = 4    # Channels (FLAIR, T1, T1ce, T2)

NUM_CLASSES = 4
N_PROTOTYPES = 3

# Model
BACKBONE_CHANNELS = 64
ASPP_OUT_CHANNELS = 256
DILATION_RATES = (2, 4, 8)

# Training epochs per phase
PHASE1_EPOCHS = 50
PHASE2_EPOCHS = 150
PHASE3_EPOCHS = 30

# ============================================================================

INPUT_SHAPE = (D, H, W, C)
print(f"Input shape: {INPUT_SHAPE}")
print(f"Classes: {NUM_CLASSES}, Prototypes: {N_PROTOTYPES}")

## Step 5: Create Data Generators

In [None]:
train_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=NUM_SLICES,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='train',
    shuffle=True,
    random_state=RANDOM_STATE
)

val_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=NUM_SLICES,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='val',
    shuffle=False,
    random_state=RANDOM_STATE
)

print(f"Training batches: {len(train_generator)}")
print(f"Validation batches: {len(val_generator)}")

## Step 6: Build Model

In [None]:
model = create_prototype_segnet3d(
    input_shape=INPUT_SHAPE,
    num_classes=NUM_CLASSES,
    n_prototypes=N_PROTOTYPES,
    backbone_channels=BACKBONE_CHANNELS,
    aspp_out_channels=ASPP_OUT_CHANNELS,
    dilation_rates=DILATION_RATES,
    distance_type='l2',
    activation_function='log'
)

# Initialize weights
dummy_input = tf.zeros((1,) + INPUT_SHAPE)
_ = model(dummy_input, training=False)

print("Model built.")
model.summary()

## Step 7: Setup Google Drive for Checkpoints

In [None]:
from google.colab import drive
import datetime

drive.mount('/content/drive')

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
DRIVE_CHECKPOINT_DIR = f'/content/drive/MyDrive/prototype_segnet_checkpoints/{timestamp}'
LOCAL_CHECKPOINT_DIR = '/content/checkpoints'

!mkdir -p {DRIVE_CHECKPOINT_DIR}
!mkdir -p {LOCAL_CHECKPOINT_DIR}

print(f"Drive checkpoint dir: {DRIVE_CHECKPOINT_DIR}")

## Step 8: Create Trainer

In [None]:
trainer = PrototypeTrainer(
    model=model,
    train_generator=train_generator,
    val_generator=val_generator,
    checkpoint_dir=LOCAL_CHECKPOINT_DIR
)

print("Trainer created.")

## Step 9: Phase 1 - Warm-up Training

In [None]:
trainer.train_phase1(epochs=PHASE1_EPOCHS)

# Save to Google Drive
model.save(f'{DRIVE_CHECKPOINT_DIR}/model_after_phase1.keras')
print(f"Phase 1 model saved to Google Drive.")

## Step 10: Phase 2 - Joint Fine-tuning

In [None]:
trainer.train_phase2(epochs=PHASE2_EPOCHS)

# Save to Google Drive
model.save(f'{DRIVE_CHECKPOINT_DIR}/model_after_phase2.keras')
print(f"Phase 2 model saved to Google Drive.")

## Step 11: Phase 3 - Prototype Projection & Refinement

In [None]:
trainer.train_phase3(epochs=PHASE3_EPOCHS)

# Save final model to Google Drive
model.save(f'{DRIVE_CHECKPOINT_DIR}/model_final.keras')
print(f"Final model saved to Google Drive.")

## Step 12: Save Training History

In [None]:
import json

history = trainer.get_full_history()

with open(f'{DRIVE_CHECKPOINT_DIR}/training_history.json', 'w') as f:
    json.dump(history, f, indent=2)

print(f"Training complete. All files saved to: {DRIVE_CHECKPOINT_DIR}")
!ls -la {DRIVE_CHECKPOINT_DIR}