In [1]:
!pip install jiwer scipy



# Optimized Brain-to-Text '25
## Phases 1, 2 & 3: Signal Processing + Adapter + Beam Search
This notebook implements:
- **Phase 1:** Gaussian Smoothing, Robust Normalization, Time/Channel Masking
- **Phase 2:** Subject-Specific Projection (Adapter Layer)
- **Phase 3:** CTC Beam Search Decoder

In [2]:
import os
# This helps with memory fragmentation
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import h5py
import torch
import numpy as np
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torch.nn import functional as F
import torch.nn.utils.rnn as rnn_utils
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import jiwer
from scipy.ndimage import gaussian_filter1d
import heapq
from collections import defaultdict

class CFG:
    N_HEAD = 4 # Keep this lower if needed
    EPOCHS = 5
    LR = 1e-3
    BATCH_SIZE = 8  # <--- Drop this from 32 to 8
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    DATA_DIR = DATA_DIR = r"/mnt/data_ssd/ai_workspace/venv_torch/t15_copyTask_neuralData/hdf5_data_final"
    CHECKPOINT_PATH = r"/mnt/data_ssd/ai_workspace/venv_torch/t15_pretrained_rnn_baseline"
    
print(f"Running on device: {CFG.DEVICE}")

  from .autonotebook import tqdm as notebook_tqdm


Running on device: cuda


In [3]:
# --- Signal Processing Functions (Phase 1) ---

def gaussian_smoothing_1d(data, sigma=20):
    """Applies 1D Gaussian smoothing along the time dimension."""
    return gaussian_filter1d(data, sigma=sigma, axis=0)

def robust_scale(data, median, iqr):
    """Scales data using robust statistics (Median and IQR)."""
    return (data - median) / (iqr + 1e-6)

def temporal_mask(data, mask_percentage=0.05, mask_value=0.0):
    """Applies temporal masking to a 2D tensor [Sequence, Features]."""
    if not torch.is_tensor(data):
        data = torch.tensor(data, dtype=torch.float32)
    seq_len, _ = data.shape
    num_to_mask = int(seq_len * mask_percentage)
    if num_to_mask > 0:
        mask_indices = torch.randperm(seq_len)[:num_to_mask]
        data[mask_indices, :] = mask_value
    return data

def channel_mask(data, mask_percentage=0.1, mask_value=0.0):
    """Masks a random subset of channels."""
    if not torch.is_tensor(data):
        data = torch.tensor(data, dtype=torch.float32)
    _, n_channels = data.shape
    num_to_mask = int(n_channels * mask_percentage)
    if num_to_mask > 0:
        mask_indices = torch.randperm(n_channels)[:num_to_mask]
        data[:, mask_indices] = mask_value
    return data

In [4]:
# --- Data Loading with Session IDs (Phase 2) ---

SESSION_TO_ID = {}

class BrainDataset(Dataset):
    def __init__(self, hdf5_file, session_id, input_key="input_features", target_key="seq_class_ids", 
                 is_test=False, use_augmentation=False, smoothing_sigma=0, robust_stats=None):
        self.file_path = hdf5_file
        self.session_id = session_id
        self.input_key = input_key
        self.target_key = target_key
        self.is_test = is_test
        self.use_augmentation = use_augmentation
        self.smoothing_sigma = smoothing_sigma
        self.robust_stats = robust_stats 
        self.file = None
        
        try:
            with h5py.File(self.file_path, "r") as f:
                self.trial_keys = sorted(list(f.keys()))
        except FileNotFoundError:
            print(f"Warning: File not found {self.file_path}, creating empty dataset.")
            self.trial_keys = []

    def __len__(self):
        return len(self.trial_keys)

    def __getitem__(self, idx):
        if self.file is None:
            self.file = h5py.File(self.file_path, "r")
            
        trial_key = self.trial_keys[idx]
        trial_group = self.file[trial_key]
        x_data = trial_group[self.input_key][:]
        
        if self.smoothing_sigma > 0:
            x_data = gaussian_smoothing_1d(x_data, sigma=self.smoothing_sigma)
        
        x = torch.tensor(x_data, dtype=torch.float32)
        
        if self.robust_stats is not None:
            x = robust_scale(x, self.robust_stats['median'], self.robust_stats['iqr'])
        
        if self.use_augmentation and not self.is_test:
            x = temporal_mask(x, mask_percentage=0.1)
            x = channel_mask(x, mask_percentage=0.1)
        
        if self.target_key in trial_group:
            y_data = trial_group[self.target_key][:]
            y = torch.tensor(y_data, dtype=torch.long)
        else:
            y = torch.tensor([], dtype=torch.long)
        
        if self.is_test:
            return x, y, self.session_id, trial_key
        else:
            return x, y, self.session_id

def custom_collate(batch):
    is_test = len(batch[0]) == 4
    if is_test:
        xs, ys, session_ids, keys = zip(*batch)
    else:
        xs, ys, session_ids = zip(*batch)
        
    x_lengths = torch.tensor([len(x) for x in xs], dtype=torch.long)
    y_lengths = torch.tensor([len(y) for y in ys], dtype=torch.long)
    session_ids = torch.tensor(session_ids, dtype=torch.long)
    
    padded_xs = rnn_utils.pad_sequence(xs, batch_first=True, padding_value=0.0)
    padded_ys = rnn_utils.pad_sequence(ys, batch_first=True, padding_value=0)
    
    if is_test:
        return padded_xs, padded_ys, x_lengths, y_lengths, session_ids, keys
    else:
        return padded_xs, padded_ys, x_lengths, y_lengths, session_ids

def load_datasets(smoothing_sigma=20):
    global SESSION_TO_ID
    train_datasets, val_datasets, test_datasets = [], [], []

    subfolders = sorted([f.path for f in os.scandir(CFG.DATA_DIR) if f.is_dir()])
    print(f"Found {len(subfolders)} session folders.")
    
    SESSION_TO_ID = {os.path.basename(path): i for i, path in enumerate(subfolders)}
    print(f"Session mapping: {SESSION_TO_ID}")
    
    robust_stats = None
    
    for subfolder_path in subfolders:
        session_name = os.path.basename(subfolder_path)
        session_id = SESSION_TO_ID[session_name]
        
        train_set = BrainDataset(os.path.join(subfolder_path, "data_train.hdf5"), session_id, 
                                 is_test=False, use_augmentation=True, smoothing_sigma=smoothing_sigma, robust_stats=robust_stats)
        val_set = BrainDataset(os.path.join(subfolder_path, "data_val.hdf5"), session_id, 
                               is_test=False, use_augmentation=False, smoothing_sigma=smoothing_sigma, robust_stats=robust_stats)
        test_set = BrainDataset(os.path.join(subfolder_path, "data_test.hdf5"), session_id, 
                                is_test=True, use_augmentation=False, smoothing_sigma=smoothing_sigma, robust_stats=robust_stats)
        
        if len(train_set) > 0: train_datasets.append(train_set)
        if len(val_set) > 0: val_datasets.append(val_set)
        if len(test_set) > 0: test_datasets.append(test_set)
            
    return ConcatDataset(train_datasets), ConcatDataset(val_datasets), ConcatDataset(test_datasets), len(SESSION_TO_ID)

print("Loading datasets...")
train_dataset, val_dataset, test_dataset, NUM_SESSIONS = load_datasets(smoothing_sigma=20)
print(f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}, Sessions: {NUM_SESSIONS}")

train_loader = DataLoader(train_dataset, batch_size=CFG.BATCH_SIZE, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, collate_fn=custom_collate)
test_loader = DataLoader(test_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, collate_fn=custom_collate)

Loading datasets...
Found 45 session folders.
Session mapping: {'t15.2023.08.11': 0, 't15.2023.08.13': 1, 't15.2023.08.18': 2, 't15.2023.08.20': 3, 't15.2023.08.25': 4, 't15.2023.08.27': 5, 't15.2023.09.01': 6, 't15.2023.09.03': 7, 't15.2023.09.24': 8, 't15.2023.09.29': 9, 't15.2023.10.01': 10, 't15.2023.10.06': 11, 't15.2023.10.08': 12, 't15.2023.10.13': 13, 't15.2023.10.15': 14, 't15.2023.10.20': 15, 't15.2023.10.22': 16, 't15.2023.11.03': 17, 't15.2023.11.04': 18, 't15.2023.11.17': 19, 't15.2023.11.19': 20, 't15.2023.11.26': 21, 't15.2023.12.03': 22, 't15.2023.12.08': 23, 't15.2023.12.10': 24, 't15.2023.12.17': 25, 't15.2023.12.29': 26, 't15.2024.02.25': 27, 't15.2024.03.03': 28, 't15.2024.03.08': 29, 't15.2024.03.15': 30, 't15.2024.03.17': 31, 't15.2024.04.25': 32, 't15.2024.04.28': 33, 't15.2024.05.10': 34, 't15.2024.06.14': 35, 't15.2024.07.19': 36, 't15.2024.07.21': 37, 't15.2024.07.28': 38, 't15.2025.01.10': 39, 't15.2025.01.12': 40, 't15.2025.03.14': 41, 't15.2025.03.16': 42, 

In [5]:
class SubjectAdapter(nn.Module):
    def __init__(self, input_dim, num_sessions):
        super().__init__()
        # Use a 3D weight tensor for all sessions at once
        self.weight = nn.Parameter(torch.stack([torch.eye(input_dim) for _ in range(num_sessions)]))
        self.bias = nn.Parameter(torch.zeros(num_sessions, input_dim))
    
    def forward(self, x, session_ids):
        # x: [Batch, Seq, Dim]
        # session_ids: [Batch]
        w = self.weight[session_ids] # [Batch, Dim, Dim]
        b = self.bias[session_ids].unsqueeze(1) # [Batch, 1, Dim]
        
        # Process whole batch with one matrix multiplication
        return torch.bmm(x, w) + b

In [6]:
from torch.utils.checkpoint import checkpoint

class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class FeedForwardModule(nn.Module):
    def __init__(self, dim, expansion_factor=4, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(dim)
        self.linear1 = nn.Linear(dim, dim * expansion_factor)
        self.swish = Swish()
        self.dropout1 = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim * expansion_factor, dim)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x):
        out = self.layer_norm(x)
        out = self.linear1(out)
        out = self.swish(out)
        out = self.dropout1(out)
        out = self.linear2(out)
        out = self.dropout2(out)
        return out

class ConvolutionModule(nn.Module):
    def __init__(self, dim, kernel_size=31, dropout=0.1):
        super().__init__()
        self.layer_norm = nn.LayerNorm(dim)
        self.pointwise_conv1 = nn.Conv1d(dim, dim * 2, kernel_size=1)
        self.glu = nn.GLU(dim=1)
        self.depthwise_conv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, groups=dim)
        self.batch_norm = nn.BatchNorm1d(dim)
        self.swish = Swish()
        self.pointwise_conv2 = nn.Conv1d(dim, dim, kernel_size=1)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = self.layer_norm(x)
        out = out.transpose(1, 2)
        out = self.pointwise_conv1(out)
        out = self.glu(out)
        out = self.depthwise_conv(out)
        out = self.batch_norm(out)
        out = self.swish(out)
        out = self.pointwise_conv2(out)
        out = self.dropout(out)
        out = out.transpose(1, 2)
        return out

class ConformerBlock(nn.Module):
    def __init__(self, dim, n_head, conv_kernel_size=31, dropout=0.1):
        super().__init__()
        self.ff1 = FeedForwardModule(dim, dropout=dropout)
        self.self_attn_layer_norm = nn.LayerNorm(dim)
        # We no longer use nn.MultiheadAttention to ensure Flash Attention 2 is used
        self.n_head = n_head
        self.dropout_p = dropout
        self.conv_module = ConvolutionModule(dim, kernel_size=conv_kernel_size, dropout=dropout)
        self.ff2 = FeedForwardModule(dim, dropout=dropout)
        self.final_layer_norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + 0.5 * self.ff1(x)
        residual = x
        x_norm = self.self_attn_layer_norm(x)
        
        # Blackwell Optimization: Scaled Dot Product Attention (Flash Attention)
        # This prevents the 2.8GB allocation by calculating attention in blocks
        attn_out = torch.nn.functional.scaled_dot_product_attention(
            x_norm, x_norm, x_norm, 
            dropout_p=self.dropout_p if self.training else 0.0,
            is_causal=False
        )
        
        x = residual + self.dropout(attn_out)
        x = x + self.conv_module(x)
        x = x + 0.5 * self.ff2(x)
        x = self.final_layer_norm(x)
        return x

class ConformerEncoder(nn.Module):
    def __init__(self, input_dim, encoder_dim, n_layers, n_head, output_dim):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, encoder_dim)
        self.layers = nn.ModuleList([ConformerBlock(encoder_dim, n_head) for _ in range(n_layers)])
        self.output_proj = nn.Linear(encoder_dim, output_dim)

    def forward(self, x):
        x = self.input_proj(x)
        for layer in self.layers:
            # Gradient Checkpointing: Trades 20% speed for 60-70% VRAM saving
            x = checkpoint(layer, x, use_reentrant=False)
        x = self.output_proj(x)
        return nn.functional.log_softmax(x, dim=2)

In [7]:
# --- Full Model with Subject Adapter (Phase 2) ---

class BrainToTextModel(nn.Module):
    def __init__(self, input_dim, encoder_dim, n_layers, n_head, output_dim, num_sessions):
        super().__init__()
        self.adapter = SubjectAdapter(input_dim, num_sessions)
        self.encoder = ConformerEncoder(input_dim, encoder_dim, n_layers, n_head, output_dim)
    
    def forward(self, x, session_ids):
        x = self.adapter(x, session_ids)
        return self.encoder(x)

INPUT_DIM = 512
ENCODER_DIM = 256
N_LAYERS = 4
N_HEAD = 4
OUTPUT_DIM = 41  # 40 phonemes + 1 blank

model = BrainToTextModel(INPUT_DIM, ENCODER_DIM, N_LAYERS, N_HEAD, OUTPUT_DIM, NUM_SESSIONS).to(CFG.DEVICE)
print(f"BrainToTextModel initialized with {sum(p.numel() for p in model.parameters())} parameters")

BrainToTextModel initialized with 17000489 parameters


In [8]:
# --- CTC Beam Search Decoder (Phase 3) ---

def ctc_beam_search(log_probs, beam_width=10, blank_id=0):
    """
    CTC Prefix Beam Search Decoder.
    
    Args:
        log_probs: [T, num_classes] numpy array of log probabilities
        beam_width: Number of beams to keep
        blank_id: ID of the blank token (usually 0)
    
    Returns:
        Best decoded sequence (list of token IDs, excluding blanks)
    """
    T, num_classes = log_probs.shape
    
    # Each beam is: (prefix_tuple, (log_prob_blank, log_prob_non_blank))
    # prefix_tuple: the decoded sequence so far
    # log_prob_blank: prob of this prefix ending in blank
    # log_prob_non_blank: prob of this prefix ending in non-blank
    
    NEG_INF = float('-inf')
    beams = {(): (0.0, NEG_INF)}  # Start with empty prefix, prob=1 for blank path
    
    for t in range(T):
        new_beams = defaultdict(lambda: (NEG_INF, NEG_INF))
        
        for prefix, (pb, pnb) in beams.items():
            # Total probability of this prefix
            p_prefix = np.logaddexp(pb, pnb)
            
            for c in range(num_classes):
                p_c = log_probs[t, c]
                
                if c == blank_id:
                    # Extend with blank - prefix stays same
                    new_pb, new_pnb = new_beams[prefix]
                    new_pb = np.logaddexp(new_pb, p_prefix + p_c)
                    new_beams[prefix] = (new_pb, new_pnb)
                else:
                    # Extend with non-blank character
                    new_prefix = prefix + (c,)
                    
                    # Case 1: Previous was blank OR different character
                    new_pb, new_pnb = new_beams[new_prefix]
                    new_pnb = np.logaddexp(new_pnb, p_prefix + p_c)
                    
                    # Case 2: Repeat character (only if last char is same)
                    if len(prefix) > 0 and prefix[-1] == c:
                        # Can only extend via blank path (pb)
                        new_pnb = np.logaddexp(new_pnb, pb + p_c)
                        # Also, staying at same prefix
                        old_pb, old_pnb = new_beams[prefix]
                        old_pnb = np.logaddexp(old_pnb, pnb + p_c)
                        new_beams[prefix] = (old_pb, old_pnb)
                    
                    new_beams[new_prefix] = (new_pb, new_pnb)
        
        # Prune to beam_width
        scored = [(np.logaddexp(pb, pnb), prefix) for prefix, (pb, pnb) in new_beams.items()]
        scored.sort(reverse=True)
        beams = {prefix: new_beams[prefix] for _, prefix in scored[:beam_width]}
    
    # Return best prefix
    best_prefix = max(beams.keys(), key=lambda p: np.logaddexp(*beams[p]))
    return list(best_prefix)

def beam_search_decoder(logits, token_map, beam_width=10):
    """
    Decode logits using CTC beam search.
    
    Args:
        logits: [T, num_classes] tensor (log probabilities)
        token_map: dict mapping token IDs to phoneme strings
        beam_width: Number of beams
    
    Returns:
        Decoded phoneme string
    """
    log_probs = logits.cpu().numpy()
    decoded_ids = ctc_beam_search(log_probs, beam_width=beam_width, blank_id=0)
    phonemes = [token_map.get(i, "?") for i in decoded_ids]
    return " ".join(phonemes)

print("CTC Beam Search Decoder defined.")

CTC Beam Search Decoder defined.


In [9]:
# --- Decoders: Greedy vs Beam Search ---

VOCAB = [
    'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 
    'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 
    'ZH', '|'
]
TOKEN_MAP = {i + 1: phoneme for i, phoneme in enumerate(VOCAB)}
TOKEN_MAP[0] = ""  # blank

def greedy_decoder(logits, token_map):
    """Simple greedy CTC decoder (baseline)."""
    pred_indices = torch.argmax(logits, dim=-1)
    collapsed_indices = torch.unique_consecutive(pred_indices)
    final_indices = [idx.item() for idx in collapsed_indices if idx.item() != 0]
    phonemes = [token_map.get(i, "?") for i in final_indices]
    return " ".join(phonemes)

# --- Test Beam Search ---
print("Testing beam search decoder...")
dummy_logits = torch.randn(50, 41)  # 50 timesteps, 41 classes
dummy_logits = F.log_softmax(dummy_logits, dim=-1)

greedy_result = greedy_decoder(dummy_logits, TOKEN_MAP)
beam_result = beam_search_decoder(dummy_logits, TOKEN_MAP, beam_width=5)

print(f"Greedy: {greedy_result[:50]}...")
print(f"Beam:   {beam_result[:50]}...")
print("Beam search decoder working!")

Testing beam search decoder...
Greedy: HH D N L OW Z AE AO B ER D NG V UH S OW AW HH ZH U...
Beam:   HH D N L OW Z AE AO B ER D NG V UH S OW AW HH ZH U...
Beam search decoder working!


In [10]:
from torch.amp import autocast, GradScaler

scaler = GradScaler('cuda')

def train_one_epoch(epoch, model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    for x, y, x_lengths, y_lengths, session_ids in tqdm(train_loader, desc=f"Epoch {epoch} [Train]", leave=False):
        x = x.to(CFG.DEVICE)
        y = y.to(CFG.DEVICE)
        x_lengths = x_lengths.to(CFG.DEVICE)
        y_lengths = y_lengths.to(CFG.DEVICE)
        session_ids = session_ids.to(CFG.DEVICE)
        
        optimizer.zero_grad()
        
        # BF16 is native to Blackwell and uses half the memory of FP32
        with autocast('cuda', dtype=torch.bfloat16):
            y_pred = model(x, session_ids)
            y_pred_for_loss = y_pred.permute(1, 0, 2)
            loss = criterion(y_pred_for_loss, y, x_lengths, y_lengths)
            
        if torch.isnan(loss) or torch.isinf(loss): continue
        
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        
        running_loss += loss.item() * x.size(0)
    return running_loss / len(train_loader.dataset)

def validate_one_epoch(epoch, model, val_loader, criterion, token_map, use_beam_search=False, beam_width=10):
    model.eval()
    val_loss = 0.0
    all_pred, all_true = [], []
    
    with torch.no_grad():
        with autocast('cuda', dtype=torch.bfloat16):
            for x, y, x_lengths, y_lengths, session_ids in tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False):
                x = x.to(CFG.DEVICE)
                y = y.to(CFG.DEVICE)
                x_lengths = x_lengths.to(CFG.DEVICE)
                y_lengths = y_lengths.to(CFG.DEVICE)
                session_ids = session_ids.to(CFG.DEVICE)
                
                y_pred = model(x, session_ids)
                y_pred_for_loss = y_pred.permute(1, 0, 2)
                loss = criterion(y_pred_for_loss, y, x_lengths, y_lengths)
                val_loss += loss.item() * x.size(0)
                
                # Move to CPU for decoding
                y_pred_cpu = y_pred.to(torch.float32).cpu()
                
                for i in range(x.size(0)):
                    pred_logits = y_pred_cpu[i, :x_lengths[i], :]
                    true_indices = y[i, :y_lengths[i]]
                    
                    if use_beam_search:
                        pred_text = beam_search_decoder(pred_logits, token_map, beam_width)
                    else:
                        pred_text = greedy_decoder(pred_logits, token_map)
                        
                    true_text = " ".join([token_map.get(idx.item(), "?") for idx in true_indices])
                    all_pred.append(pred_text)
                    all_true.append(true_text)
                
    wer = jiwer.wer(all_true, all_pred)
    return val_loss / len(val_loader.dataset), wer

criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=CFG.LR)

In [11]:
import gc

# 1. Clear GPU memory manually
torch.cuda.empty_cache()
gc.collect()

# 2. Re-init loader with safe Batch Size for 16GB VRAM
CFG.BATCH_SIZE = 8
train_loader = DataLoader(train_dataset, batch_size=CFG.BATCH_SIZE, shuffle=True, collate_fn=custom_collate)
val_loader = DataLoader(val_dataset, batch_size=CFG.BATCH_SIZE, shuffle=False, collate_fn=custom_collate)

print("Starting Training (Blackwell Optimized)...")
for epoch in range(1, CFG.EPOCHS + 1):
    train_loss = train_one_epoch(epoch, model, train_loader, criterion, optimizer)
    val_loss, wer = validate_one_epoch(epoch, model, val_loader, criterion, TOKEN_MAP, use_beam_search=False)
    print(f"Epoch {epoch} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | WER: {wer:.4f}")

Starting Training (Blackwell Optimized)...


                                                                                            

Epoch 1 | Train Loss: -0.2677 | Val Loss: -0.3312 | WER: 0.9504


                                                                                            

Epoch 2 | Train Loss: -0.2789 | Val Loss: -0.3355 | WER: 0.8265


                                                                                            

Epoch 3 | Train Loss: -0.2804 | Val Loss: -0.3368 | WER: 0.7219


                                                                                            

Epoch 4 | Train Loss: -0.2810 | Val Loss: -0.3303 | WER: 0.7274


                                                                                            

Epoch 5 | Train Loss: -0.2821 | Val Loss: -0.3295 | WER: 0.7621




In [13]:
# --- Compare Greedy vs Beam Search WER ---
# Run this after training to compare decoding strategies:

print("Comparing Greedy vs Beam Search...")
_, wer_greedy = validate_one_epoch(0, model, val_loader, criterion, TOKEN_MAP, use_beam_search=False)
_, wer_beam = validate_one_epoch(0, model, val_loader, criterion, TOKEN_MAP, use_beam_search=True, beam_width=10)
print(f"Greedy WER: {wer_greedy:.4f}")
print(f"Beam (k=10) WER: {wer_beam:.4f}")
print(f"Improvement: {(wer_greedy - wer_beam) * 100:.2f}%")

Comparing Greedy vs Beam Search...


                                                                                            

Greedy WER: 0.7621
Beam (k=10) WER: 0.7676
Improvement: -0.55%


