# 3D Brain Tumor Segmentation with MProtoNet - AWS S3 Deployment

This notebook deploys your 3D brain tumor segmentation network from GitHub with data from AWS S3.

**Quick Setup:**
1. Upload this notebook to Google Colab
2. Go to Runtime ‚Üí Change runtime type ‚Üí Select **GPU (T4 or better)**
3. Add AWS credentials to Colab Secrets (üîë icon on left)
4. Update GitHub repository URL and S3 bucket details
5. Run all cells

---

## Step 1: Check GPU and System Info

In [None]:
import tensorflow as tf
import sys

print("=" * 60)
print("SYSTEM INFORMATION")
print("=" * 60)
print(f"Python version: {sys.version}")
print(f"TensorFlow version: {tf.__version__}")
print(f"\nGPU Devices: {tf.config.list_physical_devices('GPU')}")

# Enable memory growth to prevent OOM errors
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        print(f"\n‚úì Memory growth enabled for {len(gpus)} GPU(s)")
        print(f"‚úì GPU: {gpus[0].name}")
    except RuntimeError as e:
        print(f"Error enabling memory growth: {e}")
else:
    print("\n‚ö†Ô∏è WARNING: No GPU detected! Please enable GPU in Runtime ‚Üí Change runtime type")

print("=" * 60)

SYSTEM INFORMATION
Python version: 3.12.12 (main, Oct 10 2025, 08:52:57) [GCC 11.4.0]
TensorFlow version: 2.19.0

GPU Devices: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

‚úì Memory growth enabled for 1 GPU(s)
‚úì GPU: /physical_device:GPU:0


## Step 2: Check Available Resources

In [None]:
print("=" * 60)
print("AVAILABLE RESOURCES")
print("=" * 60)

print("\nüì¶ Disk Space:")
!df -h /content | grep -E 'Filesystem|/content'

print("\nüß† RAM:")
!free -h | grep -E 'total|Mem'

print("\nüéÆ GPU Memory:")
!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv

print("\n" + "=" * 60)
print("Note: Your 7.95 GB dataset will use disk space, not RAM")
print("Only one batch (~200 MB) is loaded in RAM at a time")
print("=" * 60)

AVAILABLE RESOURCES

üì¶ Disk Space:
Filesystem      Size  Used Avail Use% Mounted on

üß† RAM:
               total        used        free      shared  buff/cache   available
Mem:            52Gi       1.4Gi        47Gi       1.0Mi       3.7Gi        50Gi

üéÆ GPU Memory:
name, memory.total [MiB], memory.free [MiB]
NVIDIA L4, 23034 MiB, 22689 MiB

Note: Your 7.95 GB dataset will use disk space, not RAM
Only one batch (~200 MB) is loaded in RAM at a time


## Step 3: Install Dependencies

In [None]:
%%capture
# Silent installation - remove %%capture to see output
!pip install h5py numpy tensorflow keras matplotlib awscli boto3 -q

## Step 4: Clone GitHub Repository

**‚ö†Ô∏è IMPORTANT: Update the repository URL below with your GitHub repository!**

In [None]:
import os

# ============================================================================
# UPDATE THIS WITH YOUR GITHUB REPOSITORY URL
# ============================================================================
GITHUB_REPO_URL = "https://github.com/dariamarc/brainTumorSurvival.git"
# ============================================================================

# Repository name (extracted from URL)
repo_name = GITHUB_REPO_URL.split('/')[-1].replace('.git', '')

print(f"Cloning repository: {GITHUB_REPO_URL}")
print(f"Repository name: {repo_name}")
print("-" * 60)

# Remove if exists (for re-running)
if os.path.exists(f'/content/{repo_name}'):
    !rm -rf /content/{repo_name}
    print(f"Removed existing directory: {repo_name}")

# Clone the repository
!git clone {GITHUB_REPO_URL}

# Change to repository directory
os.chdir(f'/content/{repo_name}')
print(f"\n‚úì Changed to directory: {os.getcwd()}")

# List files to verify
print("\nRepository contents:")
!ls -la

Cloning repository: https://github.com/dariamarc/brainTumorSurvival.git
Repository name: brainTumorSurvival
------------------------------------------------------------
Cloning into 'brainTumorSurvival'...
remote: Enumerating objects: 45, done.[K
remote: Counting objects: 100% (45/45), done.[K
remote: Compressing objects: 100% (34/34), done.[K
remote: Total 45 (delta 19), reused 36 (delta 10), pack-reused 0 (from 0)[K
Receiving objects: 100% (45/45), 238.36 KiB | 1.34 MiB/s, done.
Resolving deltas: 100% (19/19), done.

‚úì Changed to directory: /content/brainTumorSurvival

Repository contents:
total 388
drwxr-xr-x 4 root root   4096 Nov  5 17:47 .
drwxr-xr-x 1 root root   4096 Nov  5 17:47 ..
-rw-r--r-- 1 root root   7316 Nov  5 17:47 data_generator.py
-rw-r--r-- 1 root root 332985 Nov  5 17:47 data_processing.ipynb
drwxr-xr-x 8 root root   4096 Nov  5 17:47 .git
-rw-r--r-- 1 root root     82 Nov  5 17:47 .gitignore
drwxr-xr-x 3 root root   4096 Nov  5 17:47 .idea
-rw-r--r-- 1 root

## Step 5: Verify Required Files

In [None]:
required_files = ['model.py', 'data_generator.py', 'losses.py']
optional_files = ['main.py', 'train.py']

print("Checking required files...")
print("=" * 60)

all_present = True
for file in required_files:
    if os.path.exists(file):
        print(f"‚úì {file} - Found")
    else:
        print(f"‚úó {file} - MISSING")
        all_present = False

print("\nChecking optional files...")
for file in optional_files:
    if os.path.exists(file):
        print(f"‚úì {file} - Found")
    else:
        print(f"- {file} - Not present (optional)")

print("=" * 60)
if all_present:
    print("‚úì All required files present! Ready to proceed.")
else:
    print("‚ö†Ô∏è WARNING: Some required files are missing!")
    print("Please check your repository structure.")

Checking required files...
‚úì model.py - Found
‚úì data_generator.py - Found
‚úì losses.py - Found

Checking optional files...
‚úì main.py - Found
- train.py - Not present (optional)
‚úì All required files present! Ready to proceed.


## Step 6: Configure AWS Credentials

**IMPORTANT SECURITY STEPS:**

1. Click the **üîë Secrets** icon in the left sidebar
2. Add these secrets:
   - Name: `AWS_ACCESS_KEY_ID`, Value: Your AWS access key
   - Name: `AWS_SECRET_ACCESS_KEY`, Value: Your AWS secret key
3. Enable "Notebook access" for both secrets

**Never hardcode credentials in notebooks!**

In [None]:
from google.colab import userdata
import os

print("Configuring AWS credentials...")
print("-" * 60)

try:
    # Get credentials from Colab Secrets
    os.environ['AWS_ACCESS_KEY_ID'] = userdata.get('AWS_ACCESS_KEY_ID')
    os.environ['AWS_SECRET_ACCESS_KEY'] = userdata.get('AWS_SECRET_ACCESS_KEY')

    print("‚úì AWS credentials loaded from Colab Secrets")
    print("‚úì Access Key ID: " + os.environ['AWS_ACCESS_KEY_ID'][:8] + "...")

except Exception as e:
    print("‚úó Error loading AWS credentials from Colab Secrets")
    print(f"Error: {e}")
    print("\nPlease add AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY to Colab Secrets (üîë icon)")
    raise

Configuring AWS credentials...
------------------------------------------------------------
‚úì AWS credentials loaded from Colab Secrets
‚úì Access Key ID: AKIAYAYR...


## Step 7: Download Dataset from AWS S3

**‚ö†Ô∏è UPDATE YOUR S3 BUCKET DETAILS BELOW**

This will download your 7.95 GB dataset to Colab's local storage.  
Estimated time: 10-15 minutes

In [None]:
# ============================================================================
# UPDATE THESE WITH YOUR S3 DETAILS
# ============================================================================
S3_BUCKET = 'your-brats2020-data'           # Your S3 bucket name
S3_PATH = 'archive/BraTS2020_training_data/content/data'               # Path to data in S3 (no leading/trailing slashes)
AWS_REGION = 'eu-central-1'                 # Your bucket's region
# ============================================================================

LOCAL_PATH = '/content/brainTumorData'

print("=" * 60)
print("DOWNLOADING DATASET FROM AWS S3")
print("=" * 60)

# Set AWS region
os.environ['AWS_DEFAULT_REGION'] = AWS_REGION

print(f"\nSource: s3://{S3_BUCKET}/{S3_PATH}")
print(f"Destination: {LOCAL_PATH}")
print(f"Region: {AWS_REGION}")
print(f"Dataset size: 7.95 GB")
print(f"Estimated time: 10-15 minutes")
print("-" * 60)
print("Starting download...\n")

# Create local directory
!mkdir -p {LOCAL_PATH}

# Download using AWS CLI sync (shows progress)
!aws s3 sync s3://{S3_BUCKET}/{S3_PATH} {LOCAL_PATH}

# Verify download
print("\n" + "=" * 60)
if os.path.exists(LOCAL_PATH):
    # Count files
    file_count = sum([len(files) for r, d, files in os.walk(LOCAL_PATH)])

    # Calculate size
    total_size = sum(
        os.path.getsize(os.path.join(dirpath, filename))
        for dirpath, dirnames, filenames in os.walk(LOCAL_PATH)
        for filename in filenames
    ) / (1024**3)  # Convert to GB

    print("‚úì DOWNLOAD COMPLETE!")
    print("=" * 60)
    print(f"Location: {LOCAL_PATH}")
    print(f"Files downloaded: {file_count:,}")
    print(f"Total size: {total_size:.2f} GB")

    # Show sample files
    print("\nSample files:")
    !ls {LOCAL_PATH} | head -10

    # Check disk space after download
    print("\nüì¶ Disk Usage After Download:")
    !df -h /content | grep -E 'Filesystem|/content'

    # Set data path
    DATA_PATH = LOCAL_PATH
    print(f"\n‚úì DATA_PATH set to: {DATA_PATH}")
else:
    print("‚úó DOWNLOAD FAILED!")
    print("Please check:")
    print("  1. S3 bucket name is correct")
    print("  2. S3 path is correct")
    print("  3. AWS credentials have read permissions")
    print("  4. AWS region is correct")
    raise FileNotFoundError(f"Data not found at {LOCAL_PATH}")

print("=" * 60)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
download: s3://your-brats2020-data/archive/BraTS2020_training_data/content/data/volume_70_slice_87.h5 to ../brainTumorData/volume_70_slice_87.h5
download: s3://your-brats2020-data/archive/BraTS2020_training_data/content/data/volume_70_slice_83.h5 to ../brainTumorData/volume_70_slice_83.h5
download: s3://your-brats2020-data/archive/BraTS2020_training_data/content/data/volume_70_slice_89.h5 to ../brainTumorData/volume_70_slice_89.h5
download: s3://your-brats2020-data/archive/BraTS2020_training_data/content/data/volume_70_slice_9.h5 to ../brainTumorData/volume_70_slice_9.h5
download: s3://your-brats2020-data/archive/BraTS2020_training_data/content/data/volume_70_slice_90.h5 to ../brainTumorData/volume_70_slice_90.h5
download: s3://your-brats2020-data/archive/BraTS2020_training_data/content/data/volume_70_slice_85.h5 to ../brainTumorData/volume_70_slice_85.h5
download: s3://your-brats2020-data/archive/BraTS2020_training_data/

## Step 8: Import Modules

In [None]:
import sys

# Ensure repository is in Python path
repo_dir = f'/content/{repo_name}'
if repo_dir not in sys.path:
    sys.path.insert(0, repo_dir)

print(f"Python path includes: {repo_dir}")
print(f"Working directory: {os.getcwd()}")
print("-" * 60)

# Import your modules
try:
    from model import MProtoNet3D_Segmentation_Keras
    from data_generator import MRIDataGenerator
    from losses import FocalLoss, CombinedLoss
    from tensorflow import keras
    import numpy as np

    print("‚úì All modules imported successfully!")
except ImportError as e:
    print(f"‚úó Import error: {e}")
    print("\nDebugging info:")
    print("Files in repository:")
    !ls -la
    raise

## Step 9: Training Configuration

Adjust these parameters based on your needs.

In [None]:
# ============================================================================
# TRAINING CONFIGURATION - ADJUST AS NEEDED
# ============================================================================

# Data configuration
BATCH_SIZE = 2          # Batch size of 2 for better batch normalization stability
SPLIT_RATIO = 0.2       # 20% for validation
RANDOM_STATE = 42
NUM_VOLUMES = 369       # Total number of volumes (adjust if different)

# Volume dimensions
D = 155                 # Depth (number of slices)
H = 240                 # Height
W = 240                 # Width
C = 4                   # Channels (FLAIR, T1, T1ce, T2)

# Model configuration
NUM_CLASSES = 3         # GD enhancing tumor, peritumoral edema, non-enhancing tumor core
PROTOTYPE_SHAPE = (21, 128, 1, 1, 1)  # 21/3 = 7 prototypes per class

# Training settings
EPOCHS = 100            # Number of training epochs (early stopping will handle when to stop)
LEARNING_RATE = 0.0001  # Initial learning rate

# ============================================================================

INPUT_SHAPE = (D, H, W, C)

print("=" * 60)
print("TRAINING CONFIGURATION")
print("=" * 60)
print(f"Data path: {DATA_PATH}")
print(f"Input shape: {INPUT_SHAPE}")
print(f"Number of classes: {NUM_CLASSES}")
print(f"Prototypes per class: {PROTOTYPE_SHAPE[0] // NUM_CLASSES}")
print(f"Batch size: {BATCH_SIZE}")
print(f"Epochs: {EPOCHS}")
print(f"Learning rate: {LEARNING_RATE}")
print(f"Total volumes: {NUM_VOLUMES}")
print(f"Train/Val split: {int((1-SPLIT_RATIO)*100)}% / {int(SPLIT_RATIO*100)}%")
print("=" * 60)

## Step 10: Create Data Generators

In [12]:
print("Creating data generators...")
print("-" * 60)

train_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=D,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='train',
    shuffle=True,
    random_state=RANDOM_STATE
)

validation_generator = MRIDataGenerator(
    DATA_PATH,
    batch_size=BATCH_SIZE,
    num_slices=D,
    num_volumes=NUM_VOLUMES,
    split_ratio=SPLIT_RATIO,
    subset='val',
    shuffle=False,
    random_state=RANDOM_STATE
)

print(f"\n‚úì Training batches: {len(train_generator)}")
print(f"‚úì Validation batches: {len(validation_generator)}")
print(f"\nEstimated time per epoch (T4 GPU): ~{len(train_generator) * 2.5:.0f} minutes")
print(f"Total estimated training time: ~{len(train_generator) * 2.5 * EPOCHS / 60:.1f} hours")

Creating data generators...
------------------------------------------------------------
MRIDataGenerator: Initializing for H5 files from: /content/brainTumorData
MRIDataGenerator: Found 369 unique volume IDs (0 to 368).
MRIDataGenerator: train subset has 295 volumes (each containing 155 slices).
MRIDataGenerator: Initializing for H5 files from: /content/brainTumorData
MRIDataGenerator: Found 369 unique volume IDs (0 to 368).
MRIDataGenerator: val subset has 74 volumes (each containing 155 slices).

‚úì Training batches: 147
‚úì Validation batches: 37

Estimated time per epoch (T4 GPU): ~368 minutes
Total estimated training time: ~61.2 hours


## Step 11: Build and Compile Model

In [None]:
print("Building model...")
print("-" * 60)

# Build the MProtoNet3D model
model = MProtoNet3D_Segmentation_Keras(
    in_size=INPUT_SHAPE,
    num_classes=NUM_CLASSES,
    prototype_shape=PROTOTYPE_SHAPE,
    features='resnet50_ri',
    f_dist='l2'
)

print("‚úì Model architecture created!")

# Build the model explicitly to enable parameter counting
print("Initializing model layers...")
model.build(input_shape=(None,) + INPUT_SHAPE)
print("‚úì Model built successfully!")

# Setup optimizer and loss
optimizer = keras.optimizers.Adam(learning_rate=LEARNING_RATE)

# Use Combined Focal + Dice Loss (RECOMMENDED for medical segmentation)
loss_fn = CombinedLoss(focal_weight=0.5, dice_weight=0.5, gamma=1.0, alpha=0.25)

# Compile model with comprehensive metrics
model.compile(
    optimizer=optimizer,
    loss=loss_fn,
    metrics=[
        'accuracy',
        keras.metrics.MeanIoU(num_classes=NUM_CLASSES, name='mean_iou'),
        keras.metrics.Precision(name='precision'),
        keras.metrics.Recall(name='recall')
    ]
)

print("‚úì Model compiled successfully!")
print("\nLoss function: Combined Focal + Dice Loss")
print("  - Focal loss gamma: 1.0")
print("  - Focal loss alpha: 0.25")
print("  - Loss weights: 50% Focal + 50% Dice")
print("\nMetrics tracked:")
print("  - Accuracy (overall voxel accuracy)")
print("  - Mean IoU (Intersection over Union per class)")
print("  - Precision & Recall")

# Count parameters
try:
    total_params = model.count_params()
    print(f"\nTotal trainable parameters: {total_params:,}")
except:
    print("\nNote: Parameter count will be available after first training step")

## Step 12: Setup Callbacks and Checkpointing

**IMPORTANT:** We'll save checkpoints to Google Drive for persistence across sessions.

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau, TensorBoard, CSVLogger
import datetime

# Mount Google Drive for checkpoint storage
from google.colab import drive
drive.mount('/content/drive')

# Create directories
checkpoint_dir = '/content/checkpoints'
logs_dir = '/content/logs'
!mkdir -p {checkpoint_dir}
!mkdir -p {logs_dir}

# Google Drive checkpoint directory (for persistence)
drive_checkpoint_dir = '/content/drive/MyDrive/brain_tumor_checkpoints'
!mkdir -p {drive_checkpoint_dir}

timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")

callbacks = [
    # Save best model locally
    ModelCheckpoint(
        filepath=f'{checkpoint_dir}/best_model.keras',
        monitor='val_loss',
        save_best_only=True,
        mode='min',
        verbose=1
    ),

    # Save best model to Google Drive (IMPORTANT for persistence)
    ModelCheckpoint(
        filepath=f'{drive_checkpoint_dir}/best_model_{timestamp}.keras',
        monitor='val_loss',
        save_best_only=True,
        mode='min',
        verbose=1
    ),

    # Save periodic checkpoints to Google Drive (every epoch)
    ModelCheckpoint(
        filepath=f'{drive_checkpoint_dir}/checkpoint_epoch_{{epoch:02d}}_{timestamp}.keras',
        save_freq='epoch',
        save_best_only=False,
        verbose=1
    ),

    # Early stopping - wait 10 epochs before stopping
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),

    # Reduce learning rate on plateau - wait 5 epochs before reducing
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),

    # TensorBoard logging
    TensorBoard(
        log_dir=f'{logs_dir}/{timestamp}',
        histogram_freq=1,
        write_graph=True
    ),

    # CSV Logger
    CSVLogger(
        filename=f'{drive_checkpoint_dir}/training_log_{timestamp}.csv',
        append=True
    )
]

print("‚úì Callbacks configured!")
print(f"  - Local checkpoints: {checkpoint_dir}")
print(f"  - Drive backups: {drive_checkpoint_dir}")
print(f"  - TensorBoard logs: {logs_dir}")
print(f"  - Timestamp: {timestamp}")
print("\nCallback settings:")
print("  - EarlyStopping: patience=10 epochs")
print("  - ReduceLROnPlateau: patience=5 epochs, factor=0.5")

## Step 13: Train the Model

**This will take several hours. Keep the browser tab active to prevent disconnection!**

In [15]:
print("=" * 60)
print("STARTING TRAINING")
print("=" * 60)
print(f"Training for {EPOCHS} epochs")
print(f"Estimated total time: ~{len(train_generator) * 2.5 * EPOCHS / 60:.1f} hours (T4 GPU)")
print("\n‚ö†Ô∏è IMPORTANT: Keep this browser tab active to prevent disconnection!")
print("‚ö†Ô∏è Models are being saved to Google Drive automatically")
print("=" * 60)
print()

history = model.fit(
    train_generator,
    epochs=EPOCHS,
    validation_data=validation_generator,
    callbacks=callbacks,
    verbose=1
)

print("\n" + "=" * 60)
print("‚úì TRAINING COMPLETED!")
print("=" * 60)

STARTING TRAINING
Training for 10 epochs
Estimated total time: ~61.2 hours (T4 GPU)

‚ö†Ô∏è IMPORTANT: Keep this browser tab active to prevent disconnection!
‚ö†Ô∏è Models are being saved to Google Drive automatically



  self._warn_if_super_not_called()


Padding volume: 155 -> 160 (pad_before=2, pad_after=3)
Padded volume shape: (160, 240, 240, 4)
Padding volume: 155 -> 160 (pad_before=2, pad_after=3)
Padded volume shape: (160, 240, 240, 3)
After padding - Image shape: (160, 240, 240, 4), Mask shape: (160, 240, 240, 3)
Padding volume: 155 -> 160 (pad_before=2, pad_after=3)
Padded volume shape: (160, 240, 240, 4)
Padding volume: 155 -> 160 (pad_before=2, pad_after=3)
Padded volume shape: (160, 240, 240, 3)
After padding - Image shape: (160, 240, 240, 4), Mask shape: (160, 240, 240, 3)
Final batch shapes - Images: (2, 160, 240, 240, 4), Masks: (2, 160, 240, 240, 3)
Padding volume: 155 -> 160 (pad_before=2, pad_after=3)
Padded volume shape: (160, 240, 240, 4)
Padding volume: 155 -> 160 (pad_before=2, pad_after=3)
Padded volume shape: (160, 240, 240, 3)
After padding - Image shape: (160, 240, 240, 4), Mask shape: (160, 240, 240, 3)
Padding volume: 155 -> 160 (pad_before=2, pad_after=3)
Padded volume shape: (160, 240, 240, 4)
Padding volume

UnknownError: Graph execution error:

Detected at node StatefulPartitionedCall defined at (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main

  File "<frozen runpy>", line 88, in _run_code

  File "/usr/local/lib/python3.12/dist-packages/colab_kernel_launcher.py", line 37, in <module>

  File "/usr/local/lib/python3.12/dist-packages/traitlets/config/application.py", line 992, in launch_instance

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelapp.py", line 712, in start

  File "/usr/local/lib/python3.12/dist-packages/tornado/platform/asyncio.py", line 211, in start

  File "/usr/lib/python3.12/asyncio/base_events.py", line 645, in run_forever

  File "/usr/lib/python3.12/asyncio/base_events.py", line 1999, in _run_once

  File "/usr/lib/python3.12/asyncio/events.py", line 88, in _run

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 510, in dispatch_queue

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 499, in process_one

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 406, in dispatch_shell

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/kernelbase.py", line 730, in execute_request

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/ipkernel.py", line 383, in do_execute

  File "/usr/local/lib/python3.12/dist-packages/ipykernel/zmqshell.py", line 528, in run_cell

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 2975, in run_cell

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3030, in _run_cell

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/async_helpers.py", line 78, in _pseudo_sync_runner

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3257, in run_cell_async

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3473, in run_ast_nodes

  File "/usr/local/lib/python3.12/dist-packages/IPython/core/interactiveshell.py", line 3553, in run_code

  File "/tmp/ipython-input-2317829315.py", line 11, in <cell line: 0>

  File "/usr/local/lib/python3.12/dist-packages/keras/src/utils/traceback_utils.py", line 117, in error_handler

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/trainer.py", line 377, in fit

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/trainer.py", line 220, in function

  File "/usr/local/lib/python3.12/dist-packages/keras/src/backend/tensorflow/trainer.py", line 133, in multi_step_on_iterator

Failed to determine best cudnn convolution algorithm for:
%cudnn-conv-bw-input.16 = (f32[2,64,160,240,240]{4,3,2,1,0}, u8[0]{0}) custom-call(f32[2,16,160,240,240]{4,3,2,1,0} %bitcast.16789, f32[16,64,3,3,3]{4,3,2,1,0} %bitcast.16726), window={size=3x3x3 pad=1_1x1_1x1_1}, dim_labels=bf012_oi012->bf012, custom_call_target="__cudnn$convBackwardInput", metadata={op_type="Conv3DBackpropInputV2" op_name="gradient_tape/m_proto_net3d__segmentation__keras_1/final_processing_1/convolution/Conv3DBackpropInputV2" source_file="/usr/local/lib/python3.12/dist-packages/tensorflow/python/framework/ops.py" source_line=1200}, backend_config={"operation_queue_id":"0","wait_on_operation_queues":[],"cudnn_conv_backend_config":{"conv_result_scale":1,"activation_mode":"kNone","side_input_scale":0,"leakyrelu_alpha":0},"force_earliest_schedule":false}

Original error: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 4735369216 bytes. [tf-allocator-allocation-error='']

To ignore this failure and try to use a fallback algorithm (which may have suboptimal performance), use XLA_FLAGS=--xla_gpu_strict_conv_algorithm_picker=false.  Please also file a bug for the root cause of failing autotuning.
	 [[{{node StatefulPartitionedCall}}]] [Op:__inference_multi_step_on_iterator_12651]

## Step 14: Visualize Training History

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(2, 3, figsize=(18, 10))
fig.suptitle('Training Metrics Over Time', fontsize=16, fontweight='bold')

# Loss plot
axes[0, 0].plot(history.history['loss'], label='Training Loss', linewidth=2)
axes[0, 0].plot(history.history['val_loss'], label='Validation Loss', linewidth=2)
axes[0, 0].set_title('Model Loss Over Time', fontsize=14, fontweight='bold')
axes[0, 0].set_xlabel('Epoch', fontsize=12)
axes[0, 0].set_ylabel('Loss', fontsize=12)
axes[0, 0].legend(fontsize=10)
axes[0, 0].grid(True, alpha=0.3)

# Accuracy plot
axes[0, 1].plot(history.history['accuracy'], label='Training Accuracy', linewidth=2)
axes[0, 1].plot(history.history['val_accuracy'], label='Validation Accuracy', linewidth=2)
axes[0, 1].set_title('Model Accuracy Over Time', fontsize=14, fontweight='bold')
axes[0, 1].set_xlabel('Epoch', fontsize=12)
axes[0, 1].set_ylabel('Accuracy', fontsize=12)
axes[0, 1].legend(fontsize=10)
axes[0, 1].grid(True, alpha=0.3)

# Mean IoU plot
axes[0, 2].plot(history.history['mean_iou'], label='Training Mean IoU', linewidth=2)
axes[0, 2].plot(history.history['val_mean_iou'], label='Validation Mean IoU', linewidth=2)
axes[0, 2].set_title('Mean IoU Over Time', fontsize=14, fontweight='bold')
axes[0, 2].set_xlabel('Epoch', fontsize=12)
axes[0, 2].set_ylabel('Mean IoU', fontsize=12)
axes[0, 2].legend(fontsize=10)
axes[0, 2].grid(True, alpha=0.3)

# Precision plot
axes[1, 0].plot(history.history['precision'], label='Training Precision', linewidth=2)
axes[1, 0].plot(history.history['val_precision'], label='Validation Precision', linewidth=2)
axes[1, 0].set_title('Precision Over Time', fontsize=14, fontweight='bold')
axes[1, 0].set_xlabel('Epoch', fontsize=12)
axes[1, 0].set_ylabel('Precision', fontsize=12)
axes[1, 0].legend(fontsize=10)
axes[1, 0].grid(True, alpha=0.3)

# Recall plot
axes[1, 1].plot(history.history['recall'], label='Training Recall', linewidth=2)
axes[1, 1].plot(history.history['val_recall'], label='Validation Recall', linewidth=2)
axes[1, 1].set_title('Recall Over Time', fontsize=14, fontweight='bold')
axes[1, 1].set_xlabel('Epoch', fontsize=12)
axes[1, 1].set_ylabel('Recall', fontsize=12)
axes[1, 1].legend(fontsize=10)
axes[1, 1].grid(True, alpha=0.3)

# Hide the last subplot (we have 5 metrics, 6 subplots)
axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig(f'{drive_checkpoint_dir}/training_history_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"‚úì Training history saved to Google Drive")

# Print final metrics
print("\n" + "=" * 60)
print("FINAL TRAINING METRICS")
print("=" * 60)
print(f"Final training loss: {history.history['loss'][-1]:.4f}")
print(f"Final training accuracy: {history.history['accuracy'][-1]:.4f}")
print(f"Final training mean IoU: {history.history['mean_iou'][-1]:.4f}")
print(f"Final training precision: {history.history['precision'][-1]:.4f}")
print(f"Final training recall: {history.history['recall'][-1]:.4f}")
print("\nFINAL VALIDATION METRICS")
print("=" * 60)
print(f"Final validation loss: {history.history['val_loss'][-1]:.4f}")
print(f"Final validation accuracy: {history.history['val_accuracy'][-1]:.4f}")
print(f"Final validation mean IoU: {history.history['val_mean_iou'][-1]:.4f}")
print(f"Final validation precision: {history.history['val_precision'][-1]:.4f}")
print(f"Final validation recall: {history.history['val_recall'][-1]:.4f}")
print("\nBEST VALIDATION METRICS")
print("=" * 60)
print(f"Best validation loss: {min(history.history['val_loss']):.4f}")
print(f"Best validation accuracy: {max(history.history['val_accuracy']):.4f}")
print(f"Best validation mean IoU: {max(history.history['val_mean_iou']):.4f}")
print(f"Best validation precision: {max(history.history['val_precision']):.4f}")
print(f"Best validation recall: {max(history.history['val_recall']):.4f}")
print("=" * 60)

## Step 15: Save Final Model

In [None]:
# Save final model to Google Drive
final_model_path = f'{drive_checkpoint_dir}/final_model_{timestamp}.keras'
model.save(final_model_path)
print(f"‚úì Final model saved to: {final_model_path}")

# Also save locally
model.save('/content/final_model.keras')
print(f"‚úì Final model also saved locally to: /content/final_model.keras")

print("\n‚úì All models safely stored in Google Drive!")
print("\n" + "=" * 60)
print("To load the best model:")
print("=" * 60)
print("from losses import CombinedLoss")
print("model = keras.models.load_model('best_model.keras',")
print("                                 custom_objects={'CombinedLoss': CombinedLoss})")
print("=" * 60)

## Step 16: Test Prediction and Visualization

In [None]:
# Get a sample from validation set
print("Loading sample for prediction...")
sample_x, sample_y = validation_generator[0]

print(f"Input shape: {sample_x.shape}")
print(f"Label shape: {sample_y.shape}")

# Make prediction
print("\nGenerating prediction...")
prediction = model.predict(sample_x, verbose=0)
print(f"Prediction shape: {prediction.shape}")

# Visualize middle slice
slice_idx = D // 2  # Middle slice

fig, axes = plt.subplots(2, 3, figsize=(15, 10))
fig.suptitle(f'Brain Tumor Segmentation - Slice {slice_idx}', fontsize=16, fontweight='bold')

# Input modalities
axes[0, 0].imshow(sample_x[0, slice_idx, :, :, 0], cmap='gray')
axes[0, 0].set_title('FLAIR', fontsize=12)
axes[0, 0].axis('off')

axes[0, 1].imshow(sample_x[0, slice_idx, :, :, 1], cmap='gray')
axes[0, 1].set_title('T1', fontsize=12)
axes[0, 1].axis('off')

axes[0, 2].imshow(sample_x[0, slice_idx, :, :, 2], cmap='gray')
axes[0, 2].set_title('T1ce', fontsize=12)
axes[0, 2].axis('off')

# Ground truth and prediction
axes[1, 0].imshow(sample_x[0, slice_idx, :, :, 3], cmap='gray')
axes[1, 0].set_title('T2', fontsize=12)
axes[1, 0].axis('off')

axes[1, 1].imshow(np.argmax(sample_y[0, slice_idx], axis=-1), cmap='jet', vmin=0, vmax=NUM_CLASSES-1)
axes[1, 1].set_title('Ground Truth', fontsize=12)
axes[1, 1].axis('off')

axes[1, 2].imshow(np.argmax(prediction[0, slice_idx], axis=-1), cmap='jet', vmin=0, vmax=NUM_CLASSES-1)
axes[1, 2].set_title('Prediction', fontsize=12)
axes[1, 2].axis('off')

plt.tight_layout()
plt.savefig(f'{drive_checkpoint_dir}/prediction_visualization_{timestamp}.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"\n‚úì Prediction visualization saved to Google Drive")

## Step 17: Summary and Results

In [None]:
print("=" * 70)
print(" " * 20 + "TRAINING COMPLETE!" + " " * 20)
print("=" * 70)
print("\n‚úì Your trained models are safely stored in Google Drive:")
print(f"  üìÅ {drive_checkpoint_dir}/")
print("\n‚úì Files saved:")
print(f"  - best_model_{timestamp}.keras")
print(f"  - final_model_{timestamp}.keras")
print(f"  - checkpoint_epoch_*.keras (periodic checkpoints)")
print(f"  - training_log_{timestamp}.csv")
print(f"  - training_history_{timestamp}.png")
print(f"  - prediction_visualization_{timestamp}.png")
print("\nüìä Training Configuration Summary:")
print(f"  - Epochs: {EPOCHS} (early stopping enabled)")
print(f"  - Batch size: {BATCH_SIZE}")
print(f"  - Learning rate: {LEARNING_RATE}")
print(f"  - Loss: Combined Focal + Dice Loss (50%/50%)")
print(f"  - Metrics: Accuracy, Mean IoU, Precision, Recall")
print(f"  - Early stopping patience: 10 epochs")
print(f"  - LR reduction patience: 5 epochs")
print("\nüìä View TensorBoard logs:")
print(f"  Run: %load_ext tensorboard")
print(f"       %tensorboard --logdir {logs_dir}")
print("\nüéØ Next Steps:")
print("  1. Evaluate model on test set")
print("  2. Perform hyperparameter tuning if needed")
print("  3. Generate more visualizations")
print("  4. Export model for deployment")
print("\nüíæ All important files are backed up to Google Drive!")
print("=" * 70)

# List all saved files
print("\nFiles in Google Drive checkpoint directory:")
!ls -lh {drive_checkpoint_dir}

## Optional: Launch TensorBoard

In [None]:
# Uncomment to launch TensorBoard
# %load_ext tensorboard
# %tensorboard --logdir {logs_dir}

## Optional: Download Files Locally

In [None]:
# Uncomment to download files to your computer
# from google.colab import files
# files.download('/content/final_model.keras')
# files.download(f'{drive_checkpoint_dir}/training_history_{timestamp}.png')
# files.download(f'{drive_checkpoint_dir}/prediction_visualization_{timestamp}.png')

## Notes and Tips

### Dataset Storage
- Your 7.95 GB dataset is stored on **disk** (local SSD)
- Only **one batch** (~400 MB with batch_size=2) is loaded in **RAM** at a time
- This is very efficient and won't cause memory issues

### Training Configuration
- **Epochs**: 100 (early stopping will stop training if no improvement for 10 epochs)
- **Batch Size**: 2 (provides better batch normalization stability)
- **Loss Function**: Combined Focal + Dice Loss (optimal for medical segmentation)
  - Focal Loss: Handles class imbalance
  - Dice Loss: Optimizes overlap between predicted and actual segmentation
- **Metrics**: Accuracy, Mean IoU, Precision, Recall for comprehensive evaluation

### Session Management
- **Colab Free**: 12-hour session limit
- **Colab Pro**: 24-hour session limit
- Keep browser tab active to prevent disconnection
- All checkpoints are saved to Google Drive automatically every epoch

### Resuming Training
If your session disconnects, you can resume:
```python
from losses import CombinedLoss

# Load the last checkpoint
checkpoint_path = f'{drive_checkpoint_dir}/checkpoint_epoch_05_{timestamp}.keras'
model = keras.models.load_model(checkpoint_path, custom_objects={'CombinedLoss': CombinedLoss})

# Continue training
history = model.fit(
    train_generator,
    epochs=100,
    initial_epoch=5,  # Start from where you left off
    validation_data=validation_generator,
    callbacks=callbacks
)
```

### Loading Best Model for Inference
```python
from losses import CombinedLoss

# Load best model
best_model_path = f'{drive_checkpoint_dir}/best_model_{timestamp}.keras'
model = keras.models.load_model(best_model_path, custom_objects={'CombinedLoss': CombinedLoss})

# Make predictions
predictions = model.predict(validation_generator)
```

### Performance Tips
- Use **Colab Pro** for better GPUs (V100/A100) and longer sessions
- Monitor GPU usage with: `!nvidia-smi`
- Check disk usage with: `!df -h /content`
- Monitor RAM with: `!free -h`
- All metrics are logged to CSV for offline analysis