# Brain-to-Text '25 - GPU Training
This notebook trains the Conformer model with k-fold cross-validation.

**Instructions:**
1. Upload your data to Google Drive
2. Mount Drive and update `DATA_DIR` path
3. Run all cells


In [1]:
# Install essential libraries for evaluation and decoding
!pip install -q jiwer scipy

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
import os
import gc
import h5py
import torch
import numpy as np
import jiwer
from torch import nn
from torch.utils.data import Dataset, DataLoader, Subset
from torch.nn import functional as F
import torch.nn.utils.rnn as rnn_utils
from torch.utils.checkpoint import checkpoint
from torch.amp import autocast, GradScaler
from tqdm.auto import tqdm
from scipy.ndimage import gaussian_filter1d
from sklearn.model_selection import KFold
from collections import defaultdict
import matplotlib.pyplot as plt

# Set memory alloc for fragmentation handling
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

class CFG:
    # --- Paths ---
    DATA_DIR = '/content/drive/MyDrive/hdf5_data_final'

    # --- Training Params ---
    N_FOLDS = 4            # Increased to 4 for better validation stability
    EPOCHS = 10            # Enough for convergence with pre-trained initialization
    LR = 8e-4              # Slightly lower LR for stability with adapters
    BATCH_SIZE = 8         # Conservative batch size to fit in VRAM with deep model
    GRAD_ACCUM_STEPS = 4   # Effective batch size = 32 (8 * 4)

    # --- Model Architecture ---
    INPUT_DIM = 512
    ENCODER_DIM = 256
    N_LAYERS = 6           # Deeper model (enabled by gradient checkpointing)
    N_HEAD = 4
    OUTPUT_DIM = 41        # 40 phonemes + 1 blank

    # --- Signal Processing ---
    SMOOTHING_SIGMA = 20   # Gaussian smoothing factor
    ROBUST_SCALING = True  # Use Median/IQR scaling

    # --- Hardware ---
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Running on {CFG.DEVICE}")

Running on cpu


In [None]:
# --- Helper Functions ---
def gaussian_smoothing(data, sigma=20):
    return gaussian_filter1d(data, sigma=sigma, axis=0)

def robust_scale_trial(data):
    """
    Scales data using Median and Interquartile Range (IQR).
    Robust to high-amplitude artifacts/spikes.
    """
    # Calculate stats per channel across time
    median = np.median(data, axis=0)
    q75, q25 = np.percentile(data, [75, 25], axis=0)
    iqr = q75 - q25
    # Avoid division by zero
    iqr[iqr == 0] = 1.0
    return (data - median) / iqr

# --- Dataset Class ---
class BrainDataset(Dataset):
    def __init__(self, hdf5_files, session_mapping, is_test=False, train_mode=True):
        self.files = []
        self.session_ids = []
        self.keys = []

        # Pre-scan files to build index
        for f_path in hdf5_files:
            sess_name = os.path.basename(os.path.dirname(f_path))
            sid = session_mapping.get(sess_name, 0)
            with h5py.File(f_path, 'r') as f:
                keys = sorted(list(f.keys()))
                for k in keys:
                    self.files.append(f_path)
                    self.session_ids.append(sid)
                    self.keys.append(k)

        self.is_test = is_test
        self.train_mode = train_mode
        self.opened_files = {}

    def __len__(self):
        return len(self.keys)

    def __getitem__(self, idx):
        f_path = self.files[idx]
        if f_path not in self.opened_files:
            self.opened_files[f_path] = h5py.File(f_path, 'r')

        key = self.keys[idx]
        grp = self.opened_files[f_path][key]

        # 1. Load Input
        x = grp['input_features'][:]

        # 2. Apply Signal Processing
        if CFG.SMOOTHING_SIGMA > 0:
            x = gaussian_smoothing(x, CFG.SMOOTHING_SIGMA)
        if CFG.ROBUST_SCALING:
            x = robust_scale_trial(x)

        x = torch.tensor(x, dtype=torch.float32)

        # 3. Load Targets
        y = torch.tensor(grp['seq_class_ids'][:], dtype=torch.long)

        return x, y, self.session_ids[idx]

def collate_fn(batch):
    xs, ys, sids = zip(*batch)

    # Safety: Ensure CTC constraints (Target Length <= Input Length)
    xs_safe, ys_safe = [], []
    for x, y in zip(xs, ys):
        if len(y) > len(x):
            y = y[:len(x)] # Truncate target if longer than input
        xs_safe.append(x)
        ys_safe.append(y)

    x_lens = torch.tensor([len(x) for x in xs_safe], dtype=torch.long)
    y_lens = torch.tensor([len(y) for y in ys_safe], dtype=torch.long)
    sids = torch.tensor(sids, dtype=torch.long)

    padded_x = rnn_utils.pad_sequence(xs_safe, batch_first=True)
    padded_y = rnn_utils.pad_sequence(ys_safe, batch_first=True)

    return padded_x, padded_y, x_lens, y_lens, sids

# --- Loader Utility ---
def get_all_filepaths(data_dir):
    train_files = []
    sessions = sorted([d for d in os.listdir(data_dir) if os.path.isdir(os.path.join(data_dir, d))])
    sess_map = {name: i for i, name in enumerate(sessions)}

    for root, dirs, files in os.walk(data_dir):
        for f in files:
            if f == 'data_train.hdf5' or f == 'data_val.hdf5': # Use both for Cross Validation
                train_files.append(os.path.join(root, f))

    return train_files, sess_map, len(sessions)

RecursionError: maximum recursion depth exceeded

In [None]:
# --- Model Components ---
class SubjectAdapter(nn.Module):
    """Learns a unique projection for each day/session to normalize inputs"""
    def __init__(self, dim, n_sess):
        super().__init__()
        self.w = nn.Parameter(torch.stack([torch.eye(dim) for _ in range(n_sess)]))
        self.b = nn.Parameter(torch.zeros(n_sess, dim))

    def forward(self, x, sids):
        # Batch Matrix Multiply: x[B,T,D] @ w[B,D,D] + b[B,1,D]
        return torch.bmm(x, self.w[sids]) + self.b[sids].unsqueeze(1)

class ConformerBlock(nn.Module):
    def __init__(self, dim, n_head, drop=0.1):
        super().__init__()
        self.ln1 = nn.LayerNorm(dim)
        # Flash Attention (Scalable Dot Product)
        self.attn = nn.MultiheadAttention(dim, n_head, dropout=drop, batch_first=True)
        self.ln2 = nn.LayerNorm(dim)
        self.ff1 = nn.Sequential(
            nn.Linear(dim, dim*4), nn.SiLU(), nn.Dropout(drop), nn.Linear(dim*4, dim)
        )
        self.ln3 = nn.LayerNorm(dim)
        # Conv Module
        self.conv = nn.Sequential(
            nn.Conv1d(dim, dim*2, 1), nn.GLU(dim=1),
            nn.Conv1d(dim, dim, 31, padding=15, groups=dim), # Depthwise
            nn.BatchNorm1d(dim), nn.SiLU(),
            nn.Conv1d(dim, dim, 1), nn.Dropout(drop)
        )
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        # 1. FF (Macaron style half-step)
        x = x + 0.5 * self.ff1(self.ln1(x))
        # 2. Attention
        res = x
        x_norm = self.ln2(x)
        # Using PyTorch 2.0+ optimized attention
        att_out, _ = self.attn(x_norm, x_norm, x_norm, need_weights=False)
        x = res + self.drop(att_out)
        # 3. Conv
        res = x
        x_norm = self.ln3(x).transpose(1, 2) # [B, D, T]
        conv_out = self.conv(x_norm).transpose(1, 2) # [B, T, D]
        x = res + conv_out
        return x

class BrainToTextModel(nn.Module):
    def __init__(self, num_sessions):
        super().__init__()
        # Phase 2: Adapter
        self.adapter = SubjectAdapter(CFG.INPUT_DIM, num_sessions)
        self.inp_proj = nn.Linear(CFG.INPUT_DIM, CFG.ENCODER_DIM)

        # Deep Encoder with Gradient Checkpointing
        self.layers = nn.ModuleList([
            ConformerBlock(CFG.ENCODER_DIM, CFG.N_HEAD) for _ in range(CFG.N_LAYERS)
        ])

        self.out_proj = nn.Linear(CFG.ENCODER_DIM, CFG.OUTPUT_DIM)

    def forward(self, x, sids):
        x = self.adapter(x, sids)
        x = self.inp_proj(x)

        for layer in self.layers:
            # Checkpointing saves VRAM allowing larger batches/deeper models
            x = checkpoint(layer, x, use_reentrant=False)

        return F.log_softmax(self.out_proj(x), dim=2)

In [None]:
VOCAB = ['AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER',
         'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
         'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH', '|']
TOKEN_MAP = {i + 1: phoneme for i, phoneme in enumerate(VOCAB)}
TOKEN_MAP[0] = ""

def beam_search_decode(log_probs, beam_width=5):
    """
    Pure Python CTC Beam Search for validation.
    """
    T, C = log_probs.shape
    # Beam: {prefix_tuple: (prob_blank, prob_non_blank)}
    beams = {(): (0.0, float('-inf'))}

    for t in range(T):
        next_beams = defaultdict(lambda: (float('-inf'), float('-inf')))
        for prefix, (p_b, p_nb) in beams.items():
            p_total = np.logaddexp(p_b, p_nb)
            if p_total < -20: continue # Pruning low prob paths

            for c in range(C):
                p_c = log_probs[t, c]
                if c == 0: # Blank
                    n_pb, n_pnb = next_beams[prefix]
                    n_pb = np.logaddexp(n_pb, p_total + p_c)
                    next_beams[prefix] = (n_pb, n_pnb)
                else: # Non-blank
                    end_t = prefix[-1] if prefix else None
                    new_prefix = prefix + (c,)

                    n_pb, n_pnb = next_beams[new_prefix]

                    if c == end_t: # Repeated char
                        # Don't extend, just update probability of current prefix ending in non-blank
                        # This part handles the "aa" vs "a" logic in CTC
                        prev_pb, prev_pnb = next_beams[prefix]
                        prev_pnb = np.logaddexp(prev_pnb, p_nb + p_c)
                        next_beams[prefix] = (prev_pb, prev_pnb)
                        # Extend via blank
                        n_pnb = np.logaddexp(n_pnb, p_b + p_c)
                    else:
                        n_pnb = np.logaddexp(n_pnb, p_total + p_c)

                    next_beams[new_prefix] = (n_pb, n_pnb)

        # Sort and keep top K
        sorted_beams = sorted(
            next_beams.items(),
            key=lambda k: np.logaddexp(*k[1]),
            reverse=True
        )
        beams = dict(sorted_beams[:beam_width])

    best_prefix = max(beams.keys(), key=lambda k: np.logaddexp(*beams[k]))
    return " ".join([TOKEN_MAP.get(i, "") for i in best_prefix])

In [None]:
def validate(model, loader):
    model.eval()
    losses = []
    preds, trues = [], []
    crit = nn.CTCLoss(blank=0, zero_infinity=True)

    with torch.no_grad():
        for x, y, xl, yl, sids in tqdm(loader, desc="Validating", leave=False):
            x, y = x.to(CFG.DEVICE), y.to(CFG.DEVICE)
            sids = sids.to(CFG.DEVICE)

            with autocast('cuda', dtype=torch.bfloat16):
                logits = model(x, sids)
                loss = crit(logits.permute(1,0,2), y, xl, yl)
                losses.append(loss.item())

            # Decode a few for WER
            cpu_logits = logits.float().cpu().numpy()
            for i in range(len(x)):
                # Use Beam Search for better accuracy
                p_text = beam_search_decode(cpu_logits[i, :xl[i]], beam_width=3)
                t_text = " ".join([TOKEN_MAP.get(k.item(), "") for k in y[i, :yl[i]]])
                preds.append(p_text)
                trues.append(t_text)

    wer = jiwer.wer(trues, preds)
    return np.mean(losses), wer

# --- Main Execution ---
files, sess_map, n_sess = get_all_filepaths(CFG.DATA_DIR)
print(f"Found {len(files)} files across {n_sess} sessions.")

kfold = KFold(n_splits=CFG.N_FOLDS, shuffle=True, random_state=42)
fold_results = []

for fold, (train_idx, val_idx) in enumerate(kfold.split(files)):
    print(f"\n{'='*20} FOLD {fold+1}/{CFG.N_FOLDS} {'='*20}")

    # 1. Setup Data
    train_ds = BrainDataset([files[i] for i in train_idx], sess_map)
    val_ds = BrainDataset([files[i] for i in val_idx], sess_map)

    train_loader = DataLoader(train_ds, batch_size=CFG.BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2)
    val_loader = DataLoader(val_ds, batch_size=CFG.BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2)

    # 2. Setup Model & Optim
    model = BrainToTextModel(n_sess).to(CFG.DEVICE)
    opt = torch.optim.AdamW(model.parameters(), lr=CFG.LR, weight_decay=1e-2)
    crit = nn.CTCLoss(blank=0, zero_infinity=True)
    scaler = GradScaler()
    sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=CFG.LR, steps_per_epoch=len(train_loader), epochs=CFG.EPOCHS)

    # 3. Training Loop
    best_wer = 1.0
    for epoch in range(1, CFG.EPOCHS+1):
        model.train()
        total_loss = 0
        pbar = tqdm(train_loader, desc=f"Ep {epoch}", leave=False)

        for i, (x, y, xl, yl, sids) in enumerate(pbar):
            x, y = x.to(CFG.DEVICE), y.to(CFG.DEVICE)
            sids = sids.to(CFG.DEVICE)

            with autocast('cuda', dtype=torch.bfloat16):
                logits = model(x, sids)
                loss = crit(logits.permute(1,0,2), y, xl, yl) / CFG.GRAD_ACCUM_STEPS

            scaler.scale(loss).backward()

            if (i + 1) % CFG.GRAD_ACCUM_STEPS == 0:
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(opt)
                scaler.update()
                opt.zero_grad()
                sched.step()

            total_loss += loss.item() * CFG.GRAD_ACCUM_STEPS
            pbar.set_postfix(loss=total_loss/(i+1))

        # Validation
        v_loss, v_wer = validate(model, val_loader)
        print(f"Epoch {epoch}: Train Loss={total_loss/len(train_loader):.3f} | Val Loss={v_loss:.3f} | WER={v_wer:.4f}")

        if v_wer < best_wer:
            best_wer = v_wer
            torch.save(model.state_dict(), f"best_model_fold_{fold}.pt")
            print(f"--> Saved Best WER: {best_wer:.4f}")

    fold_results.append(best_wer)

    # Cleanup
    del model, opt, scaler, train_loader, val_loader
    torch.cuda.empty_cache()
    gc.collect()

print(f"\nFinal Results: {fold_results}")
print(f"Average WER: {np.mean(fold_results):.4f}")