# Brain-to-Text Training on AWS SageMaker

This notebook trains the RNN baseline model for brain-to-text decoding using data from AWS S3.

## Setup and Dependencies


In [None]:
# Install required packages
print("Installing required packages...")

# First, check if we need to install PyTorch with CUDA support
import sys
try:
    import torch
    if not torch.cuda.is_available():
        print("⚠️  PyTorch CUDA not available. Installing PyTorch with CUDA support...")
        # Install PyTorch with CUDA support for CUDA 11.8 (common on SageMaker)
        %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
    else:
        print("✅ PyTorch with CUDA is already available")
except ImportError:
    print("Installing PyTorch with CUDA support...")
    %pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

# Install other required packages
%pip install omegaconf h5py boto3 s3fs

print("✅ Package installation complete!")


In [None]:
# Import required libraries
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import LambdaLR
import h5py
import numpy as np
import os
import time
import logging
import json
import pickle
import math
import random
from torch.nn.utils.rnn import pad_sequence
import torchaudio.functional as F
from omegaconf import OmegaConf
from pathlib import Path
import boto3
import s3fs
from botocore.exceptions import ClientError

# Check GPU availability
print("🔍 GPU Availability Check:")
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA version (PyTorch): {torch.version.cuda}")

if torch.cuda.is_available():
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
        print(f"    Memory: {torch.cuda.get_device_properties(i).total_memory / 1024**3:.1f} GB")
        print(f"    Compute Capability: {torch.cuda.get_device_properties(i).major}.{torch.cuda.get_device_properties(i).minor}")
else:
    print("❌ No CUDA GPUs available")
    print("\n🔧 Troubleshooting steps:")
    print("1. Check if NVIDIA drivers are installed:")
    print("   !nvidia-smi")
    print("2. Check if CUDA is properly installed:")
    print("   !nvcc --version")
    print("3. Verify PyTorch CUDA installation:")
    print("   !python -c 'import torch; print(torch.cuda.is_available())'")

# Additional system checks
print(f"\n🖥️  System Information:")
print(f"Python version: {os.sys.version}")
print(f"Platform: {os.sys.platform}")

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'\nUsing device: {device}')


## GPU Diagnostics


In [None]:
# Run GPU diagnostics
print("🔍 Running GPU Diagnostics...")
print("="*50)

# Check NVIDIA drivers
print("1. NVIDIA Driver Check:")
try:
    import subprocess
    result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=10)
    if result.returncode == 0:
        print("✅ NVIDIA drivers are installed")
        print(result.stdout)
    else:
        print("❌ nvidia-smi command failed")
        print("Error:", result.stderr)
except Exception as e:
    print(f"❌ Error running nvidia-smi: {e}")

print("\n" + "="*50)

# Check CUDA installation
print("2. CUDA Installation Check:")
try:
    result = subprocess.run(['nvcc', '--version'], capture_output=True, text=True, timeout=10)
    if result.returncode == 0:
        print("✅ CUDA is installed")
        print(result.stdout)
    else:
        print("❌ nvcc command failed")
        print("Error:", result.stderr)
except Exception as e:
    print(f"❌ Error running nvcc: {e}")

print("\n" + "="*50)

# Check PyTorch CUDA
print("3. PyTorch CUDA Check:")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version (PyTorch): {torch.version.cuda}")
    print(f"Number of GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    print("❌ PyTorch cannot access CUDA")

print("\n" + "="*50)

# Check environment variables
print("4. Environment Variables:")
import os
cuda_vars = ['CUDA_HOME', 'CUDA_PATH', 'LD_LIBRARY_PATH', 'PATH']
for var in cuda_vars:
    value = os.environ.get(var, 'Not set')
    print(f"  {var}: {value}")

print("\n" + "="*50)
print("Diagnostics complete!")


## AWS S3 Configuration


In [None]:
# AWS S3 Configuration
S3_BUCKET_NAME = '4k-woody-btt'
S3_DATA_PREFIX = '4k/data/'
LOCAL_DATA_DIR = '/tmp/brain_to_text_data'

# Initialize S3 client
s3_client = boto3.client('s3')
s3_fs = s3fs.S3FileSystem()

print(f"S3 Bucket: {S3_BUCKET_NAME}")
print(f"Data prefix: {S3_DATA_PREFIX}")
print(f"Local data directory: {LOCAL_DATA_DIR}")


In [None]:
# Test S3 path construction for sessions
print("🔍 Testing S3 Path Construction...")
print("="*50)

# Load the sessions from config
from omegaconf import OmegaConf
import os

# Check current working directory and find the config file
print(f"Current working directory: {os.getcwd()}")
print(f"Files in current directory: {os.listdir('.')}")

# Try different possible locations for the config file
config_paths = [
    'rnn_args.yaml',
    '../rnn_args.yaml', 
    '/home/ec2-user/SageMaker/btt_training/rnn_args.yaml',
    '/home/ec2-user/SageMaker/btt_training/model_training/rnn_args.yaml'
]

config = None
for config_path in config_paths:
    if os.path.exists(config_path):
        print(f"✅ Found config file at: {config_path}")
        config = OmegaConf.load(config_path)
        break
    else:
        print(f"❌ Config file not found at: {config_path}")

if config is None:
    print("❌ Could not find rnn_args.yaml file!")
    print("Please check the file location and update the path.")
else:
    sessions = config.dataset.sessions
    print(f"📋 Sessions from config: {sessions}")
    print(f"📦 S3 Bucket: {S3_BUCKET_NAME}")
    print(f"📁 S3 Prefix: {S3_DATA_PREFIX}")

    print("\n🔍 Testing session paths:")
    for session in sessions[:3]:  # Test first 3 sessions
        s3_path = f"{S3_BUCKET_NAME}/{S3_DATA_PREFIX}{session}"
        print(f"\nSession: {session}")
        print(f"  Constructed path: {s3_path}")
        
        try:
            # Test if the session directory exists
            response = s3_client.list_objects_v2(Bucket=S3_BUCKET_NAME, Prefix=f"{S3_DATA_PREFIX}{session}/", MaxKeys=5)
            if 'Contents' in response:
                print(f"  ✅ Session directory exists")
                print(f"  📁 Files in session:")
                for obj in response['Contents']:
                    print(f"    - {obj['Key']}")
            else:
                print(f"  ❌ Session directory not found")
        except Exception as e:
            print(f"  ❌ Error checking session: {e}")

print("\n" + "="*50)
print("Path construction test complete!")


In [None]:
# Find the rnn_args.yaml file
import os
import glob

print("🔍 Searching for rnn_args.yaml file...")
print(f"Current working directory: {os.getcwd()}")

# Search for the config file in various locations
search_paths = [
    '.',
    '..',
    '/home/ec2-user/SageMaker/btt_training/',
    '/home/ec2-user/SageMaker/btt_training/model_training/',
    '/home/ec2-user/SageMaker/'
]

found_files = []
for search_path in search_paths:
    if os.path.exists(search_path):
        pattern = os.path.join(search_path, '**', 'rnn_args.yaml')
        files = glob.glob(pattern, recursive=True)
        found_files.extend(files)

if found_files:
    print(f"✅ Found {len(found_files)} rnn_args.yaml file(s):")
    for file_path in found_files:
        print(f"  - {file_path}")
    
    # Use the first found file
    CONFIG_FILE_PATH = found_files[0]
    print(f"\n📁 Using config file: {CONFIG_FILE_PATH}")
else:
    print("❌ No rnn_args.yaml file found!")
    print("Please ensure the file exists in the repository.")
    CONFIG_FILE_PATH = None


t according 

## S3 Direct Access Setup


In [None]:
def list_s3_data_files(bucket_name, prefix):
    """
    List all HDF5 files in S3 bucket prefix
    """
    try:
        # List all objects in the S3 prefix
        paginator = s3_client.get_paginator('list_objects_v2')
        pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
        
        h5_files = []
        
        for page in pages:
            if 'Contents' in page:
                for obj in page['Contents']:
                    s3_key = obj['Key']
                    
                    # Only include HDF5 files
                    if s3_key.endswith('.hdf5'):
                        h5_files.append(s3_key)
        
        print(f"✅ Found {len(h5_files)} HDF5 files in S3")
        return h5_files
        
    except ClientError as e:
        print(f"❌ Error listing S3 files: {e}")
        return []

# List available training data files
print("Scanning S3 for training data files...")
s3_files = list_s3_data_files(S3_BUCKET_NAME, S3_DATA_PREFIX)

# Show some example files
if s3_files:
    print("\nExample S3 files:")
    for file_path in s3_files[:5]:  # Show first 5 files
        print(f"  - s3://{S3_BUCKET_NAME}/{file_path}")
    if len(s3_files) > 5:
        print(f"  ... and {len(s3_files) - 5} more files")
else:
    print("❌ No HDF5 files found in S3. Please check your bucket and prefix configuration.")


## Clone Repository and Setup


In [None]:
# Clone the repository
!git clone https://github.com/Neuroprosthetics-Lab/nejm-brain-to-text.git


In [None]:
# Change to the model_training directory
import sys
import os

# Add the model_training directory to Python path
sys.path.append('/home/ec2-user/SageMaker/btt_training/model_training')

# Change to the model_training directory
os.chdir('/home/ec2-user/SageMaker/btt_training/model_training')

print("Current working directory:", os.getcwd())
print("Python path updated for model_training module")


## Configure Training Parameters


In [None]:
# Load and modify configuration for SageMaker with S3 direct access
from omegaconf import OmegaConf

print("Loading original configuration...")
args = OmegaConf.load('rnn_args.yaml')

# Update configuration for S3 direct access
args.dataset.s3_bucket = S3_BUCKET_NAME
args.dataset.s3_prefix = S3_DATA_PREFIX
args.dataset.use_s3_direct = True  # Flag to use S3 direct access

# Auto-configure GPU number based on availability
if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
    requested_gpu = int(args.gpu_number)
    if requested_gpu >= num_gpus:
        args.gpu_number = '0'  # Use GPU 0 if requested GPU doesn't exist
        print(f"⚠️  Requested GPU {requested_gpu} not available. Using GPU 0 instead.")
    else:
        print(f"✅ Using requested GPU {requested_gpu}")
else:
    print("⚠️  No CUDA GPUs available. Training will use CPU.")

# Update output directories for SageMaker
args.output_dir = '/home/ec2-user/SageMaker/trained_models/baseline_rnn'
args.checkpoint_dir = '/home/ec2-user/SageMaker/trained_models/baseline_rnn/checkpoint'

# Create output directories
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.checkpoint_dir, exist_ok=True)

print("\nConfiguration updated for SageMaker with S3 direct access:")
print(f"  S3 Bucket: {args.dataset.s3_bucket}")
print(f"  S3 Prefix: {args.dataset.s3_prefix}")
print(f"  Use S3 Direct: {args.dataset.use_s3_direct}")
print(f"  GPU Number: {args.gpu_number}")
print(f"  Output directory: {args.output_dir}")
print(f"  Checkpoint directory: {args.checkpoint_dir}")

# S3 Checkpoint Configuration
S3_CHECKPOINT_PREFIX = 'training_results/baseline_rnn/checkpoints/'
S3_BEST_CHECKPOINT_KEY = f'{S3_CHECKPOINT_PREFIX}best_checkpoint'

print(f"  S3 Checkpoint prefix: s3://{S3_BUCKET_NAME}/{S3_CHECKPOINT_PREFIX}")
print(f"  S3 Best checkpoint: s3://{S3_BUCKET_NAME}/{S3_BEST_CHECKPOINT_KEY}")

# Optional: Resume from S3 checkpoint if available
# RESUME_FROM_S3 = True
# S3_CHECKPOINT_TO_RESUME = f'{S3_CHECKPOINT_PREFIX}checkpoint_step_10000'  # Example checkpoint


## Start Training


In [None]:
# Import and run the training with S3 direct access
from s3_rnn_trainer import S3BrainToTextDecoder_Trainer

print("Starting training with S3 direct access...")
print(f"S3 Bucket: {args.dataset.s3_bucket}")
print(f"S3 Prefix: {args.dataset.s3_prefix}")
print(f"Output directory: {args.output_dir}")
print(f"Number of training batches: {args.num_training_batches}")
print(f"Batch size: {args.dataset.batch_size}")
print(f"Learning rate: {args.lr_max}")

# Create the S3 trainer
trainer = S3BrainToTextDecoder_Trainer(args)

# Optional: Resume from S3 checkpoint
# Uncomment the following lines to resume from a specific S3 checkpoint:
# if 'RESUME_FROM_S3' in locals() and RESUME_FROM_S3:
#     print(f"Resuming from S3 checkpoint: {S3_CHECKPOINT_TO_RESUME}")
#     start_step = trainer.resume_from_s3_checkpoint(S3_CHECKPOINT_TO_RESUME)
#     print(f"Resumed from step {start_step}")

# Run training
metrics = trainer.train()

print("\n✅ Training completed!")
print(f"Final metrics: {metrics}")

# Clean up cached files
trainer.cleanup()


## S3 Checkpoint Management


In [None]:
def list_s3_checkpoints(bucket_name, prefix):
    """List all checkpoints in S3"""
    try:
        paginator = s3_client.get_paginator('list_objects_v2')
        pages = paginator.paginate(Bucket=bucket_name, Prefix=prefix)
        
        checkpoints = []
        for page in pages:
            if 'Contents' in page:
                for obj in page['Contents']:
                    key = obj['Key']
                    if key.endswith('.pt') or 'checkpoint' in key:
                        checkpoints.append({
                            'key': key,
                            'size': obj['Size'],
                            'last_modified': obj['LastModified']
                        })
        
        return sorted(checkpoints, key=lambda x: x['last_modified'], reverse=True)
        
    except ClientError as e:
        print(f"Error listing S3 checkpoints: {e}")
        return []

def download_checkpoint_from_s3(bucket_name, s3_key, local_path):
    """Download a specific checkpoint from S3"""
    try:
        s3_client.download_file(bucket_name, s3_key, local_path)
        print(f"Downloaded checkpoint to: {local_path}")
        return True
    except ClientError as e:
        print(f"Error downloading checkpoint: {e}")
        return False

# List available checkpoints in S3
print("Available checkpoints in S3:")
checkpoints = list_s3_checkpoints(S3_BUCKET_NAME, S3_CHECKPOINT_PREFIX)

if checkpoints:
    for i, checkpoint in enumerate(checkpoints[:10]):  # Show first 10
        print(f"  {i+1}. {checkpoint['key']}")
        print(f"     Size: {checkpoint['size']} bytes")
        print(f"     Modified: {checkpoint['last_modified']}")
        print()
    if len(checkpoints) > 10:
        print(f"  ... and {len(checkpoints) - 10} more checkpoints")
else:
    print("  No checkpoints found in S3 yet.")
    print("  Checkpoints will be saved during training.")


## Upload Results to S3 (Optional)


In [None]:
def upload_to_s3(local_dir, bucket_name, s3_prefix):
    """
    Upload local directory contents to S3
    """
    try:
        for root, dirs, files in os.walk(local_dir):
            for file in files:
                local_file_path = os.path.join(root, file)
                relative_path = os.path.relpath(local_file_path, local_dir)
                s3_key = f"{s3_prefix}{relative_path}"
                
                print(f"Uploading: {local_file_path} -> s3://{bucket_name}/{s3_key}")
                s3_client.upload_file(local_file_path, bucket_name, s3_key)
        
        print("\n✅ Successfully uploaded results to S3")
        
    except ClientError as e:
        print(f"❌ Error uploading to S3: {e}")

# Upload trained models and results to S3
print("Uploading training results to S3...")
upload_to_s3(args.output_dir, S3_BUCKET_NAME, 'training_results/baseline_rnn/')


## Training Summary


In [None]:
# Display training summary
print("\n" + "="*50)
print("TRAINING SUMMARY")
print("="*50)
print(f"Dataset: S3 bucket '{S3_BUCKET_NAME}' at '{S3_DATA_PREFIX}'")
print(f"Data access: Direct S3 access (no local download)")
print(f"Output directory: {args.output_dir}")
print(f"Training batches: {args.num_training_batches}")
print(f"Batch size: {args.dataset.batch_size}")
print(f"Model architecture: {args.model.n_layers} GRU layers with {args.model.n_units} units each")
print(f"Final metrics: {metrics}")
print("="*50)
print("✅ Training completed successfully with S3 direct access!")
print("💡 Benefits: No need to download entire dataset, faster startup, less storage usage")
print("="*50)
