# Brain-to-Text Training on AWS SageMaker

This notebook trains the RNN baseline model for brain-to-text decoding using data from AWS S3.

## Setup and Dependencies


In [None]:
# Install required packages
%pip install omegaconf h5py torchaudio boto3 s3fs


In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
import h5py
import numpy as np
import os
import time
import logging
import json
import pickle
import math
import random
from torch.nn.utils.rnn import pad_sequence
import torchaudio.functional as F
from omegaconf import OmegaConf
from pathlib import Path
import boto3
import s3fs
from botocore.exceptions import ClientError

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')


## AWS S3 Configuration


In [None]:
# AWS S3 Configuration
S3_BUCKET_NAME = '4k-woody-btt'
S3_DATA_PREFIX = '4k/data/'
LOCAL_DATA_DIR = '/tmp/brain_to_text_data'

# Initialize S3 client
s3_client = boto3.client('s3')
s3_fs = s3fs.S3FileSystem()

print(f"S3 Bucket: {S3_BUCKET_NAME}")
print(f"Data prefix: {S3_DATA_PREFIX}")
print(f"Local data directory: {LOCAL_DATA_DIR}")


## S3 Direct Access Setup


In [None]:
def list_s3_data_files(bucket_name, prefix):
    """
    List all HDF5 files in S3 bucket prefix
    """
    try:
        # List all objects in the S3 prefix
        paginator = s3_client.get_paginator('list_objects_v2')
        pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
        
        h5_files = []
        
        for page in pages:
            if 'Contents' in page:
                for obj in page['Contents']:
                    s3_key = obj['Key']
                    
                    # Only include HDF5 files
                    if s3_key.endswith('.hdf5'):
                        h5_files.append(s3_key)
        
        print(f"✅ Found {len(h5_files)} HDF5 files in S3")
        return h5_files
        
    except ClientError as e:
        print(f"❌ Error listing S3 files: {e}")
        return []

# List available training data files
print("Scanning S3 for training data files...")
s3_files = list_s3_data_files(S3_BUCKET_NAME, S3_DATA_PREFIX)

# Show some example files
if s3_files:
    print("\nExample S3 files:")
    for file_path in s3_files[:5]:  # Show first 5 files
        print(f"  - s3://{S3_BUCKET_NAME}/{file_path}")
    if len(s3_files) > 5:
        print(f"  ... and {len(s3_files) - 5} more files")
else:
    print("❌ No HDF5 files found in S3. Please check your bucket and prefix configuration.")


## Clone Repository and Setup


In [None]:
# Clone the repository
!git clone https://github.com/Neuroprosthetics-Lab/nejm-brain-to-text.git


In [None]:
# Change to the model_training directory
import sys
import os

# Add the model_training directory to Python path
sys.path.append('/home/ec2-user/SageMaker/nejm-brain-to-text/model_training')

# Change to the model_training directory
os.chdir('/home/ec2-user/SageMaker/nejm-brain-to-text/model_training')

print("Current working directory:", os.getcwd())
print("Python path updated for model_training module")


## Configure Training Parameters


In [None]:
# Load and modify configuration for SageMaker with S3 direct access
from omegaconf import OmegaConf

print("Loading original configuration...")
args = OmegaConf.load('rnn_args.yaml')

# Update configuration for S3 direct access
args.dataset.s3_bucket = S3_BUCKET_NAME
args.dataset.s3_prefix = S3_DATA_PREFIX
args.dataset.use_s3_direct = True  # Flag to use S3 direct access

# Update output directories for SageMaker
args.output_dir = '/home/ec2-user/SageMaker/trained_models/baseline_rnn'
args.checkpoint_dir = '/home/ec2-user/SageMaker/trained_models/baseline_rnn/checkpoint'

# Create output directories
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.checkpoint_dir, exist_ok=True)

print("Configuration updated for SageMaker with S3 direct access:")
print(f"  S3 Bucket: {args.dataset.s3_bucket}")
print(f"  S3 Prefix: {args.dataset.s3_prefix}")
print(f"  Use S3 Direct: {args.dataset.use_s3_direct}")
print(f"  Output directory: {args.output_dir}")
print(f"  Checkpoint directory: {args.checkpoint_dir}")

# Optional: Resume from checkpoint if available
# args.init_from_checkpoint = True
# args.init_checkpoint_path = '/home/ec2-user/SageMaker/trained_models/baseline_rnn/checkpoint/best_checkpoint'


## Start Training


In [None]:
# Import and run the training with S3 direct access
from s3_rnn_trainer import S3BrainToTextDecoder_Trainer

print("Starting training with S3 direct access...")
print(f"S3 Bucket: {args.dataset.s3_bucket}")
print(f"S3 Prefix: {args.dataset.s3_prefix}")
print(f"Output directory: {args.output_dir}")
print(f"Number of training batches: {args.num_training_batches}")
print(f"Batch size: {args.dataset.batch_size}")
print(f"Learning rate: {args.lr_max}")

# Create the S3 trainer and run training
trainer = S3BrainToTextDecoder_Trainer(args)
metrics = trainer.train()

print("\n✅ Training completed!")
print(f"Final metrics: {metrics}")

# Clean up cached files
trainer.cleanup()


## Upload Results to S3 (Optional)


In [None]:
def upload_to_s3(local_dir, bucket_name, s3_prefix):
    """
    Upload local directory contents to S3
    """
    try:
        for root, dirs, files in os.walk(local_dir):
            for file in files:
                local_file_path = os.path.join(root, file)
                relative_path = os.path.relpath(local_file_path, local_dir)
                s3_key = f"{s3_prefix}{relative_path}"
                
                print(f"Uploading: {local_file_path} -> s3://{bucket_name}/{s3_key}")
                s3_client.upload_file(local_file_path, bucket_name, s3_key)
        
        print("\n✅ Successfully uploaded results to S3")
        
    except ClientError as e:
        print(f"❌ Error uploading to S3: {e}")

# Upload trained models and results to S3
print("Uploading training results to S3...")
upload_to_s3(args.output_dir, S3_BUCKET_NAME, 'training_results/baseline_rnn/')


## Training Summary


In [None]:
# Display training summary
print("\n" + "="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"Dataset: S3 bucket '{S3_BUCKET_NAME}' at '{S3_DATA_PREFIX}'")
print(f"Data access: Direct S3 access (no local download)")
print(f"Output directory: {args.output_dir}")
print(f"Training batches: {args.num_training_batches}")
print(f"Batch size: {args.dataset.batch_size}")
print(f"Model architecture: {args.model.n_layers} GRU layers with {args.model.n_units} units each")
print(f"Final metrics: {metrics}")
print("="*50)
print("✅ Training completed successfully with S3 direct access!")
print("💡 Benefits: No need to download entire dataset, faster startup, less storage usage")
print("="*50)
