In [1]:
!pip install jiwer

Collecting jiwer
  Downloading jiwer-4.0.0-py3-none-any.whl.metadata (3.3 kB)
Collecting click>=8.1.8 (from jiwer)
  Downloading click-8.3.1-py3-none-any.whl.metadata (2.6 kB)
Collecting rapidfuzz>=3.9.7 (from jiwer)
  Downloading rapidfuzz-3.14.3-cp312-cp312-macosx_11_0_arm64.whl.metadata (12 kB)
Downloading jiwer-4.0.0-py3-none-any.whl (23 kB)
Downloading click-8.3.1-py3-none-any.whl (108 kB)
Downloading rapidfuzz-3.14.3-cp312-cp312-macosx_11_0_arm64.whl (1.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/1.4 MB[0m [31m7.5 MB/s[0m  [33m0:00:00[0m eta [36m0:00:01[0m
[?25hInstalling collected packages: rapidfuzz, click, jiwer
[2K  Attempting uninstall: click
[2K    Found existing installation: click 8.1.7
[2K    Uninstalling click-8.1.7:
[2K      Successfully uninstalled click-8.1.7
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3/3[0m [jiwer]
[1A[2KSuccessfully installed click-8.3.1 jiwer-4.0.0 rapidfuzz-3.14.3


# Optimized Brain-to-Text '25
## Phase 1: Setup & Baseline Reproduction
This notebook implements the optimized pipeline, starting with the baseline reproduction and moving towards Conformer architecture.

In [2]:
import os
import yaml
import h5py
import torch
import numpy as np
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.nn import functional as F
import torch.nn.utils.rnn as rnn_utils
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import jiwer

class CFG:
    # --- Model Hyperparameters ---
    N_HEAD = 8 
    
    # --- Training ---
    EPOCHS = 5
    LR = 1e-3
    BATCH_SIZE = 32
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # --- Paths ---
    DATA_DIR = "/Users/pswmi64/Desktop/Brain-to-Text-25/t15_copyTask_neuralData/hdf5_data_final"
    CHECKPOINT_PATH = "/Users/pswmi64/Desktop/Brain-to-Text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/best_checkpoint"
    
print(f"Running on device: {CFG.DEVICE}")

Running on device: cpu


In [3]:
# --- Data Loading Utilities ---

def temporal_mask(data, mask_percentage=0.05, mask_value=0.0):
    """
    Applies temporal masking to a 2D tensor [Sequence, Features].
    """
    if not torch.is_tensor(data):
        data = torch.tensor(data, dtype=torch.float32)
        
    seq_len, _ = data.shape
    num_to_mask = int(seq_len * mask_percentage)
    
    if num_to_mask > 0:
        mask_indices = torch.randperm(seq_len)[:num_to_mask]
        data[mask_indices, :] = mask_value
        
    return data

class BrainDataset(Dataset):
    def __init__(self, hdf5_file, input_key="input_features", target_key="seq_class_ids", is_test=False, use_augmentation=False):
        self.file_path = hdf5_file
        self.input_key = input_key
        self.target_key = target_key
        self.is_test = is_test
        self.use_augmentation = use_augmentation 
        self.file = None
        
        try:
            with h5py.File(self.file_path, "r") as f:
                self.trial_keys = sorted(list(f.keys()))
        except FileNotFoundError:
            print(f"Warning: File not found {self.file_path}, creating empty dataset.")
            self.trial_keys = []

    def __len__(self):
        return len(self.trial_keys)

    def __getitem__(self, idx):
        if self.file is None:
            self.file = h5py.File(self.file_path, "r")
            
        trial_key = self.trial_keys[idx]
        trial_group = self.file[trial_key]
        
        x_data = trial_group[self.input_key][:]
        x = torch.tensor(x_data, dtype=torch.float32)
        
        if self.use_augmentation and not self.is_test:
            x = temporal_mask(x, mask_percentage=0.1)
        
        if self.target_key in trial_group:
            y_data = trial_group[self.target_key][:]
            y = torch.tensor(y_data, dtype=torch.long)
        else:
            y = torch.tensor([], dtype=torch.long)
        
        if self.is_test:
            return x, y, trial_key
        else:
            return x, y

def custom_collate(batch):
    is_test = len(batch[0]) == 3
    if is_test:
        xs, ys, keys = zip(*batch)
    else:
        xs, ys = zip(*batch)
        
    x_lengths = torch.tensor([len(x) for x in xs], dtype=torch.long)
    y_lengths = torch.tensor([len(y) for y in ys], dtype=torch.long)
    
    padded_xs = rnn_utils.pad_sequence(xs, batch_first=True, padding_value=0.0)
    padded_ys = rnn_utils.pad_sequence(ys, batch_first=True, padding_value=0)
    
    if is_test:
        return padded_xs, padded_ys, x_lengths, y_lengths, keys
    else:
        return padded_xs, padded_ys, x_lengths, y_lengths

def load_datasets():
    train_datasets = []
    val_datasets = []
    test_datasets = []

    subfolders = [f.path for f in os.scandir(CFG.DATA_DIR) if f.is_dir()]
    print(f"Found {len(subfolders)} session folders.")
    
    for subfolder_path in subfolders:
        train_file = os.path.join(subfolder_path, "data_train.hdf5")
        val_file = os.path.join(subfolder_path, "data_val.hdf5")
        test_file = os.path.join(subfolder_path, "data_test.hdf5")

        train_set = BrainDataset(train_file, input_key="input_features", target_key="seq_class_ids", is_test=False, use_augmentation=True)
        val_set = BrainDataset(val_file, input_key="input_features", target_key="seq_class_ids", is_test=False, use_augmentation=False)
        test_set = BrainDataset(test_file, input_key="input_features", target_key="seq_class_ids", is_test=True, use_augmentation=False) 
        
        if len(train_set) > 0: train_datasets.append(train_set)
        if len(val_set) > 0: val_datasets.append(val_set)
        if len(test_set) > 0: test_datasets.append(test_set)
            
    return ConcatDataset(train_datasets), ConcatDataset(val_datasets), ConcatDataset(test_datasets)

print("Loading datasets...")
train_dataset, val_dataset, test_dataset = load_datasets()
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}")

train_loader = DataLoader(train_dataset, batch_size=CFG.BATCH_SIZE, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, collate_fn=custom_collate)

Loading datasets...
Found 45 session folders.
Train: 8072, Val: 1426, Test: 1450


In [4]:
# --- Conformer Architecture ---

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class FeedForwardModule(nn.Module):
    def __init__(self, dim, expansion_factor=4, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(dim)
        self.linear1 = nn.Linear(dim, dim * expansion_factor)
        self.swish = Swish()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim * expansion_factor, dim)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, T, C]
        out = self.layer_norm(x)
        out = self.linear1(out)
        out = self.swish(out)
        out = self.dropout1(out)
        out = self.linear2(out)
        out = self.dropout2(out)
        return out

class ConvolutionModule(nn.Module):
    def __init__(self, dim, kernel_size=31, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(dim)
        # Pointwise
        self.pointwise_conv1 = nn.Conv1d(dim, dim * 2, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        # Depthwise
        self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=dim)
        self.batch_norm = nn.BatchNorm1d(dim)
        self.swish = Swish()
        # Pointwise
        self.pointwise_conv2 = nn.Conv1d(dim, dim, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, T, C]
        out = self.layer_norm(x)
        out = out.transpose(1, 2) # [B, C, T]
        
        out = self.pointwise_conv1(out)
        out = self.glu(out)
        out = self.depthwise_conv(out)
        out = self.batch_norm(out)
        out = self.swish(out)
        out = self.pointwise_conv2(out)
        out = self.dropout(out)
        
        out = out.transpose(1, 2) # [B, T, C]
        return out

class ConformerBlock(nn.Module):
    def __init__(self, dim, n_head, conv_kernel_size=31, dropout=0.1):
        super().__init__()
        self.ff1 = FeedForwardModule(dim, dropout=dropout)
        self.self_attn_layer_norm = nn.LayerNorm(dim)
        self.self_attn = nn.MultiheadAttention(dim, n_head, dropout=dropout, batch_first=True)
        self.conv_module = ConvolutionModule(dim, kernel_size=conv_kernel_size, dropout=dropout)
        self.ff2 = FeedForwardModule(dim, dropout=dropout)
        self.final_layer_norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x: [B, T, C]
        # FF1 (Half Step)
        x = x + 0.5 * self.ff1(x)
        
        # Self Attention
        residual = x
        x_norm = self.self_attn_layer_norm(x)
        attn_out, _ = self.self_attn(x_norm, x_norm, x_norm)
        x = residual + self.dropout(attn_out)
        
        # Convolution
        x = x + self.conv_module(x)
        
        # FF2 (Half Step)
        x = x + 0.5 * self.ff2(x)
        
        # Final Norm
        x = self.final_layer_norm(x)
        return x

class ConformerEncoder(nn.Module):
    def __init__(self, input_dim, encoder_dim, n_layers, n_head, output_dim):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, encoder_dim)
        self.layers = nn.ModuleList([
            ConformerBlock(encoder_dim, n_head) for _ in range(n_layers)
        ])
        self.output_proj = nn.Linear(encoder_dim, output_dim)

    def forward(self, x):
        # x: [B, T, InputDim]
        x = self.input_proj(x)
        for layer in self.layers:
            x = layer(x)
        x = self.output_proj(x)
        return nn.functional.log_softmax(x, dim=2)

# Model Config
INPUT_DIM = 512
ENCODER_DIM = 256
N_LAYERS = 4
N_HEAD = 4
OUTPUT_DIM = 41 # 40 phonemes + 1 blank

model = ConformerEncoder(INPUT_DIM, ENCODER_DIM, N_LAYERS, N_HEAD, OUTPUT_DIM).to(CFG.DEVICE)
print(f"Conformer Model Initialized with {sum(p.numel() for p in model.parameters())} parameters")

Conformer Model Initialized with 6233641 parameters


In [None]:
# --- Training & Evaluation ---

VOCAB = [
    'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 
    'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 
    'ZH', '|'
]
TOKEN_MAP = {i + 1: phoneme for i, phoneme in enumerate(VOCAB)}
TOKEN_MAP[0] = ""

def greedy_decoder(logits, token_map):
    pred_indices = torch.argmax(logits, dim=-1)
    collapsed_indices = torch.unique_consecutive(pred_indices)
    final_indices = [idx.item() for idx in collapsed_indices if idx.item() != 0]
    phonemes = [token_map.get(i, "?") for i in final_indices]
    return " ".join(phonemes)

def train_one_epoch(epoch, model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for x, y, x_lengths, y_lengths in tqdm(train_loader, desc=f"Epoch {epoch} [Train]", leave=False):
        x, y, x_lengths, y_lengths = x.to(CFG.DEVICE), y.to(CFG.DEVICE), x_lengths.to(CFG.DEVICE), y_lengths.to(CFG.DEVICE)
        optimizer.zero_grad()
        y_pred = model(x)
        y_pred_for_loss = y_pred.permute(1, 0, 2)
        loss = criterion(y_pred_for_loss, y, x_lengths, y_lengths)
        if torch.isnan(loss) or torch.isinf(loss): continue
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
    return running_loss / len(train_loader.dataset)

def validate_one_epoch(epoch, model, val_loader, criterion, token_map):
    model.eval()
    val_loss = 0.0
    all_pred = []
    all_true = []
    with torch.no_grad():
        for x, y, x_lengths, y_lengths in tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False):
            x, y, x_lengths, y_lengths = x.to(CFG.DEVICE), y.to(CFG.DEVICE), x_lengths.to(CFG.DEVICE), y_lengths.to(CFG.DEVICE)
            y_pred = model(x)
            y_pred_for_loss = y_pred.permute(1, 0, 2)
            loss = criterion(y_pred_for_loss, y, x_lengths, y_lengths)
            val_loss += loss.item() * x.size(0)
            
            for i in range(x.size(0)):
                pred_logits = y_pred[i, :x_lengths[i], :]
                true_indices = y[i, :y_lengths[i]]
                pred_text = greedy_decoder(pred_logits, token_map)
                true_text = " ".join([token_map.get(idx.item(), "?") for idx in true_indices])
                all_pred.append(pred_text)
                all_true.append(true_text)
                
    wer = jiwer.wer(all_true, all_pred)
    return val_loss / len(val_loader.dataset), wer

criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.LR)

print("Starting Training...")
for epoch in range(1, CFG.EPOCHS + 1):
    train_loss = train_one_epoch(epoch, model, train_loader, criterion, optimizer)
    val_loss, wer = validate_one_epoch(epoch, model, val_loader, criterion, TOKEN_MAP)
    print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | WER: {wer:.4f}")

Starting Training...


Epoch 1 [Train]:   0%|          | 0/253 [00:00<?, ?it/s]