In [1]:
import os
from pathlib import Path
import h5py
import pandas as pd

# ---------- 1. helper to load ONE .hdf5 file (same as Anna) ----------
def load_h5py_file(file_path):
    data = {
        'neural_features': [],
        'n_time_steps': [],
        'seq_class_ids': [],
        'seq_len': [],
        'transcriptions': [],
        'sentence_label': [],
        'session': [],
        'block_num': [],
        'trial_num': [],
    }
    with h5py.File(file_path, 'r') as f:
        for key in f.keys():
            g = f[key]

            neural_features = g['input_features'][:]
            n_time_steps = g.attrs['n_time_steps']
            seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None
            seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None
            transcription = g['transcription'][:] if 'transcription' in g else None
            sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None
            session = g.attrs['session']
            block_num = g.attrs['block_num']
            trial_num = g.attrs['trial_num']

            data['neural_features'].append(neural_features)
            data['n_time_steps'].append(n_time_steps)
            data['seq_class_ids'].append(seq_class_ids)
            data['seq_len'].append(seq_len)
            data['transcriptions'].append(transcription)
            data['sentence_label'].append(sentence_label)
            data['session'].append(session)
            data['block_num'].append(block_num)
            data['trial_num'].append(trial_num)
    return data

# ---------- 2. walk local /input like Kaggle walks /kaggle/input ----------
INPUT_ROOT = Path("..") / "input"       # this folder contains "brain-to-text-25"

data_dict = {}

for dirname, _, filenames in os.walk(INPUT_ROOT):
    for filename in filenames:
        if filename.endswith(".hdf5"):
            # exactly the same key logic as Anna’s file :contentReference[oaicite:0]{index=0}
            dict_key_util = "_".join([
                filename.split('.')[0].split('_')[-1],   # train / val / test
                dirname.split(os.sep)[-1]               # t15.2023.09.29
            ])

            file_path = os.path.join(dirname, filename)
            print("Loading:", dict_key_util)
            data_dict[dict_key_util] = load_h5py_file(file_path)

print("Keys in data_dict:", list(data_dict.keys())[:10])


Loading: train_t15.2023.08.11
Loading: test_t15.2023.08.13
Loading: train_t15.2023.08.13
Loading: val_t15.2023.08.13
Loading: test_t15.2023.08.18
Loading: train_t15.2023.08.18
Loading: val_t15.2023.08.18
Loading: test_t15.2023.08.20
Loading: train_t15.2023.08.20
Loading: val_t15.2023.08.20
Loading: test_t15.2023.08.25
Loading: train_t15.2023.08.25
Loading: val_t15.2023.08.25
Loading: test_t15.2023.08.27
Loading: train_t15.2023.08.27
Loading: val_t15.2023.08.27
Loading: test_t15.2023.09.01
Loading: train_t15.2023.09.01
Loading: val_t15.2023.09.01
Loading: test_t15.2023.09.03
Loading: train_t15.2023.09.03
Loading: val_t15.2023.09.03
Loading: test_t15.2023.09.24
Loading: train_t15.2023.09.24
Loading: val_t15.2023.09.24
Loading: test_t15.2023.09.29
Loading: train_t15.2023.09.29
Loading: val_t15.2023.09.29
Loading: test_t15.2023.10.01
Loading: train_t15.2023.10.01
Loading: val_t15.2023.10.01
Loading: test_t15.2023.10.06
Loading: train_t15.2023.10.06
Loading: val_t15.2023.10.06
Loading: test

In [2]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

LOGIT_TO_PHONEME = [
'BLANK',    # "BLANK" = CTC blank symbol
'AA', 'AE', 'AH', 'AO', 'AW',
'AY', 'B', 'CH', 'D', 'DH',
'EH', 'ER', 'EY', 'F', 'G',
'HH', 'IH', 'IY', 'JH', 'K',
'L', 'M', 'N', 'NG', 'OW',
'OY', 'P', 'R', 'S', 'SH',
'T', 'TH', 'UH', 'UW', 'V',
'W', 'Y', 'Z', 'ZH',
' | ',    # "|" = silence token
]
# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [3]:
import h5py

def load_h5py_file(file_path):
    data = {
        'neural_features': [],
        'n_time_steps': [],
        'seq_class_ids': [],
        'seq_len': [],
        'transcriptions': [],
        'sentence_label': [],
        'session': [],
        'block_num': [],
        'trial_num': [],
    }
    # Open the hdf5 file for that day
    with h5py.File(file_path, 'r') as f:

        keys = list(f.keys())

        # For each trial in the selected trials on that day
        for key in keys:
            g = f[key]

            neural_features = g['input_features'][:]
            n_time_steps = g.attrs['n_time_steps']
            seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None
            seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None
            transcription = g['transcription'][:] if 'transcription' in g else None
            sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None
            session = g.attrs['session']
            block_num = g.attrs['block_num']
            trial_num = g.attrs['trial_num']

            data['neural_features'].append(neural_features)
            data['n_time_steps'].append(n_time_steps)
            data['seq_class_ids'].append(seq_class_ids)
            data['seq_len'].append(seq_len)
            data['transcriptions'].append(transcription)
            data['sentence_label'].append(sentence_label)
            data['session'].append(session)
            data['block_num'].append(block_num)
            data['trial_num'].append(trial_num)
    return data


In [4]:
data_dict = {}
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        if ('.hdf5' in filename):
            dict_key_util = '_'.join([filename.split('.')[0].split('_')[-1], dirname.split('/')[-1]])
            data_dict[dict_key_util] = load_h5py_file(os.path.join(dirname, filename))

* neural_features: Temporally binned (20 ms) neural features for each trial (512 X T).
* n_time_steps: Number of time steps per trial.
* seq_class_ids: Integer phoneme sequence labels for each trial. Integers correspond to phonemes using the mapping according to **LOGIT_TO_PHONEME**
* seq_len: Number of phoneme labels per trial.
* transcriptions: ASCII representation of sentence label for each trial.
* sentence_label: Raw text sentence label for each trial.
* session: Date that the trial's data was collected. Each date has a number of blocks, each block has a number of trials.
* block_num: Research block number that the trial is sourced from.
* trial_num: Trial number that the trial is sourced from.

In [5]:
import os
from pathlib import Path
import h5py
import pandas as pd

def load_h5py_file(file_path):
    data = {
        'neural_features': [],
        'n_time_steps': [],
        'seq_class_ids': [],
        'seq_len': [],
        'transcriptions': [],
        'sentence_label': [],
        'session': [],
        'block_num': [],
        'trial_num': [],
    }
    with h5py.File(file_path, 'r') as f:
        for key in f.keys():
            g = f[key]

            neural_features = g['input_features'][:]
            n_time_steps = g.attrs['n_time_steps']
            seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None
            seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None
            transcription = g['transcription'][:] if 'transcription' in g else None
            sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None
            session = g.attrs['session']
            block_num = g.attrs['block_num']
            trial_num = g.attrs['trial_num']

            data['neural_features'].append(neural_features)
            data['n_time_steps'].append(n_time_steps)
            data['seq_class_ids'].append(seq_class_ids)
            data['seq_len'].append(seq_len)
            data['transcriptions'].append(transcription)
            data['sentence_label'].append(sentence_label)
            data['session'].append(session)
            data['block_num'].append(block_num)
            data['trial_num'].append(trial_num)
    return data

# ✅ root of your local data (equivalent to /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final)
DATA_ROOT = Path("..") / "input" / "brain-to-text-25" / "t15_copyTask_neuralData" / "hdf5_data_final"

print("DATA_ROOT:", DATA_ROOT.resolve())

data_dict = {}

for dirname, _, filenames in os.walk(DATA_ROOT):
    for filename in filenames:
        if filename.endswith(".hdf5"):
            # 'train' / 'val' / 'test'
            split_name = filename.split('.')[0].split('_')[-1]
            # folder name like 't15.2023.09.29'
            date_name = os.path.basename(dirname)
            dict_key_util = "_".join([split_name, date_name])

            file_path = os.path.join(dirname, filename)
            print("Loading:", dict_key_util)
            data_dict[dict_key_util] = load_h5py_file(file_path)

print("Number of keys:", len(data_dict))
print("First 10 keys:", list(data_dict.keys())[:10])


DATA_ROOT: C:\Users\btlim\OneDrive\Desktop\brain_to_text_local\input\brain-to-text-25\t15_copyTask_neuralData\hdf5_data_final
Loading: train_t15.2023.08.11
Loading: test_t15.2023.08.13
Loading: train_t15.2023.08.13
Loading: val_t15.2023.08.13
Loading: test_t15.2023.08.18
Loading: train_t15.2023.08.18
Loading: val_t15.2023.08.18
Loading: test_t15.2023.08.20
Loading: train_t15.2023.08.20
Loading: val_t15.2023.08.20
Loading: test_t15.2023.08.25
Loading: train_t15.2023.08.25
Loading: val_t15.2023.08.25
Loading: test_t15.2023.08.27
Loading: train_t15.2023.08.27
Loading: val_t15.2023.08.27
Loading: test_t15.2023.09.01
Loading: train_t15.2023.09.01
Loading: val_t15.2023.09.01
Loading: test_t15.2023.09.03
Loading: train_t15.2023.09.03
Loading: val_t15.2023.09.03
Loading: test_t15.2023.09.24
Loading: train_t15.2023.09.24
Loading: val_t15.2023.09.24
Loading: test_t15.2023.09.29
Loading: train_t15.2023.09.29
Loading: val_t15.2023.09.29
Loading: test_t15.2023.10.01
Loading: train_t15.2023.10.01
Lo

In [6]:
print(data_dict.keys())
#test_df = pd.DataFrame.from_dict(data_dict['data_test.hdf5'])
#train_df = pd.DataFrame.from_dict(data_dict['data_train.hdf5'])
#val_df = pd.DataFrame.from_dict(data_dict['data_val.hdf5'])

test_df = pd.DataFrame()
train_df = pd.DataFrame()
val_df = pd.DataFrame()

for k in data_dict.keys():
    if 'train' in k:
        train_df = pd.concat([train_df, pd.DataFrame.from_dict(data_dict[k])], ignore_index=True)
    elif 'val' in k:
        val_df = pd.concat([val_df, pd.DataFrame.from_dict(data_dict[k])], ignore_index = True)
    elif 'test' in k:
        test_df = pd.concat([test_df, pd.DataFrame.from_dict(data_dict[k])], ignore_index=True)

dict_keys(['train_t15.2023.08.11', 'test_t15.2023.08.13', 'train_t15.2023.08.13', 'val_t15.2023.08.13', 'test_t15.2023.08.18', 'train_t15.2023.08.18', 'val_t15.2023.08.18', 'test_t15.2023.08.20', 'train_t15.2023.08.20', 'val_t15.2023.08.20', 'test_t15.2023.08.25', 'train_t15.2023.08.25', 'val_t15.2023.08.25', 'test_t15.2023.08.27', 'train_t15.2023.08.27', 'val_t15.2023.08.27', 'test_t15.2023.09.01', 'train_t15.2023.09.01', 'val_t15.2023.09.01', 'test_t15.2023.09.03', 'train_t15.2023.09.03', 'val_t15.2023.09.03', 'test_t15.2023.09.24', 'train_t15.2023.09.24', 'val_t15.2023.09.24', 'test_t15.2023.09.29', 'train_t15.2023.09.29', 'val_t15.2023.09.29', 'test_t15.2023.10.01', 'train_t15.2023.10.01', 'val_t15.2023.10.01', 'test_t15.2023.10.06', 'train_t15.2023.10.06', 'val_t15.2023.10.06', 'test_t15.2023.10.08', 'train_t15.2023.10.08', 'val_t15.2023.10.08', 'test_t15.2023.10.13', 'train_t15.2023.10.13', 'val_t15.2023.10.13', 'test_t15.2023.10.15', 'train_t15.2023.10.15', 'val_t15.2023.10.15',

In [7]:
print("len(data_dict):", len(data_dict))
print("first 10 keys:", list(data_dict.keys())[:10])

print("Train shape:", train_df.shape)
print("Val shape:",   val_df.shape)
print("Test shape:",  test_df.shape)

print("Sessions (train):", train_df['session'].nunique())
print("Sessions (val):",   val_df['session'].nunique())
print("Sessions (test):",  test_df['session'].nunique())


len(data_dict): 127
first 10 keys: ['train_t15.2023.08.11', 'test_t15.2023.08.13', 'train_t15.2023.08.13', 'val_t15.2023.08.13', 'test_t15.2023.08.18', 'train_t15.2023.08.18', 'val_t15.2023.08.18', 'test_t15.2023.08.20', 'train_t15.2023.08.20', 'val_t15.2023.08.20']
Train shape: (8072, 9)
Val shape: (1426, 9)
Test shape: (1450, 9)
Sessions (train): 45
Sessions (val): 41
Sessions (test): 41


In [8]:
print(train_df.describe())
print(val_df.describe())
print(test_df.describe())
print(train_df.columns)
print(val_df.columns)
#512 neural features (2 features [-4.5 RMS threshold crossings and spike band power] per electrode, 
#256 electrodes), binned at 20 ms resolution. The data were recorded from the speech motor cortex 
#via four high-density microelectrode arrays (64 electrodes each). The 512 features are ordered 
#as follows in all data files:
#0-64: ventral 6v threshold crossings
#65-128: area 4 threshold crossings
#129-192: 55b threshold crossings
#193-256: dorsal 6v threshold crossings
#257-320: ventral 6v spike band power
#321-384: area 4 spike band power
#385-448: 55b spike band power
#449-512: dorsal 6v spike band power
#for i in range(train_df['neural_features'].shape[0]):
    #print(train_df['neural_features'][i].shape) #trialx(n)x512, with n = n_time_steps
print(train_df.head())
for c in ['n_time_steps', 'seq_len', 'session', 'block_num', 'trial_num']:
    print("_______TRAIN_______")
    print(train_df[c].value_counts())
    print(train_df[c].nunique())
    print("_____VALIDATION_____")
    print(val_df[c].value_counts())
    print(val_df[c].nunique())
    print("_______TEST_______")
    print(test_df[c].value_counts())
    print(test_df[c].nunique())

       n_time_steps      seq_len    block_num    trial_num
count   8072.000000  8072.000000  8072.000000  8072.000000
mean     874.840560    26.541625     4.500248    22.602081
std      308.298035     9.154736     2.871122    14.403206
min      138.000000     3.000000     1.000000     0.000000
25%      655.000000    20.000000     2.000000    10.000000
50%      836.000000    26.000000     4.000000    21.000000
75%     1050.000000    33.000000     6.000000    35.000000
max     2475.000000   110.000000    14.000000    49.000000
       n_time_steps      seq_len    block_num    trial_num
count   1426.000000  1426.000000  1426.000000  1426.000000
mean     922.128331    29.026648     6.969144    11.048387
std      318.502705     9.288702     2.416122     7.094434
min      297.000000     8.000000     1.000000     0.000000
25%      695.000000    22.000000     6.000000     5.000000
50%      890.500000    29.000000     7.000000    11.000000
75%     1094.000000    35.000000     9.000000    17.0000

In [9]:
for i in train_df['seq_class_ids'][3][:train_df['seq_len'][3]]:
    print(LOGIT_TO_PHONEME[i])
print(train_df['seq_class_ids'][3][:train_df['seq_len'][3]])
print(train_df['sentence_label'][3])
print(train_df['transcriptions'][3][:train_df['seq_len'][3]])
for i in train_df['transcriptions'][3][:train_df['seq_len'][3]]:
    print(chr(i))
for i in train_df['seq_class_ids'][2][:train_df['seq_len'][2]]:
    print(LOGIT_TO_PHONEME[i])
print(train_df['seq_class_ids'][2][:train_df['seq_len'][2]])
print(train_df['sentence_label'][2])
print(train_df['transcriptions'][2][:train_df['seq_len'][2]])
for i in train_df['transcriptions'][2][:train_df['seq_len'][2]]:
    print(chr(i))

HH
AW
 | 
IH
Z
 | 
DH
AE
T
 | 
G
UH
D
 | 
[16  5 40 17 38 40 10  2 31 40 15 33  9 40]
How is that good?
[ 72 111 119  32 105 115  32 116 104  97 116  32 103 111]
H
o
w
 
i
s
 
t
h
a
t
 
g
o
W
AH
T
 | 
D
UW
 | 
DH
EY
 | 
L
AY
K
 | 
[36  3 31 40  9 34 40 10 13 40 21  6 20 40]
What do they like?
[ 87 104  97 116  32 100 111  32 116 104 101 121  32 108]
W
h
a
t
 
d
o
 
t
h
e
y
 
l


In [10]:
print(train_df['neural_features'][0][0][:])
print(train_df['n_time_steps'][:]*.02)
for c in train_df.columns:
    print(c)
    print(train_df[c].dtypes)

[ 2.3076649  -0.78699756 -0.64687246 -0.5465877   0.25500455 -0.37754795
 -0.31888878 -0.43742913 -0.552158   -0.6198629  -0.2722918   1.0336585
  0.27755055 -0.666718   -0.6310771  -0.78702307 -0.59295505  0.8399486
  0.85698867  0.34778863 -0.7790861  -0.843797    0.7688951  -0.77096075
 -0.437959    2.4420617   0.29788262 -0.60283345 -0.4174294  -0.5666092
 -0.7893149   0.16854596 -0.54237986 -0.3172244  -0.7257092  -0.6197359
  0.15344943 -0.38841027 -0.76966536 -0.27819777  0.53903705 -0.6611451
 -0.75468504 -0.6057648  -0.73880273 -0.4290835  -0.94724435  0.6137742
 -0.5463916  -0.41273063 -0.8792733   0.97371936  1.0565114  -0.5758789
  2.4920397  -0.42387414 -0.22300059 -0.4641142  -0.5492305  -0.46833596
  3.2247906  -0.11923911  1.5838451  -0.09712851 -0.7942411  -0.5100695
 -0.68580496 -0.5852925  -0.603436    0.50690377 -0.30034852  4.1031313
 -0.6948983  -0.5562878   1.5018272  -0.35887727 -0.5901224  -0.63912416
 -0.52937365 -0.16325521  0.74647266 -0.13449022 -0.15371595

In [11]:
print(train_df['transcriptions'][0].shape)
print(train_df['seq_class_ids'][0].shape)

(500,)
(500,)


### Time to get dirty ;)

In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

EXPERIMENT_NAME = "bilstm_large_sched"
CKPT_PATH = f"rnn_ctc_best_{EXPERIMENT_NAME}.pth"

# ===== Basic config =====
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
FEAT_DIM = 512                          # 512 neural features
BLANK_ID = 0                            # 'BLANK' is first in LOGIT_TO_PHONEME
NUM_CLASSES = len(LOGIT_TO_PHONEME)     # include BLANK

BATCH_SIZE   = 16
EPOCHS       = 30
BASE_LR      = 5e-3
WEIGHT_DECAY = 1e-4
PATIENCE     = 5                        # early stopping on val PER

print("Device:", DEVICE)
print("NUM_CLASSES:", NUM_CLASSES)


Device: cuda
NUM_CLASSES: 41


In [13]:
import math
import numpy as np

class BrainCTCDataset(Dataset):
    """
    Wraps train_df / val_df.
    
    Expected columns:
      - 'neural_features': (T, 512) numpy array
      - 'n_time_steps'   : int T
      - 'seq_class_ids'  : array/list of phoneme IDs
      - 'seq_len'        : int L (valid target length)
    """
    def __init__(self, df):
        df = df.copy()
        # keep only rows with labels
        df = df[df["seq_len"].notnull()]
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        feats = row["neural_features"]        # (T, 512)
        T = int(row["n_time_steps"])
        L = int(row["seq_len"])

        # convert to tensor
        if isinstance(feats, np.ndarray):
            feats = torch.from_numpy(feats).float()
        else:
            feats = feats.float()

        assert feats.shape[0] == T
        assert feats.shape[1] == FEAT_DIM

        seq_ids = row["seq_class_ids"]
        labels = torch.as_tensor(seq_ids[:L], dtype=torch.long)

        sample = {
            "feats": feats,      # (T, 512)
            "T": T,
            "labels": labels,    # (L,)
            "L": L,
        }
        return sample


def ctc_collate_fn(batch):
    """
    Pads inputs and concatenates labels for CTC.
    Returns:
      padded_feats: (B, max_T, 512)
      input_lengths: (B,)
      flat_labels: (sum_L,)
      label_lengths: (B,)
    """
    B = len(batch)
    T_list = [b["T"] for b in batch]
    L_list = [b["L"] for b in batch]
    max_T = max(T_list)

    padded_feats = torch.zeros(B, max_T, FEAT_DIM, dtype=torch.float32)
    flat_labels_list = []

    for i, b in enumerate(batch):
        T = b["T"]
        padded_feats[i, :T] = b["feats"]
        if b["L"] > 0:
            flat_labels_list.append(b["labels"])

    if flat_labels_list:
        flat_labels = torch.cat(flat_labels_list, dim=0)
    else:
        flat_labels = torch.zeros(0, dtype=torch.long)

    input_lengths = torch.as_tensor(T_list, dtype=torch.long)
    label_lengths = torch.as_tensor(L_list, dtype=torch.long)

    return padded_feats, input_lengths, flat_labels, label_lengths


In [14]:
train_ds = BrainCTCDataset(train_df)
val_ds   = BrainCTCDataset(val_df)

train_loader = DataLoader(
    train_ds,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=ctc_collate_fn,
    num_workers=0,
)

val_loader = DataLoader(
    val_ds,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=ctc_collate_fn,
    num_workers=0,
)

print("len(train_ds) =", len(train_ds))
print("len(val_ds)   =", len(val_ds))

# quick sanity check: get one batch
batch = next(iter(train_loader))
for x in batch:
    print(type(x), getattr(x, "shape", None))


len(train_ds) = 8072
len(val_ds)   = 1426
<class 'torch.Tensor'> torch.Size([16, 948, 512])
<class 'torch.Tensor'> torch.Size([16])
<class 'torch.Tensor'> torch.Size([381])
<class 'torch.Tensor'> torch.Size([16])


In [15]:
class RNNCTCEncoder(nn.Module):
    """
    Larger BiLSTM encoder + linear CTC head.
    """
    def __init__(self, feat_dim=FEAT_DIM, hidden=384, num_layers=3, dropout=0.25):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=feat_dim,
            hidden_size=hidden,
            num_layers=num_layers,
            dropout=dropout if num_layers > 1 else 0.0,
            bidirectional=True,
        )
        self.proj = nn.Linear(2 * hidden, NUM_CLASSES)

    def forward(self, x, input_lengths):
        out, _ = self.lstm(x)      # (T, B, 2*hidden)
        logits = self.proj(out)    # (T, B, C)
        return logits              # lengths unchanged


model = RNNCTCEncoder().to(DEVICE)
criterion = nn.CTCLoss(blank=BLANK_ID, zero_infinity=True)
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=BASE_LR,
    weight_decay=WEIGHT_DECAY,
)

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    mode="min",
    factor=0.5,
    patience=2,
)

print(model)


RNNCTCEncoder(
  (lstm): LSTM(512, 384, num_layers=3, dropout=0.25, bidirectional=True)
  (proj): Linear(in_features=768, out_features=41, bias=True)
)


In [16]:
def greedy_decode(log_probs, input_lengths, blank_id=BLANK_ID):
    """
    log_probs: (T, B, C)
    input_lengths: (B,)
    Returns: list of list[int], length B
    """
    T, B, C = log_probs.shape
    preds = log_probs.detach().cpu().argmax(dim=-1)  # (T, B)

    decoded = []
    for b in range(B):
        T_b = input_lengths[b].item()
        seq = preds[:T_b, b].tolist()

        # CTC collapse: remove repeats & blanks
        collapsed = []
        prev = None
        for s in seq:
            if s == blank_id:
                prev = None
                continue
            if s != prev:
                collapsed.append(s)
                prev = s
        decoded.append(collapsed)
    return decoded


def phoneme_error_rate(pred_ids, target_ids):
    """
    Classic edit distance / PER.
    pred_ids, target_ids: list[int]
    """
    m, n = len(target_ids), len(pred_ids)
    dp = [[0]*(n+1) for _ in range(m+1)]
    for i in range(m+1):
        dp[i][0] = i
    for j in range(n+1):
        dp[0][j] = j

    for i in range(1, m+1):
        for j in range(1, n+1):
            cost = 0 if target_ids[i-1] == pred_ids[j-1] else 1
            dp[i][j] = min(
                dp[i-1][j] + 1,      # delete
                dp[i][j-1] + 1,      # insert
                dp[i-1][j-1] + cost  # substitute
            )
    return dp[m][n] / max(1, m)


def compute_batch_per(decoded_batch, flat_labels, label_lengths):
    """
    decoded_batch: list[list[int]] length B
    flat_labels: (sum_L,)
    label_lengths: (B,)
    Returns: average PER over batch
    """
    idx = 0
    per_sum = 0.0
    B = len(decoded_batch)

    for b in range(B):
        L_b = label_lengths[b].item()
        target_b = flat_labels[idx: idx + L_b].tolist()
        idx += L_b

        pred_b = decoded_batch[b]
        per_sum += phoneme_error_rate(pred_b, target_b)

    return per_sum / max(1, B)


In [17]:
import torch
from torch.utils.data import Dataset, DataLoader

# --- CORRECTED CONFIGURATION ---
# We reserve 0 for CTC Blank. 
# Your original data has 41 classes (0-40).
# We shift them to 1-41. So Total Classes = 42.
BLANK_ID = 0
NUM_CLASSES = 42  # 41 phonemes + 1 dedicated blank
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class BrainCTCDataset(Dataset):
    def __init__(self, df):
        df = df.copy()
        # Filter rows without length
        df = df[df["seq_len"].notnull()]
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        feats = row["neural_features"]
        T = int(row["n_time_steps"])
        L = int(row["seq_len"])

        # 1. Handle Features
        if isinstance(feats, np.ndarray):
            feats = torch.from_numpy(feats).float()
        else:
            feats = feats.float()

        # 2. Handle Labels (THE CRITICAL FIX)
        seq_ids = row["seq_class_ids"]
        # Take the valid sequence
        raw_labels = torch.as_tensor(seq_ids[:L], dtype=torch.long)
        
        # SHIFT LABELS: 0 becomes 1, 40 becomes 41.
        # Now 0 is free to be the CTC Blank.
        labels = raw_labels + 1 

        sample = {
            "feats": feats,      
            "T": T,
            "labels": labels,    
            "L": L,
        }
        return sample

TRAINING CODE

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import pandas as pd
import numpy as np
import os

# ==========================================
# 1. CONFIGURATION & DATASET
# ==========================================
BATCH_SIZE = 64        # BiLSTM is light, we can use larger batch
EPOCHS = 30
LR_MAX = 2e-3          # Peak learning rate
FEAT_DIM = 512
BLANK_ID = 0
NUM_CLASSES = 42       # 41 phonemes + 1 blank
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {DEVICE}")

class BrainCTCDataset(Dataset):
    def __init__(self, df):
        df = df.copy()
        df = df[df["seq_len"].notnull()]
        self.df = df.reset_index(drop=True)

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        feats = row["neural_features"]
        if isinstance(feats, np.ndarray):
            feats = torch.from_numpy(feats).float()
        else:
            feats = feats.float()
            
        # LABELS: Shift +1 to avoid 0 collision
        seq_ids = row["seq_class_ids"]
        L = int(row["seq_len"])
        raw_labels = torch.as_tensor(seq_ids[:L], dtype=torch.long)
        labels = raw_labels + 1 

        return {
            "feats": feats,
            "T": int(row["n_time_steps"]),
            "labels": labels,
            "L": L
        }

# Re-create loaders with the fixed Dataset
train_ds = BrainCTCDataset(train_df)
val_ds   = BrainCTCDataset(val_df)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,  collate_fn=ctc_collate_fn, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False, collate_fn=ctc_collate_fn, num_workers=0)

# ==========================================
# 2. THE ROBUST MODEL (Plan A)
# ==========================================
class RobustBiLSTM(nn.Module):
    def __init__(self, feat_dim, num_classes, hidden=384, num_layers=4, dropout=0.3):
        super().__init__()
        self.ln = nn.LayerNorm(feat_dim)
        self.projection = nn.Linear(feat_dim, hidden)
        
        self.lstm = nn.LSTM(
            input_size=hidden,
            hidden_size=hidden,
            num_layers=num_layers,
            dropout=dropout,
            bidirectional=True,
            batch_first=False
        )
        self.output_proj = nn.Linear(hidden * 2, num_classes)

        # Initialize Bias to favor Blanks (Stability Hack)
        init_bias = torch.zeros(num_classes)
        init_bias[0] = 2.0 
        self.output_proj.bias.data = init_bias

    def forward(self, x, input_lengths):
        # x: (Time, Batch, Feats)
        x = self.ln(x)
        x = self.projection(x)
        x = F.relu(x)
        
        # Pack for efficiency
        x_packed = nn.utils.rnn.pack_padded_sequence(x, input_lengths.cpu(), enforce_sorted=False)
        out_packed, _ = self.lstm(x_packed)
        out, _ = nn.utils.rnn.pad_packed_sequence(out_packed)
        
        logits = self.output_proj(out)
        return logits

# ==========================================
# 3. UTILS (Decoding & PER)
# ==========================================
def greedy_decode(log_probs, input_lengths):
    T, B, C = log_probs.shape
    preds = log_probs.argmax(dim=-1).cpu()
    decoded = []
    for b in range(B):
        length = input_lengths[b].item()
        seq = preds[:length, b].tolist()
        # Collapse CTC
        collapsed = []
        prev = None
        for p in seq:
            if p != BLANK_ID and p != prev:
                collapsed.append(p)
            prev = p
        decoded.append(collapsed)
    return decoded

def compute_per(ref, hyp):
    # Simple Levenshtein distance
    import editdistance # Ensure this is installed, or use a custom function
    # Fallback if library missing:
    try:
        d = editdistance.eval(ref, hyp)
    except:
        # Quick fallback implementation
        m, n = len(ref), len(hyp)
        dp = [[0]*(n+1) for _ in range(m+1)]
        for i in range(m+1): dp[i][0] = i
        for j in range(n+1): dp[0][j] = j
        for i in range(1, m+1):
            for j in range(1, n+1):
                cost = 0 if ref[i-1] == hyp[j-1] else 1
                dp[i][j] = min(dp[i-1][j]+1, dp[i][j-1]+1, dp[i-1][j-1]+cost)
        d = dp[m][n]
    return d / len(ref) if len(ref) > 0 else 0

# ==========================================
# 4. TRAINING LOOP
# ==========================================
model = RobustBiLSTM(FEAT_DIM, NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LR_MAX, steps_per_epoch=len(train_loader), epochs=EPOCHS, pct_start=0.15
)
criterion = nn.CTCLoss(blank=BLANK_ID, zero_infinity=True)

best_per = 1.0
print("Starting Training...")

for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss = 0
    
    for batch in train_loader:
        feats, in_lens, labels, lbl_lens = batch
        feats = feats.to(DEVICE).transpose(0, 1) # T, B, F
        labels = labels.to(DEVICE)
        lbl_lens = lbl_lens.to(DEVICE)
        in_lens = in_lens.to(DEVICE)
        
        optimizer.zero_grad()
        logits = model(feats, in_lens)
        log_probs = F.log_softmax(logits, dim=2)
        
        loss = criterion(log_probs, labels, in_lens, lbl_lens)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        
    avg_train_loss = train_loss / len(train_loader)
    
    # --- VALIDATION ---
    model.eval()
    val_loss = 0
    val_per = 0
    count = 0
    
    with torch.no_grad():
        for batch in val_loader:
            feats, in_lens, labels, lbl_lens = batch
            feats = feats.to(DEVICE).transpose(0, 1)
            labels = labels.to(DEVICE)
            in_lens = in_lens.to(DEVICE)
            lbl_lens = lbl_lens.to(DEVICE)
            
            logits = model(feats, in_lens)
            log_probs = F.log_softmax(logits, dim=2)
            val_loss += criterion(log_probs, labels, in_lens, lbl_lens).item()
            
            # Compute PER
            decoded_batch = greedy_decode(log_probs, in_lens)
            # Reconstruct targets from flat labels
            idx = 0
            for i, hyp in enumerate(decoded_batch):
                L_target = lbl_lens[i].item()
                ref = labels[idx:idx+L_target].cpu().tolist()
                idx += L_target
                val_per += compute_per(ref, hyp)
                count += 1
                
    avg_val_loss = val_loss / len(val_loader)
    avg_per = val_per / count
    
    print(f"Epoch {epoch}: Train Loss={avg_train_loss:.4f} | Val Loss={avg_val_loss:.4f} | Val PER={avg_per:.4f}")
    
    if avg_per < best_per:
        best_per = avg_per
        torch.save(model.state_dict(), "best_bilstm_robust.pth")
        print(f"  >>> New Best Model Saved! PER: {best_per:.4f}")

Device: cuda
Starting Training...


KeyboardInterrupt: 

In [18]:
# ==========================================
# MISSING UTILITY FUNCTIONS
# ==========================================

def greedy_decode(log_probs, input_lengths):
    """
    Decodes the CTC output by taking the most likely token at each step.
    log_probs: (Time, Batch, Class)
    input_lengths: (Batch,)
    """
    # 1. Get max probability indices
    preds = log_probs.argmax(dim=-1).cpu() 
    
    decoded = []
    for b in range(log_probs.shape[1]): # Iterate over batch
        length = input_lengths[b].item()
        seq = preds[:length, b].tolist()
        
        # 2. CTC Collapse: Remove duplicates & Blanks
        collapsed = []
        prev = None
        for p in seq:
            if p != 0 and p != prev: # 0 is BLANK_ID
                collapsed.append(p)
            prev = p
        decoded.append(collapsed)
    return decoded

def compute_per(ref, hyp):
    """
    Calculates Phoneme Error Rate (Levenshtein Distance / Reference Length).
    """
    # Simple DP implementation of Edit Distance
    m, n = len(ref), len(hyp)
    
    # Create a table to store results of subproblems
    dp = [[0] * (n + 1) for _ in range(m + 1)]
 
    # Fill dp[][] in bottom up manner
    for i in range(m + 1):
        for j in range(n + 1):
            if i == 0:
                dp[i][j] = j    # Min. operations = j
            elif j == 0:
                dp[i][j] = i    # Min. operations = i
            elif ref[i - 1] == hyp[j - 1]:
                dp[i][j] = dp[i - 1][j - 1]
            else:
                dp[i][j] = 1 + min(dp[i][j - 1],      # Insert
                                   dp[i - 1][j],      # Remove
                                   dp[i - 1][j - 1])  # Replace
 
    distance = dp[m][n]
    
    if len(ref) == 0:
        return 0.0 if len(hyp) == 0 else 1.0
        
    return distance / len(ref)

TRAINING CODE 

In [19]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio

# ==========================================
# 1. CONFIGURATION
# ==========================================
BATCH_SIZE = 16          # Conformers are heavy; reduce batch size
EPOCHS = 40              # They need longer to converge
LR_MAX = 5e-4            # Lower LR is safer for Transformers
FEAT_DIM = 512
NUM_CLASSES = 42         # 41 phonemes + 1 blank (Fixed!)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

print(f"Device: {DEVICE}")

# ==========================================
# 2. THE SOTA MODEL (Conformer)
# ==========================================
class ConformerCTC(nn.Module):
    def __init__(self, feat_dim, num_classes, d_model=256, n_head=4, num_layers=6):
        super().__init__()
        
        # 1. Adapter: Compress 512 -> 256
        self.projection = nn.Linear(feat_dim, d_model)
        
        # 2. The Brain: Conformer Encoder
        self.conformer = torchaudio.models.Conformer(
            input_dim=d_model,
            num_heads=n_head,
            ffn_dim=d_model * 4,
            num_layers=num_layers,
            depthwise_conv_kernel_size=31,
            dropout=0.1
        )
        
        # 3. Output Head
        self.output_proj = nn.Linear(d_model, num_classes)
        
        # Bias Init Trick (Stability)
        init_bias = torch.zeros(num_classes)
        init_bias[0] = 2.0 
        self.output_proj.bias.data = init_bias

    def forward(self, x, input_lengths):
        # Input x: (Time, Batch, Feats) -> Standard for our loader
        
        # A. Project & Permute
        x = self.projection(x)  # (T, B, 256)
        x = x.transpose(0, 1)   # (B, T, 256) -> Conformer expects Batch First
        
        # B. Conformer Pass
        # Returns: output, lengths
        out, _ = self.conformer(x, input_lengths)
        
        # C. Permute Back
        out = out.transpose(0, 1) # (T, B, 256)
        
        # D. Project to Classes
        logits = self.output_proj(out)
        return logits

# ==========================================
# 3. DATA AUGMENTATION (SpecAugment)
# ==========================================
# Crucial for Conformers to avoid overfitting
class SpecAugment(nn.Module):
    def __init__(self):
        super().__init__()
        self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)
        self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=50)

    def forward(self, x):
        # x: (Time, Batch, Feats) -> Permute to (Batch, Feats, Time) for Torchaudio
        x = x.permute(1, 2, 0)
        x = self.freq_mask(x)
        x = self.time_mask(x)
        x = x.permute(2, 0, 1) # Back to (Time, Batch, Feats)
        return x

# ==========================================
# 4. TRAINING LOOP (With Schedulers)
# ==========================================
model = ConformerCTC(FEAT_DIM, NUM_CLASSES).to(DEVICE)
augmenter = SpecAugment().to(DEVICE)

# Conformer Needs AdamW + Cosine Scheduler
optimizer = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=1e-2)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LR_MAX, steps_per_epoch=len(train_loader), epochs=EPOCHS, pct_start=0.15
)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

best_per = 1.0
print("Starting Conformer Training...")

for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss = 0
    
    for batch in train_loader:
        feats, in_lens, labels, lbl_lens = batch
        feats = feats.to(DEVICE).transpose(0, 1)
        labels = labels.to(DEVICE)
        lbl_lens = lbl_lens.to(DEVICE)
        in_lens = in_lens.to(DEVICE)
        
        # Apply Augmentation
        feats = augmenter(feats)
        
        optimizer.zero_grad()
        logits = model(feats, in_lens)
        log_probs = F.log_softmax(logits, dim=2)
        
        loss = criterion(log_probs, labels, in_lens, lbl_lens)
        loss.backward()
        
        # Clip Gradients (Essential for Conformer)
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        
    avg_train_loss = train_loss / len(train_loader)
    
    # --- VALIDATION (No Augment) ---
    model.eval()
    val_per = 0
    count = 0
    
    with torch.no_grad():
        for batch in val_loader:
            feats, in_lens, labels, lbl_lens = batch
            feats = feats.to(DEVICE).transpose(0, 1)
            labels = labels.to(DEVICE)
            in_lens = in_lens.to(DEVICE)
            lbl_lens = lbl_lens.to(DEVICE)
            
            logits = model(feats, in_lens)
            log_probs = F.log_softmax(logits, dim=2)
            
            # Simple Greedy Decode for monitoring
            decoded_batch = greedy_decode(log_probs, in_lens)
            idx = 0
            for i, hyp in enumerate(decoded_batch):
                L_target = lbl_lens[i].item()
                ref = labels[idx:idx+L_target].cpu().tolist()
                idx += L_target
                val_per += compute_per(ref, hyp)
                count += 1
                
    avg_per = val_per / count
    print(f"Epoch {epoch}: Loss={avg_train_loss:.4f} | Val PER={avg_per:.4f}")
    
    if avg_per < best_per:
        best_per = avg_per
        torch.save(model.state_dict(), "best_conformer_sota.pth")
        print(f"  >>> New SOTA Model Saved! PER: {best_per:.4f}")

Device: cuda
Starting Conformer Training...
Epoch 1: Loss=3.8296 | Val PER=0.9627
  >>> New SOTA Model Saved! PER: 0.9627
Epoch 2: Loss=2.3755 | Val PER=0.7442
  >>> New SOTA Model Saved! PER: 0.7442
Epoch 3: Loss=1.5732 | Val PER=0.4319
  >>> New SOTA Model Saved! PER: 0.4319
Epoch 4: Loss=1.1913 | Val PER=0.3752
  >>> New SOTA Model Saved! PER: 0.3752
Epoch 5: Loss=1.0470 | Val PER=0.3726
  >>> New SOTA Model Saved! PER: 0.3726
Epoch 6: Loss=0.9905 | Val PER=0.3602
  >>> New SOTA Model Saved! PER: 0.3602
Epoch 7: Loss=1.0354 | Val PER=0.4226
Epoch 8: Loss=1.0560 | Val PER=0.3957
Epoch 9: Loss=1.0124 | Val PER=0.3959
Epoch 10: Loss=0.9706 | Val PER=0.3629
Epoch 11: Loss=0.9913 | Val PER=0.3631
Epoch 12: Loss=0.9638 | Val PER=0.3806
Epoch 13: Loss=0.9231 | Val PER=0.3856
Epoch 14: Loss=1.0044 | Val PER=0.3847
Epoch 15: Loss=1.0976 | Val PER=0.4009
Epoch 16: Loss=0.9424 | Val PER=0.3253
  >>> New SOTA Model Saved! PER: 0.3253
Epoch 17: Loss=0.8646 | Val PER=0.3245
  >>> New SOTA Model S

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
import pandas as pd
import numpy as np
import nltk
from torch.utils.data import Dataset, DataLoader
from pyctcdecode import build_ctcdecoder

# ==========================================
# 1. CONFIGURATION
# ==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MODEL_PATH = "best_conformer_sota.pth"
OUTPUT_FILE = "submission_english_fixed.csv"
FEAT_DIM = 512
NUM_CLASSES = 42

LOGIT_TO_PHONEME = [
    'BLANK', 'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 
    'EH', 'ER', 'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 
    'N', 'NG', 'OW', 'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 
    'V', 'W', 'Y', 'Z', 'ZH', ' | '
]

# ==========================================
# 2. DICTIONARY & FIXING LOGIC
# ==========================================
print("Initializing Dictionary...")
try: nltk.data.find('corpora/cmudict')
except LookupError: nltk.download('cmudict', quiet=True)
from nltk.corpus import cmudict
d = cmudict.dict()

# 1. Build Phoneme Map
VALID_PHONEMES = set(LOGIT_TO_PHONEME[1:]) 
phoneme_to_word = {}
for word, prons in d.items():
    for pron in prons:
        clean_pron = tuple([p.strip('012') for p in pron])
        if clean_pron not in phoneme_to_word:
            phoneme_to_word[clean_pron] = word
        elif len(word) < len(phoneme_to_word[clean_pron]):
            phoneme_to_word[clean_pron] = word

def split_phoneme_string(smushed_str):
    tokens = []
    i = 0
    n = len(smushed_str)
    while i < n:
        match_found = False
        for length in [3, 2, 1]: 
            if i + length <= n:
                chunk = smushed_str[i : i+length]
                if chunk in VALID_PHONEMES:
                    tokens.append(chunk)
                    i += length
                    match_found = True
                    break
        if not match_found: i += 1
    return tuple(tokens)

# 2. Manual Fixes (Crucial for 0.48 score)
COMMON_FIXES = {
    "AY": "I", "EY": "A", "WIHDH": "WITH", "DHAH": "THE", "AHND": "AND",
    "TAYERD": "TIRED", "GEHNT": "GET", "SAHNG": "SONG", "DEHS": "THIS",
    "HHERR": "HER", "YUWR": "YOUR", "THIHNGK": "THINK", "HHAED": "HAD",
    "DHEY": "THEY", "HHAEV": "HAVE", "EHNJHOY": "ENJOY", "MUWVD": "MOVED",
    "SIHTIY": "CITY", "WERK": "WORK", "SIYZ": "SEES", "DHEHRZ": "THERES"
}

def translate_and_fix(phoneme_text):
    if not phoneme_text: return ""
    raw_groups = phoneme_text.split('|')
    sentence = []
    for group in raw_groups:
        clean_str = group.strip()
        if not clean_str: continue
        
        # A. Try Dictionary
        ph_tuple = split_phoneme_string(clean_str)
        if ph_tuple in phoneme_to_word:
            word = phoneme_to_word[ph_tuple]
        else:
            word = clean_str
            
        # B. Apply Fixes
        if word in COMMON_FIXES:
            word = COMMON_FIXES[word]
            
        sentence.append(word)
    return " ".join(sentence)

# ==========================================
# 3. MODEL (Conformer SOTA)
# ==========================================
class ConformerCTC(nn.Module):
    def __init__(self, feat_dim, num_classes, d_model=256, n_head=4, num_layers=6):
        super().__init__()
        self.projection = nn.Linear(feat_dim, d_model)
        self.conformer = torchaudio.models.Conformer(
            input_dim=d_model, num_heads=n_head, ffn_dim=d_model * 4,
            num_layers=num_layers, depthwise_conv_kernel_size=31, dropout=0.1
        )
        self.output_proj = nn.Linear(d_model, num_classes)

    def forward(self, x, input_lengths):
        x = self.projection(x)
        x = x.transpose(0, 1) 
        out, _ = self.conformer(x, input_lengths)
        out = out.transpose(0, 1) 
        return self.output_proj(out)

# ==========================================
# 4. INFERENCE
# ==========================================
class TestDataset(Dataset):
    def __init__(self, df):
        self.df = df.reset_index(drop=True)
    def __len__(self): return len(self.df)
    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        feats = row["neural_features"]
        if isinstance(feats, np.ndarray): feats = torch.from_numpy(feats).float()
        else: feats = feats.float()
        return {"feats": feats, "T": int(row["n_time_steps"])}

print("Loading SOTA Model...")
model = ConformerCTC(FEAT_DIM, NUM_CLASSES).to(DEVICE)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.eval()

# Decoder (Windows Safe: Alpha=0)
vocab_list = [x for x in LOGIT_TO_PHONEME if x != 'BLANK'] + ['BLANK']
decoder = build_ctcdecoder(labels=vocab_list, kenlm_model_path=None, alpha=0.0, beta=1.0)

test_ds = TestDataset(test_df)
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False)

final_predictions = []
print(f"Generating Predictions...")

with torch.no_grad():
    for i, batch in enumerate(test_loader):
        feats = batch['feats'].to(DEVICE).transpose(0, 1)
        in_lens = torch.tensor([batch['T']], device=DEVICE)
        
        # 1. Inference
        logits = model(feats, in_lens)
        probs = F.softmax(logits, dim=2)
        probs_np = probs.transpose(0, 1).cpu().numpy()[0]
        
        # 2. Decode
        probs_adjusted = np.roll(probs_np, -1, axis=-1)
        phoneme_text = decoder.decode(probs_adjusted)
        
        # 3. Translate & Fix
        english_text = translate_and_fix(phoneme_text)
        
        final_predictions.append(english_text)
        
        if i % 100 == 0: print(f"Sample {i}: {english_text}")

# Save
submission = pd.DataFrame({"id": range(len(final_predictions)), "text": final_predictions})
submission.to_csv(OUTPUT_FILE, index=False)
print(f"✅ Saved {OUTPUT_FILE}")

TRAINING CODE 

In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import torchaudio
import numpy as np
from torch.utils.data import DataLoader, Dataset

# ==========================================
# 1. CONFIGURATION
# ==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 32
EPOCHS = 40
LR_MAX = 1e-3
FEAT_DIM = 512
NUM_CLASSES = 42

print(f"Device: {DEVICE}")

# ==========================================
# 2. MISSING CLASS: SpecAugment
# ==========================================
class SpecAugment(nn.Module):
    def __init__(self):
        super().__init__()
        self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)
        self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=50)

    def forward(self, x):
        # Input: (Time, Batch, Feats)
        # Permute for Torchaudio: (Batch, Feats, Time)
        x = x.permute(1, 2, 0)
        x = self.freq_mask(x)
        x = self.time_mask(x)
        # Back to: (Time, Batch, Feats)
        return x.permute(2, 0, 1)

# ==========================================
# 3. MODEL COMPONENTS
# ==========================================
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

class TransformerCTC(nn.Module):
    def __init__(self, feat_dim, num_classes, d_model=256, nhead=4, num_layers=6):
        super().__init__()
        self.input_proj = nn.Linear(feat_dim, d_model)
        self.pos_encoder = PositionalEncoding(d_model)
        encoder_layers = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dim_feedforward=d_model*4, dropout=0.1
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers)
        self.output_proj = nn.Linear(d_model, num_classes)
        
        # Stability Hack
        init_bias = torch.zeros(num_classes)
        init_bias[0] = 2.0 
        self.output_proj.bias.data = init_bias

    def forward(self, x, input_lengths):
        x = self.input_proj(x)
        x = self.pos_encoder(x)
        out = self.transformer_encoder(x)
        return self.output_proj(out)

# ==========================================
# 4. TRAINING SETUP
# ==========================================
# Ensure DataLoaders exist. If 'train_loader' is missing, 
# re-run the cell that defines BrainCTCDataset and train_loader!

model = TransformerCTC(FEAT_DIM, NUM_CLASSES).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=1e-2)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)

# Now this will work because SpecAugment is defined above
augmenter = SpecAugment().to(DEVICE)

# Scheduler
try:
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer, max_lr=LR_MAX, steps_per_epoch=len(train_loader), epochs=EPOCHS, pct_start=0.2
    )
except NameError:
    print("❌ Error: 'train_loader' is missing. Please re-run the cell defining your DataLoaders.")
    raise

# Utils (if missing)
def greedy_decode(log_probs, input_lengths):
    preds = log_probs.argmax(dim=-1).cpu()
    decoded = []
    for b in range(log_probs.shape[1]):
        L = input_lengths[b].item()
        seq = preds[:L, b].tolist()
        flat = []
        prev = None
        for p in seq:
            if p != 0 and p != prev:
                flat.append(p)
            prev = p
        decoded.append(flat)
    return decoded

def compute_per(ref, hyp):
    import editdistance
    d = editdistance.eval(ref, hyp)
    return d / len(ref) if len(ref) > 0 else 0

# ==========================================
# 5. TRAINING LOOP
# ==========================================
print("Starting Transformer Training...")
best_per = 1.0

for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss = 0
    
    for batch in train_loader:
        feats, in_lens, labels, lbl_lens = batch
        feats = feats.to(DEVICE).transpose(0, 1)
        labels = labels.to(DEVICE)
        lbl_lens = lbl_lens.to(DEVICE)
        in_lens = in_lens.to(DEVICE)
        
        # Augment
        feats = augmenter(feats)
        
        optimizer.zero_grad()
        logits = model(feats, in_lens)
        log_probs = F.log_softmax(logits, dim=2)
        
        loss = criterion(log_probs, labels, in_lens, lbl_lens)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        
    # Validation
    model.eval()
    val_per = 0
    count = 0
    with torch.no_grad():
        for batch in val_loader:
            feats, in_lens, labels, lbl_lens = batch
            feats = feats.to(DEVICE).transpose(0, 1)
            labels = labels.to(DEVICE)
            in_lens = in_lens.to(DEVICE)
            lbl_lens = lbl_lens.to(DEVICE)
            
            logits = model(feats, in_lens)
            decoded = greedy_decode(F.log_softmax(logits, dim=2), in_lens)
            
            idx = 0
            for i, hyp in enumerate(decoded):
                L_target = lbl_lens[i].item()
                ref = labels[idx:idx+L_target].cpu().tolist()
                idx += L_target
                val_per += compute_per(ref, hyp)
                count += 1
                
    avg_per = val_per / count
    print(f"Epoch {epoch}: Loss={train_loss/len(train_loader):.4f} | Val PER={avg_per:.4f}")
    
    if avg_per < best_per:
        best_per = avg_per
        torch.save(model.state_dict(), "best_transformer.pth")
        print(f"  >>> New Best Transformer! PER: {best_per:.4f}")

Device: cuda
Starting Transformer Training...
Epoch 1: Loss=3.6254 | Val PER=1.0000
Epoch 2: Loss=2.7583 | Val PER=1.0000
  >>> New Best Transformer! PER: 1.0000
Epoch 3: Loss=2.4018 | Val PER=0.9999
  >>> New Best Transformer! PER: 0.9999
Epoch 4: Loss=2.0975 | Val PER=0.9999
  >>> New Best Transformer! PER: 0.9999
Epoch 5: Loss=1.8835 | Val PER=0.8084
  >>> New Best Transformer! PER: 0.8084
Epoch 6: Loss=1.7686 | Val PER=0.8264
Epoch 7: Loss=1.8621 | Val PER=0.8761
Epoch 8: Loss=1.9263 | Val PER=0.8525
Epoch 9: Loss=1.9562 | Val PER=0.8611
Epoch 10: Loss=1.9670 | Val PER=0.9834
Epoch 11: Loss=1.9892 | Val PER=0.8883
Epoch 12: Loss=1.9398 | Val PER=0.9051
Epoch 13: Loss=1.9295 | Val PER=0.9044
Epoch 14: Loss=1.9167 | Val PER=0.8907
Epoch 15: Loss=1.8776 | Val PER=0.8587
Epoch 16: Loss=1.8850 | Val PER=0.8628
Epoch 17: Loss=1.8346 | Val PER=0.8628
Epoch 18: Loss=1.8048 | Val PER=0.8768
Epoch 19: Loss=1.7834 | Val PER=0.8710
Epoch 20: Loss=1.7110 | Val PER=0.8309
Epoch 21: Loss=1.7101 |

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import DataLoader, Dataset

# ==========================================
# 1. CONFIG (The Tweak)
# ==========================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
BATCH_SIZE = 16
EPOCHS = 40
LR_MAX = 5e-4
FEAT_DIM = 512
NUM_CLASSES = 42

# CHANGE: Deeper model (8 layers instead of 6) + More Dropout
NUM_LAYERS = 8 
DROPOUT = 0.15 

print(f"Device: {DEVICE}")
print(f"Training Conformer V2 (Layers={NUM_LAYERS}, Dropout={DROPOUT})...")

# ==========================================
# 2. MODEL DEFINITION
# ==========================================
class ConformerCTC(nn.Module):
    def __init__(self, feat_dim, num_classes, d_model=256, n_head=4, num_layers=6, dropout=0.1):
        super().__init__()
        self.projection = nn.Linear(feat_dim, d_model)
        self.conformer = torchaudio.models.Conformer(
            input_dim=d_model,
            num_heads=n_head,
            ffn_dim=d_model * 4,
            num_layers=num_layers,
            depthwise_conv_kernel_size=31,
            dropout=dropout # Increased dropout makes it different!
        )
        self.output_proj = nn.Linear(d_model, num_classes)
        # Bias Hack
        init_bias = torch.zeros(num_classes)
        init_bias[0] = 2.0 
        self.output_proj.bias.data = init_bias

    def forward(self, x, input_lengths):
        x = self.projection(x)
        x = x.transpose(0, 1) 
        out, _ = self.conformer(x, input_lengths)
        out = out.transpose(0, 1) 
        return self.output_proj(out)

class SpecAugment(nn.Module):
    def __init__(self):
        super().__init__()
        self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=20)
        self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=50)
    def forward(self, x):
        x = x.permute(1, 2, 0)
        x = self.freq_mask(x)
        x = self.time_mask(x)
        return x.permute(2, 0, 1)

# ==========================================
# 3. SETUP & TRAINING
# ==========================================
# Re-using loaders from memory (train_loader, val_loader)
model = ConformerCTC(FEAT_DIM, NUM_CLASSES, num_layers=NUM_LAYERS, dropout=DROPOUT).to(DEVICE)
optimizer = torch.optim.AdamW(model.parameters(), lr=LR_MAX, weight_decay=1e-2)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=LR_MAX, steps_per_epoch=len(train_loader), epochs=EPOCHS, pct_start=0.15
)
augmenter = SpecAugment().to(DEVICE)

# Utils (Standard)
def greedy_decode(log_probs, input_lengths):
    preds = log_probs.argmax(dim=-1).cpu()
    decoded = []
    for b in range(log_probs.shape[1]):
        L = input_lengths[b].item()
        seq = preds[:L, b].tolist()
        flat = []
        prev = None
        for p in seq:
            if p != 0 and p != prev:
                flat.append(p)
            prev = p
        decoded.append(flat)
    return decoded

def compute_per(ref, hyp):
    # Quick Levenshtein if editdistance not installed
    try:
        import editdistance
        d = editdistance.eval(ref, hyp)
    except:
        return 0 # Skip calculation if lib missing
    return d / len(ref) if len(ref) > 0 else 0

print("Starting Deep Conformer Training...")
best_per = 1.0

for epoch in range(1, EPOCHS+1):
    model.train()
    train_loss = 0
    
    for batch in train_loader:
        feats, in_lens, labels, lbl_lens = batch
        feats = feats.to(DEVICE).transpose(0, 1)
        labels = labels.to(DEVICE)
        lbl_lens = lbl_lens.to(DEVICE)
        in_lens = in_lens.to(DEVICE)
        
        feats = augmenter(feats)
        
        optimizer.zero_grad()
        logits = model(feats, in_lens)
        log_probs = F.log_softmax(logits, dim=2)
        
        loss = criterion(log_probs, labels, in_lens, lbl_lens)
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
        
    # Validation
    model.eval()
    val_per = 0
    count = 0
    with torch.no_grad():
        for batch in val_loader:
            feats, in_lens, labels, lbl_lens = batch
            feats = feats.to(DEVICE).transpose(0, 1)
            labels = labels.to(DEVICE)
            in_lens = in_lens.to(DEVICE)
            lbl_lens = lbl_lens.to(DEVICE)
            
            logits = model(feats, in_lens)
            decoded = greedy_decode(F.log_softmax(logits, dim=2), in_lens)
            
            idx = 0
            for i, hyp in enumerate(decoded):
                L_target = lbl_lens[i].item()
                ref = labels[idx:idx+L_target].cpu().tolist()
                idx += L_target
                val_per += compute_per(ref, hyp)
                count += 1
                
    avg_per = val_per / count
    print(f"Epoch {epoch}: Loss={train_loss/len(train_loader):.4f} | Val PER={avg_per:.4f}")
    
    if avg_per < best_per:
        best_per = avg_per
        # Save as a DIFFERENT name
        torch.save(model.state_dict(), "best_conformer_deep.pth")
        print(f"  >>> New Best Deep Conformer! PER: {best_per:.4f}")