# SageMaker RNN Training for Brain-to-Text Model

This notebook trains the RNN model for brain-to-text decoding using AWS SageMaker with data from S3 and checkpoint saving to S3.

**Uses the existing training infrastructure from the `model_training` folder.**

## Setup and Configuration


In [None]:
# Install required packages
%pip install omegaconf h5py torchaudio boto3 sagemaker


In [None]:
import torch
import torch.nn as nn
import h5py
import numpy as np
import os
import time
import logging
import json
import pickle
import math
import random
import boto3
import tempfile
import shutil
from pathlib import Path
import sys
from omegaconf import OmegaConf
import sagemaker
from sagemaker.session import Session

# Add model_training directory to Python path
sys.path.append('model_training')

# Import existing training components
from rnn_trainer import BrainToTextDecoder_Trainer
from dataset import BrainToTextDataset, train_test_split_indicies
from rnn_model import GRUDecoder

# Set up device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


## S3 Configuration and Setup


In [None]:
# S3 Configuration
S3_BUCKET_NAME = '4k-woody-btt'
S3_DATA_PREFIX = '4k/data/hdf5_data_final/'  # Base path for training data
S3_CHECKPOINT_PREFIX = 'checkpoints/'

# Initialize S3 client
s3_client = boto3.client('s3')

# Get SageMaker session for additional utilities
sagemaker_session = sagemaker.Session()

print(f"S3 Bucket: {S3_BUCKET_NAME}")
print(f"Data prefix: {S3_DATA_PREFIX}")
print(f"Checkpoint prefix: {S3_CHECKPOINT_PREFIX}")
print(f"Expected data structure: s3://{S3_BUCKET_NAME}/{S3_DATA_PREFIX}{{date}}/data_train.hdf5")


## S3 Data Loading and Checkpoint Management


In [None]:
# S3 Data Loading Utilities
class S3DataLoader:
    def __init__(self, bucket_name, data_prefix):
        self.bucket_name = bucket_name
        self.data_prefix = data_prefix
        self.s3_client = boto3.client('s3')
        
    def list_available_dates(self):
        """List all available training dates in the S3 bucket"""
        response = self.s3_client.list_objects_v2(
            Bucket=self.bucket_name,
            Prefix=self.data_prefix,
            Delimiter='/'
        )
        
        dates = []
        if 'CommonPrefixes' in response:
            for prefix in response['CommonPrefixes']:
                # Extract date from path like "4k/data/hdf5_data_final/t15.2023.08.11/"
                date_path = prefix['Prefix'].rstrip('/')
                date = os.path.basename(date_path)
                if date.startswith('t'):  # Filter for training dates
                    dates.append(date)
        
        return sorted(dates)
    
    def download_file(self, s3_key, local_path):
        """Download a file from S3 to local storage"""
        try:
            self.s3_client.download_file(self.bucket_name, s3_key, local_path)
            logger.info(f"Downloaded {s3_key} to {local_path}")
            return True
        except Exception as e:
            logger.error(f"Failed to download {s3_key}: {str(e)}")
            return False
    
    def download_data_to_local(self, local_data_dir, specific_dates=None):
        """Download training data from S3 to local directory structure
        
        Args:
            local_data_dir: Local directory to create the data structure
            specific_dates: List of specific dates to download (e.g., ['t15.2023.08.11'])
                          If None, downloads all available dates
        """
        if specific_dates is None:
            # Get all available dates
            dates = self.list_available_dates()
        else:
            dates = specific_dates
            
        # Create local data directory structure
        os.makedirs(local_data_dir, exist_ok=True)
        downloaded_files = []
        
        logger.info(f"Found {len(dates)} training dates: {dates}")
        
        for date in dates:
            # Create date subdirectory
            date_dir = os.path.join(local_data_dir, date)
            os.makedirs(date_dir, exist_ok=True)
            
            # Construct the S3 key for this date's training data
            s3_key = f"{self.data_prefix}{date}/data_train.hdf5"
            
            # Local file path
            local_path = os.path.join(date_dir, "data_train.hdf5")
            
            if self.download_file(s3_key, local_path):
                downloaded_files.append(local_path)
            else:
                logger.warning(f"Failed to download data for date: {date}")
        
        logger.info(f"Downloaded {len(downloaded_files)} files to {local_data_dir}")
        return local_data_dir, downloaded_files


In [None]:
# S3 Checkpoint Manager
class S3CheckpointManager:
    def __init__(self, bucket_name, checkpoint_prefix):
        self.bucket_name = bucket_name
        self.checkpoint_prefix = checkpoint_prefix
        self.s3_client = boto3.client('s3')
        
    def save_checkpoint(self, model, optimizer, epoch, loss, metrics=None, is_best=False):
        """Save model checkpoint to S3"""
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'metrics': metrics or {}
        }
        
        # Create temporary file
        with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file:
            torch.save(checkpoint, tmp_file.name)
            
            # Determine S3 key
            if is_best:
                s3_key = f"{self.checkpoint_prefix}best_checkpoint.pth"
            else:
                s3_key = f"{self.checkpoint_prefix}checkpoint_epoch_{epoch}.pth"
            
            # Upload to S3
            try:
                self.s3_client.upload_file(tmp_file.name, self.bucket_name, s3_key)
                logger.info(f"Saved checkpoint to s3://{self.bucket_name}/{s3_key}")
                
                # Clean up temporary file
                os.unlink(tmp_file.name)
                return True
                
            except Exception as e:
                logger.error(f"Failed to upload checkpoint: {str(e)}")
                os.unlink(tmp_file.name)
                return False
    
    def load_checkpoint(self, s3_key, model, optimizer=None):
        """Load model checkpoint from S3"""
        with tempfile.NamedTemporaryFile(delete=False, suffix='.pth') as tmp_file:
            try:
                self.s3_client.download_file(self.bucket_name, s3_key, tmp_file.name)
                checkpoint = torch.load(tmp_file.name, map_location=device)
                
                model.load_state_dict(checkpoint['model_state_dict'])
                
                if optimizer and 'optimizer_state_dict' in checkpoint:
                    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                
                logger.info(f"Loaded checkpoint from s3://{self.bucket_name}/{s3_key}")
                
                # Clean up
                os.unlink(tmp_file.name)
                
                return checkpoint
                
            except Exception as e:
                logger.error(f"Failed to load checkpoint: {str(e)}")
                os.unlink(tmp_file.name)
                return None


In [None]:
# Enhanced Trainer with S3 Checkpoint Support
class S3EnhancedTrainer(BrainToTextDecoder_Trainer):
    """Enhanced trainer that saves checkpoints to S3"""
    
    def __init__(self, args, s3_checkpoint_manager):
        # Initialize the parent trainer
        super().__init__(args)
        self.s3_checkpoint_manager = s3_checkpoint_manager
        
    def save_checkpoint(self, epoch, is_best=False):
        """Override the checkpoint saving to use S3"""
        if hasattr(self, 'model') and hasattr(self, 'optimizer'):
            # Get current metrics
            metrics = {
                'epoch': epoch,
                'best_val_loss': self.best_val_loss,
                'best_val_PER': self.best_val_PER
            }
            
            # Save to S3
            success = self.s3_checkpoint_manager.save_checkpoint(
                self.model, 
                self.optimizer, 
                epoch, 
                self.best_val_loss, 
                metrics, 
                is_best
            )
            
            if success:
                self.logger.info(f"Checkpoint saved to S3 successfully")
            else:
                self.logger.error(f"Failed to save checkpoint to S3")
                
            return success
        else:
            self.logger.warning("Model or optimizer not initialized, cannot save checkpoint")
            return False


## Load Training Configuration


In [None]:
# Load the existing training configuration
config_path = 'model_training/rnn_args.yaml'
args = OmegaConf.load(config_path)

# Modify configuration for SageMaker environment
args.dataset.dataset_dir = '/tmp/hdf5_data_final'  # Will be set after downloading from S3
args.output_dir = '/tmp/trained_models/baseline_rnn'
args.checkpoint_dir = '/tmp/trained_models/baseline_rnn/checkpoint'

# Create local directories
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(args.checkpoint_dir, exist_ok=True)

print("Training Configuration:")
print(f"Dataset directory: {args.dataset.dataset_dir}")
print(f"Output directory: {args.output_dir}")
print(f"Checkpoint directory: {args.checkpoint_dir}")
print(f"Number of training batches: {args.num_training_batches}")
print(f"Number of sessions: {len(args.dataset.sessions)}")
print(f"Sessions: {args.dataset.sessions[:5]}...")  # Show first 5 sessions


## Download Data from S3


In [None]:
# Initialize S3 data loader
s3_data_loader = S3DataLoader(S3_BUCKET_NAME, S3_DATA_PREFIX)

# List available training dates
available_dates = s3_data_loader.list_available_dates()
print(f"Found {len(available_dates)} training dates in S3:")
for date in available_dates[:10]:  # Show first 10
    print(f"  - {date}")
if len(available_dates) > 10:
    print(f"  ... and {len(available_dates) - 10} more")

# Filter dates to match the sessions in the config
config_sessions = set(args.dataset.sessions)
available_sessions = [date for date in available_dates if date in config_sessions]

print(f"\nMatching sessions in S3: {len(available_sessions)}")
print(f"Config sessions: {len(config_sessions)}")

# Download data to local directory structure
local_data_dir = '/tmp/hdf5_data_final'
downloaded_dir, downloaded_files = s3_data_loader.download_data_to_local(
    local_data_dir, 
    available_sessions
)

# Update the dataset directory in config
args.dataset.dataset_dir = local_data_dir

print(f"\nDownloaded {len(downloaded_files)} files to: {downloaded_dir}")
print(f"Updated dataset directory: {args.dataset.dataset_dir}")


## Initialize Training


In [None]:
# Initialize S3 checkpoint manager
s3_checkpoint_manager = S3CheckpointManager(S3_BUCKET_NAME, S3_CHECKPOINT_PREFIX)

# Initialize the enhanced trainer with S3 checkpoint support
trainer = S3EnhancedTrainer(args, s3_checkpoint_manager)

print("Trainer initialized successfully!")
print(f"Model device: {trainer.device}")
print(f"Number of model parameters: {sum(p.numel() for p in trainer.model.parameters())}")
print(f"Training batches: {args.num_training_batches}")
print(f"Validation frequency: every {args.batches_per_val_step} batches")


In [None]:
# Start training
logger.info("Starting training...")
start_time = time.time()

try:
    # Run the training
    metrics = trainer.train()
    
    training_time = time.time() - start_time
    logger.info(f"Training completed in {training_time:.2f} seconds ({training_time/3600:.2f} hours)")
    
    # Save final checkpoint to S3
    trainer.save_checkpoint(args.num_training_batches, is_best=False)
    
    print("✅ Training completed successfully!")
    print(f"Final metrics: {metrics}")
    
except Exception as e:
    logger.error(f"Training failed: {str(e)}")
    raise e


## Save Training History and Cleanup


In [None]:
# Save training history to S3
training_history = {
    'training_time': training_time,
    'final_metrics': metrics,
    'config': OmegaConf.to_yaml(args),
    'sessions_used': available_sessions,
    'checkpoints_saved': True
}

with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as tmp_file:
    json.dump(training_history, tmp_file, indent=2)
    
    try:
        s3_client.upload_file(
            tmp_file.name, 
            S3_BUCKET_NAME, 
            f"{S3_CHECKPOINT_PREFIX}training_history.json"
        )
        logger.info("Training history saved to S3")
    except Exception as e:
        logger.error(f"Failed to save training history: {str(e)}")
    finally:
        os.unlink(tmp_file.name)

# Clean up local data directory
try:
    shutil.rmtree(local_data_dir)
    logger.info(f"Cleaned up local data directory: {local_data_dir}")
except Exception as e:
    logger.warning(f"Failed to clean up local data directory: {str(e)}")

# Clean up local model directories
try:
    shutil.rmtree(args.output_dir)
    shutil.rmtree(args.checkpoint_dir)
    logger.info("Cleaned up local model directories")
except Exception as e:
    logger.warning(f"Failed to clean up local model directories: {str(e)}")


## Training Summary


In [None]:
# Display training summary
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"Total training time: {training_time:.2f} seconds ({training_time/3600:.2f} hours)")
print(f"Training batches completed: {args.num_training_batches}")
print(f"Sessions used: {len(available_sessions)}")
print(f"Best validation loss: {trainer.best_val_loss:.4f}")
print(f"Best validation PER: {trainer.best_val_PER:.4f}")
print(f"\nS3 Storage:")
print(f"  Data source: s3://{S3_BUCKET_NAME}/{S3_DATA_PREFIX}")
print(f"  Checkpoints saved to: s3://{S3_BUCKET_NAME}/{S3_CHECKPOINT_PREFIX}")
print(f"  Training history: s3://{S3_BUCKET_NAME}/{S3_CHECKPOINT_PREFIX}training_history.json")
print("="*60)


## Optional: Load and Test a Checkpoint


In [None]:
# Example: Load the best checkpoint
try:
    checkpoint = s3_checkpoint_manager.load_checkpoint(
        f"{S3_CHECKPOINT_PREFIX}best_checkpoint.pth", 
        trainer.model, 
        trainer.optimizer
    )
    
    if checkpoint:
        print("Successfully loaded best checkpoint:")
        print(f"  Epoch: {checkpoint['epoch']}")
        print(f"  Loss: {checkpoint['loss']:.4f}")
        print(f"  Metrics: {checkpoint['metrics']}")
    else:
        print("Failed to load checkpoint")
        
except Exception as e:
    print(f"Error loading checkpoint: {str(e)}")


## Save Final Model and Training History


In [None]:
# Save final model
final_metrics = {
    'train_loss': training_history['train_loss'][-1],
    'val_loss': training_history['val_loss'][-1],
    'epoch': config['training']['num_epochs'],
    'best_val_loss': best_val_loss,
    'total_training_time': total_time
}

success = checkpoint_manager.save_checkpoint(
    model, optimizer, config['training']['num_epochs'], 
    training_history['val_loss'][-1], final_metrics, False
)

if success:
    logger.info("Final model saved successfully")
else:
    logger.error("Failed to save final model")

# Save training history to S3
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.json') as tmp_file:
    json.dump(training_history, tmp_file, indent=2)
    
    try:
        s3_client.upload_file(
            tmp_file.name, 
            S3_BUCKET_NAME, 
            f"{S3_CHECKPOINT_PREFIX}training_history.json"
        )
        logger.info("Training history saved to S3")
    except Exception as e:
        logger.error(f"Failed to save training history: {str(e)}")
    finally:
        os.unlink(tmp_file.name)

# Clean up temporary data directory
import shutil
shutil.rmtree(temp_data_dir)
logger.info(f"Cleaned up temporary directory: {temp_data_dir}")


## Training Summary
