In [None]:
# Install dependencies
!pip install omegaconf editdistance hydra-core bitarray
# Install other requirements if needed (Colab usually has torch, numpy, etc.)
# !pip install -r ../requirements.txt

In [None]:
# Download Data
import os
if not os.path.exists('../data/hdf5_data_final'):
    print("Downloading data...")
    !python ../download_data.py
else:
    print("Data already exists.")

In [None]:
# Imports
import sys
import os
import torch
import logging
from omegaconf import OmegaConf
from datetime import datetime

# Add the current directory to path so we can import modules
sys.path.append(os.getcwd())

# Import project modules
from rnn_trainer import BrainToTextDecoder_Trainer
from dataset import BrainToTextDataset, train_test_split_indicies
from unet_model import NeuralUNet
from train_ssl import get_mask

## 1. Train Supervised Model (RNN or Conformer)

Configure the training parameters below. You can switch between `rnn` and `conformer` architectures.

In [None]:
# Load default arguments
args = OmegaConf.load('rnn_args.yaml')

# --- Modify Configuration Here ---
args.model.type = 'rnn'  # Options: 'rnn', 'conformer'
args.gpu_number = '0'
args.num_training_batches = 120000

# Increase batch size to utilize more GPU memory
# Default is 64. Try 128, 256, or 512 depending on your GPU memory (80GB can likely handle 256+)
args.dataset.batch_size = 256 

# For Conformer, you might want to adjust these:
# args.model.type = 'conformer'
# args.lr_max = 0.0005

print(f"Training Model: {args.model.type}")
print(f"Batch Size: {args.dataset.batch_size}")
print(OmegaConf.to_yaml(args))

In [None]:
# Start Training
trainer = BrainToTextDecoder_Trainer(args)
metrics = trainer.train()

## 2. Train Self-Supervised Model (U-Net)

Train a Masked Autoencoder (MAE) using a U-Net architecture for feature learning/denoising.

In [None]:
# SSL Hyperparameters
MASK_RATIO = 0.5
# Increase batch size for SSL as well
BATCH_SIZE = 64 # Try 64, 128, or higher
LR = 1e-4
EPOCHS = 50
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
SAVE_DIR = f'trained_models/unet_ssl_{timestamp}'
os.makedirs(SAVE_DIR, exist_ok=True)

print(f"Training U-Net SSL on {DEVICE}")
print(f"Saving to {SAVE_DIR}")

In [None]:
# Prepare Dataset for SSL
args = OmegaConf.load('rnn_args.yaml') # Load args to get data paths

train_files, test_files = train_test_split_indicies(
    os.path.join('../data', 't15_copyTaskData_description.csv'),
    test_percentage=0.1
)

train_dataset = BrainToTextDataset(
    train_files,
    n_batches=200, # Smaller number for SSL iteration
    split='train',
    batch_size=BATCH_SIZE,
    days_per_batch=1
)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=None, num_workers=4)

# Initialize Model
model = NeuralUNet(n_channels=1, n_classes=1).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
criterion = torch.nn.MSELoss()

In [None]:
# Training Loop
from tqdm.notebook import tqdm

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")
    for batch in pbar:
        # Data shape: (B, T, C)
        neural_data = batch['neural_features'].to(DEVICE)
        
        # Generate mask
        mask = get_mask(neural_data, MASK_RATIO, device=DEVICE)
        
        # Apply mask (simulate missing data)
        # Note: In MAE, we usually replace masked patches with a learnable token or 0
        # Here we just zero it out for simplicity as per original script logic implication
        masked_input = neural_data * (1 - mask.squeeze(1))
        
        # Add channel dim for U-Net: (B, 1, T, C)
        masked_input = masked_input.unsqueeze(1)
        target = neural_data.unsqueeze(1)
        
        # Forward
        output = model(masked_input)
        
        # Compute loss only on masked regions
        loss = (criterion(output, target) * mask).sum() / (mask.sum() + 1e-6)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
    
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch+1} Average Loss: {avg_loss:.6f}")
    
    # Save checkpoint
    if (epoch + 1) % 10 == 0:
        torch.save(model.state_dict(), os.path.join(SAVE_DIR, f'unet_epoch_{epoch+1}.pth'))