# Brain-to-Text '25 - Pretrained Baseline Inference

This notebook uses the **pretrained RNN baseline** (6-12% PER) for inference.

**Steps:**
1. Upload data + checkpoint to Drive
2. Run inference on test set
3. Generate submission CSV

In [None]:
# Install dependencies
!pip install -q jiwer pyyaml

In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# UPDATE THESE PATHS
DATA_DIR = '/content/drive/MyDrive/hdf5_data_final'
CHECKPOINT_PATH = '/content/drive/MyDrive/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/best_checkpoint'

In [None]:
import os
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from scipy.ndimage import gaussian_filter1d

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')

In [None]:
# Model Configuration (from args.yaml)
class CFG:
    n_input_features = 512
    n_units = 768
    n_layers = 5
    n_classes = 41
    rnn_dropout = 0.4
    input_layer_dropout = 0.2
    patch_size = 14
    patch_stride = 4
    smooth_kernel_std = 2

# Session mapping - must match training order!
SESSIONS = [
    't15.2023.08.11', 't15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20',
    't15.2023.08.25', 't15.2023.08.27', 't15.2023.09.01', 't15.2023.09.03',
    't15.2023.09.24', 't15.2023.09.29', 't15.2023.10.01', 't15.2023.10.06',
    't15.2023.10.08', 't15.2023.10.13', 't15.2023.10.15', 't15.2023.10.20',
    't15.2023.10.22', 't15.2023.11.03', 't15.2023.11.04', 't15.2023.11.17',
    't15.2023.11.19', 't15.2023.11.26', 't15.2023.12.03', 't15.2023.12.08',
    't15.2023.12.10', 't15.2023.12.17', 't15.2023.12.29', 't15.2024.02.25',
    't15.2024.03.03', 't15.2024.03.08', 't15.2024.03.15', 't15.2024.03.17',
    't15.2024.04.25', 't15.2024.04.28', 't15.2024.05.10', 't15.2024.06.14',
    't15.2024.07.19', 't15.2024.07.21', 't15.2024.07.28', 't15.2025.01.10',
    't15.2025.01.12', 't15.2025.03.14', 't15.2025.03.16', 't15.2025.03.30',
    't15.2025.04.13'
]
SESSION_TO_ID = {s: i for i, s in enumerate(SESSIONS)}

In [None]:
# GRU Decoder Model (exact architecture from training log)
class GRUDecoder(nn.Module):
    def __init__(self, n_days=45):
        super().__init__()
        
        # Day-specific layers (adapters)
        self.day_weights = nn.ParameterList([
            nn.Parameter(torch.eye(CFG.n_input_features)) for _ in range(n_days)
        ])
        self.day_biases = nn.ParameterList([
            nn.Parameter(torch.zeros(1, CFG.n_input_features)) for _ in range(n_days)
        ])
        self.day_layer_activation = nn.Softsign()
        self.day_layer_dropout = nn.Dropout(CFG.input_layer_dropout)
        
        # GRU input size = 512 * 14 = 7168 (patch_size=14)
        gru_input_size = CFG.n_input_features * CFG.patch_size
        
        self.gru = nn.GRU(
            input_size=gru_input_size,
            hidden_size=CFG.n_units,
            num_layers=CFG.n_layers,
            batch_first=True,
            dropout=CFG.rnn_dropout
        )
        self.out = nn.Linear(CFG.n_units, CFG.n_classes)
    
    def forward(self, x, day_idx):
        # x: [B, T, 512]
        B, T, D = x.shape
        
        # Apply day-specific transformation
        W = self.day_weights[day_idx]  # [512, 512]
        b = self.day_biases[day_idx]   # [1, 512]
        x = torch.matmul(x, W) + b
        x = self.day_layer_activation(x)
        x = self.day_layer_dropout(x)
        
        # Patch embedding: unfold with patch_size=14, patch_stride=4
        # Creates overlapping patches
        patches = x.unfold(1, CFG.patch_size, CFG.patch_stride)  # [B, num_patches, D, patch_size]
        patches = patches.permute(0, 1, 3, 2)  # [B, num_patches, patch_size, D]
        patches = patches.reshape(B, patches.size(1), -1)  # [B, num_patches, patch_size*D]
        
        # GRU
        out, _ = self.gru(patches)
        logits = self.out(out)
        
        return F.log_softmax(logits, dim=-1)

In [None]:
# Phoneme vocabulary
VOCAB = ['', 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER',
         'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
         'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 'ZH', '|']

def greedy_decode(logits):
    """CTC greedy decoding"""
    pred = logits.argmax(dim=-1)  # [T]
    result = []
    prev = -1
    for p in pred:
        if p != prev and p != 0:  # Skip blanks and repeats
            result.append(VOCAB[p])
        prev = p
    return ' '.join(result)

def preprocess(x, sigma=2):
    """Gaussian smoothing"""
    if sigma > 0:
        x = gaussian_filter1d(x, sigma=sigma, axis=0)
    return x

In [None]:
# Load model
print('Loading pretrained model...')
model = GRUDecoder(n_days=len(SESSIONS)).to(device)

# Load checkpoint
ckpt = torch.load(CHECKPOINT_PATH, map_location=device)

# Handle torch.compile wrapper
if 'model_state_dict' in ckpt:
    state_dict = ckpt['model_state_dict']
else:
    state_dict = ckpt

# Remove '_orig_mod.' prefix if present (from torch.compile)
new_state_dict = {}
for k, v in state_dict.items():
    if k.startswith('_orig_mod.'):
        new_state_dict[k[10:]] = v
    else:
        new_state_dict[k] = v

model.load_state_dict(new_state_dict, strict=False)
model.eval()
print(f'Model loaded! {sum(p.numel() for p in model.parameters()):,} parameters')

In [None]:
# Run inference on test set
print('Running inference on test data...')
results = []

for session_folder in tqdm(sorted(os.listdir(DATA_DIR))):
    session_path = os.path.join(DATA_DIR, session_folder)
    if not os.path.isdir(session_path):
        continue
    
    test_file = os.path.join(session_path, 'data_test.hdf5')
    if not os.path.exists(test_file):
        continue
    
    # Get session ID
    session_id = SESSION_TO_ID.get(session_folder, 0)
    
    with h5py.File(test_file, 'r') as f:
        for trial_key in sorted(f.keys()):
            # Load and preprocess
            x = f[trial_key]['input_features'][:]
            x = preprocess(x, sigma=CFG.smooth_kernel_std)
            x = torch.tensor(x, dtype=torch.float32).unsqueeze(0).to(device)
            
            # Inference
            with torch.no_grad():
                logits = model(x, session_id)
            
            # Decode
            pred_text = greedy_decode(logits[0])
            
            results.append({
                'id': f'{session_folder}_{trial_key}',
                'transcription': pred_text
            })

print(f'Generated {len(results)} predictions')

In [None]:
# Create submission
df = pd.DataFrame(results)
print(f'Submission shape: {df.shape}')
print(df.head(10))

# Save
df.to_csv('/content/submission.csv', index=False)
print('\nSaved to /content/submission.csv')
print('\nDownload this file and submit to the competition!')

In [None]:
# Optional: View sample predictions
print('\nSample predictions:')
for i in range(min(5, len(results))):
    print(f"  {results[i]['id']}: {results[i]['transcription'][:80]}...")