In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

# Brain-to-Text '25: Baseline Submission Notebook

This notebook will set up the environment, load the data, and run the official baseline model to generate a valid `submission.csv`.

**CRITICAL LIMITATION:**
The top-scoring baseline uses a 5-gram language model that requires **~300GB of RAM**, and the 3-gram model requires **~60GB**. A Kaggle notebook only has ~16GB (or ~30GB with High RAM).

Therefore, this notebook will generate a submission using:
1.  The **pre-trained RNN model** (provided by the hosts).
2.  The **1-gram language model** (which fits in Kaggle's memory).

This will get you a valid score on the leaderboard. To improve, you will need to train a better neural model and/or use a high-RAM machine to run the larger language models.

In [None]:
import h5py
import os

# Define paths from your CFG class
DATA_DIR = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
SUBFOLDER = "t15.2025.03.14"

# Path to the *test* file
test_file_path = os.path.join(DATA_DIR, SUBFOLDER, "data_test.hdf5")

print(f"Inspecting test file: {test_file_path}")
try:
    with h5py.File(test_file_path, "r") as f:
        # Get the first trial key (e.g., 'trial_0000')
        first_trial_key = sorted(list(f.keys()))[0]
        print(f"Inspecting keys inside: {first_trial_key}")
        
        trial_group = f[first_trial_key]
        
        # --- THIS IS THE IMPORTANT PART ---
        print(f"Keys found: {list(trial_group.keys())}")
        
        for key in trial_group.keys():
            print(f"  ‚Ä¢ {key}: shape {trial_group[key].shape}, dtype {trial_group[key].dtype}")

except Exception as e:
    print(f"An error occurred: {e}")

In [None]:
!pip install jiwer

In [None]:
# ============================================================
# STEP 1: Imports and Configuration
# ============================================================
print("Importing libraries...")
import os
import yaml
import h5py
import torch
import numpy as np
import pandas as pd
from torch import nn
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

# Configuration
class CFG:
    # --- Model Hyperparameters ---
    # Used for Transformer
    N_HEAD = 8 
    
    # --- Training ---
    EPOCHS = 5
    LR = 1e-3
    BATCH_SIZE = 32
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

    # --- Paths ---
    DATA_DIR = "/kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final"
    CHECKPOINT_PATH = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/best_checkpoint"
    ARGS_PATH = "/kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline/checkpoint/args.yaml"
    COMPETITION_TEST_PATH = "/kaggle/input/brain-to-text-25/data_test.hdf5"

print(f"Running on device: {CFG.DEVICE}")

In [None]:
# ============================================================
# STEP 2: Load Model Configuration (args.yaml)
# ============================================================
with open(CFG.ARGS_PATH, "r") as f:
    args = yaml.safe_load(f)

print("Loaded model config from args.yaml:")
# Use .get() for safety, providing defaults from your model init
INPUT_SIZE = args.get("input_size", 256)
HIDDEN_SIZE = args.get("hidden_size", 512)
OUTPUT_SIZE = args.get("output_size", 29)
NUM_LAYERS = args.get("num_layers", 1)

print(args)

In [None]:
# ============================================================
# B∆Ø·ªöC 2.5: KI·ªÇM TRA KEY C·ª¶A T·ªÜP HU·∫§N LUY·ªÜN
# ============================================================
print("Inspecting a TRAINING file to find the target key...")

# Thay ƒë·ªïi ƒë∆∞·ªùng d·∫´n n√†y n·∫øu c·∫ßn,
# ƒë√¢y l√† m·ªôt trong c√°c th∆∞ m·ª•c hu·∫•n luy·ªán
TRAIN_SUBFOLDER = "t15.2025.01.10" 
train_file_path = os.path.join(CFG.DATA_DIR, TRAIN_SUBFOLDER, "data_train.hdf5")

try:
    with h5py.File(train_file_path, "r") as f:
        first_trial_key = sorted(list(f.keys()))[0]
        print(f"Inspecting keys inside trial: {first_trial_key}")
        
        trial_group = f[first_trial_key]
        print(f"!!! T·∫§T C·∫¢ C√ÅC KEY T√åM TH·∫§Y: {list(trial_group.keys())} !!!")
        
        # In ra th√¥ng tin v·ªÅ c√°c key c√≥ th·ªÉ l√† target
        for key in trial_group.keys():
            if key != 'input_features':
                print(f"  ‚Ä¢ T√¨m th·∫•y key kh·∫£ nghi: '{key}'")
                print(f"    Shape: {trial_group[key].shape}, Dtype: {trial_group[key].dtype}")
                print(f"    Sample data: {trial_group[key][:10]}")
                
except Exception as e:
    print(f"L·ªói khi ki·ªÉm tra: {e}")

In [None]:
# ============================================================
# STEP 3: Define Dataset Loader (Seq2Seq)
# ============================================================
from torch.utils.data import ConcatDataset, Dataset

def temporal_mask(data, mask_percentage=0.05, mask_value=0.0):
    """
    Applies temporal masking to a 2D tensor [Sequence, Features].
    """
    if not torch.is_tensor(data):
        data = torch.tensor(data, dtype=torch.float32)
        
    seq_len, _ = data.shape
    num_to_mask = int(seq_len * mask_percentage)
    
    if num_to_mask > 0:
        mask_indices = torch.randperm(seq_len)[:num_to_mask]
        data[mask_indices, :] = mask_value
        
    return data

class BrainDataset(Dataset):
    """
    Reads data from a single HDF5 file (e.g., data_train.hdf5).
    - input_key: The name of the HDF5 dataset for input features.
    - target_key: The name of the HDF5 dataset for target sequences (indices).
    - is_test: If True, __getitem__ also returns the trial_key.
    """
    def __init__(self, hdf5_file, input_key="input_features", target_key="phoneme_indices", is_test=False, use_augmentation=False):
        self.file_path = hdf5_file
        self.input_key = input_key
        self.target_key = target_key # Key for sequence targets
        self.is_test = is_test
        
        # --- FIX 2: STORE THE PARAMETER ---
        self.use_augmentation = use_augmentation 
        self.file = None # File handle
        
        try:
            with h5py.File(self.file_path, "r") as f:
                self.trial_keys = sorted(list(f.keys()))
        except FileNotFoundError:
            # Handle cases where a subfolder might be missing a split
            print(f"Warning: File not found {self.file_path}, creating empty dataset.")
            self.trial_keys = []

    def __len__(self):
        return len(self.trial_keys)

    def __getitem__(self, idx):
        if self.file is None:
            self.file = h5py.File(self.file_path, "r")
            
        trial_key = self.trial_keys[idx]
        trial_group = self.file[trial_key]
        
        x_data = trial_group[self.input_key][:]
        x = torch.tensor(x_data, dtype=torch.float32)
        
        if self.use_augmentation and not self.is_test:
            x = temporal_mask(x, mask_percentage=0.1)
        
        if self.target_key in trial_group:
            # Assume targets are a 1D array of integer indices (for CTC)
            y_data = trial_group[self.target_key][:]
            y = torch.tensor(y_data, dtype=torch.long)
        else:
            # Create an empty long tensor as a placeholder for test/dummy targets
            y = torch.tensor([], dtype=torch.long)
        
        if self.is_test:
            return x, y, trial_key
        else:
            return x, y

def load_datasets():
    """
    Scans all subfolders in CFG.DATA_DIR and creates combined
    train, val, and test datasets from all found files.
    """
    train_datasets = []
    val_datasets = []
    test_datasets = []

    subfolders = [f.path for f in os.scandir(CFG.DATA_DIR) if f.is_dir()]
    print(f"Found {len(subfolders)} session folders.")
    
    for subfolder_path in subfolders:
        session_name = os.path.basename(subfolder_path)
        
        train_file = os.path.join(subfolder_path, "data_train.hdf5")
        val_file = os.path.join(subfolder_path, "data_val.hdf5")
        test_file = os.path.join(subfolder_path, "data_test.hdf5")

        # --- PASS THE use_augmentation FLAG ---
        # Only apply augmentations to the training set
        train_set = BrainDataset(train_file, input_key="input_features", target_key="seq_class_ids", is_test=False, use_augmentation=True)
        val_set = BrainDataset(val_file, input_key="input_features", target_key="seq_class_ids", is_test=False, use_augmentation=False)
        test_set = BrainDataset(test_file, input_key="input_features", target_key="seq_class_ids", is_test=True, use_augmentation=False) 
        
        if len(train_set) > 0:
            train_datasets.append(train_set)
        if len(val_set) > 0:
            val_datasets.append(val_set)
        if len(test_set) > 0:
            test_datasets.append(test_set)
            
    # Combine all individual datasets into one large dataset
    full_train_dataset = ConcatDataset(train_datasets)
    full_val_dataset = ConcatDataset(val_datasets)
    full_test_dataset = ConcatDataset(test_datasets)
    
    return full_train_dataset, full_val_dataset, full_test_dataset

print("Loading Train/Val/Test data from session folders...")
train_dataset, val_dataset, test_dataset = load_datasets()
print("="*40)
print(f"Total Train samples: {len(train_dataset)}")
print(f"Total Val samples: {len(val_dataset)}")
print(f"Total Test samples: {len(test_dataset)}")
print("="*40)

# Check samples
sample_x_train, sample_y_train = train_dataset[0]
print(f"Train sample X shape: {sample_x_train.shape}")
print(f"Train sample Y (indices): {sample_y_train}")
print(f"Train sample Y shape: {sample_y_train.shape}, dtype: {sample_y_train.dtype}")

if len(test_dataset) > 0:
    sample_x_test, sample_y_test, _ = test_dataset[0]
    print(f"Test sample X shape: {sample_x_test.shape}")
    print(f"Test sample Y (dummy): {sample_y_test}")
    print(f"Test sample Y shape: {sample_y_test.shape}, dtype: {sample_y_test.dtype}")
else:
    print(f"Warning: Competition test file not found at {CFG.COMPETITION_TEST_PATH}")
    print("This is normal. The file will be present during submission.")

In [None]:
# ============================================================
# STEP 3.5: INSPECT A TRAINING FILE
# ============================================================
print("üïµÔ∏è Inspecting a training file to find the correct label key...")
import h5py
import os

# --- Find the FIRST available data_train.hdf5 file ---
train_file_to_inspect = None
subfolders = [f.path for f in os.scandir(CFG.DATA_DIR) if f.is_dir()]

for subfolder_path in sorted(subfolders): # sort to get a consistent one
    train_file = os.path.join(subfolder_path, "data_train.hdf5")
    if os.path.exists(train_file):
        train_file_to_inspect = train_file
        break
        
if train_file_to_inspect:
    print(f"Inspecting file: {train_file_to_inspect}")
    try:
        with h5py.File(train_file_to_inspect, "r") as f:
            first_trial_key = sorted(list(f.keys()))[0]
            print(f"Inspecting keys inside trial: {first_trial_key}")
            
            trial_group = f[first_trial_key]
            print(f"\n--- üí° ALL KEYS FOUND IN THIS TRIAL üí° ---")
            for key in trial_group.keys():
                print(f"  ‚Ä¢ {key}")
            print(f"-------------------------------------------\n")
            print("Find the key that looks like labels (e.g., 'targets', 'labels', 'phonemes') and use it in the next step.")
            
    except Exception as e:
        print(f"An error occurred: {e}")
else:
    print("Error: Could not find any data_train.hdf5 files to inspect.")

In [None]:
# ============================================================
# STEP 4: Define Model (Adapter for Fine-tuning)
# ============================================================

# --- 1. DEFINE VOCABULARY ---
# This is the list of the 40 phonemes.
VOCAB = [
    'AA', 'AE', 'AH', 'AO', 'AW', 'AY', 'B', 'CH', 'D', 'DH', 'EH', 'ER', 
    'EY', 'F', 'G', 'HH', 'IH', 'IY', 'JH', 'K', 'L', 'M', 'N', 'NG', 'OW',
    'OY', 'P', 'R', 'S', 'SH', 'T', 'TH', 'UH', 'UW', 'V', 'W', 'Y', 'Z', 
    'ZH', '|'  # '|' is the word boundary token
]
# We add +1 for the blank token
OUTPUT_SIZE = len(VOCAB) + 1 # 40 + 1 = 41
BLANK_ID = 0

# --- 2. DEFINE MODEL PARAMETERS ---
# These MUST match the shapes in your checkpoint
DATA_INPUT_SIZE = 512       # From data
ADAPTER_OUTPUT_SIZE = 256   # <-- THIS IS THE MISSING VARIABLE
HIDDEN_SIZE = 512           # The hidden size of your pretrained RNN
NUM_LAYERS = 1              # The num_layers of your pretrained RNN
IS_BIDIRECTIONAL = False

print(f"Vocabulary Config: {len(VOCAB)} phonemes + 1 blank = {OUTPUT_SIZE} classes.")
print(f"Model Config: Data(512) -> Adapter(256) -> RNN(256, {HIDDEN_SIZE}) -> FC({HIDDEN_SIZE}, {OUTPUT_SIZE})")


### --- Model 1: Flexible Recurrent Model (Seq2Seq) --- ###
class RecurrentModel(nn.Module):
    # --- MODIFIED: Added adapter_output_size ---
    def __init__(self, model_type, data_input_size, adapter_output_size, 
                 hidden_size, output_size, num_layers, bidirectional):
        super().__init__()
        self.hidden_size = hidden_size
        self.bidirectional = bidirectional
        
        # --- ADDED: Adapter Layer ---
        self.adapter_layer = nn.Linear(data_input_size, adapter_output_size)
        
        rnn_args = {
            'input_size': adapter_output_size, # <-- Use adapter output
            'hidden_size': hidden_size,
            'num_layers': num_layers,
            'batch_first': True,
            'bidirectional': bidirectional
        }
        
        if model_type == "LSTM": self.rnn = nn.LSTM(**rnn_args)
        elif model_type == "GRU": self.rnn = nn.GRU(**rnn_args)
        elif model_type == "RNN": self.rnn = nn.RNN(**rnn_args)
        else: raise ValueError("Invalid model_type")

        fc_in_features = hidden_size * 2 if bidirectional else hidden_size
        self.fc = nn.Linear(fc_in_features, output_size) # output_size is 41

    def forward(self, x):
        # x is [B, S, 512]
        x = self.adapter_layer(x) # [B, S, 256]
        out, _ = self.rnn(x)      # [B, S, 512]
        out = self.fc(out)        # [B, S, 41]
        return nn.functional.log_softmax(out, dim=2)

        
### --- Model 2: Transformer Encoder Model (Seq2Seq) --- ###
class TransformerEncModel(nn.Module):
    # --- MODIFIED: Added adapter_output_size ---
    def __init__(self, data_input_size, adapter_output_size, n_head, num_layers, 
                 dim_feedforward, output_size):
        super().__init__()
        
        # --- ADDED: Adapter Layer ---
        self.adapter_layer = nn.Linear(data_input_size, adapter_output_size)
        self.d_model = adapter_output_size # d_model is now 256
        
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model, nhead=n_head, 
            dim_feedforward=dim_feedforward,
            batch_first=True, dropout=0.1
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.fc = nn.Linear(self.d_model, output_size) # output_size is 41

    def forward(self, x):
        # x is [B, S, 512]
        x = self.adapter_layer(x) # [B, S, 256]
        out = self.transformer_encoder(x) # [B, S, 256]
        out = self.fc(out)                # [B, S, 41]
        return nn.functional.log_softmax(out, dim=2)

In [None]:
# ============================================================
# STEP 6: Prepare Dataloaders
# ============================================================
import torch.nn.utils.rnn as rnn_utils

def custom_collate(batch):
    """
    Custom collate function for CTC (Sequence-to-Sequence).
    Pads both x (inputs) and y (targets) and returns their original lengths.
    `batch` is a list of tuples: (x, y) or (x, y, key)
    """
    # Check if it's a test batch (has 3 items: x, y, key)
    is_test = len(batch[0]) == 3

    if is_test:
        xs, ys, keys = zip(*batch)
    else:
        xs, ys = zip(*batch) # Train/val batch
        
    # These are the unpadded lengths, required by CTCLoss
    x_lengths = torch.tensor([len(x) for x in xs], dtype=torch.long)
    y_lengths = torch.tensor([len(y) for y in ys], dtype=torch.long)
    
    # 1. Pad the 'x' sequences (inputs)
    padded_xs = rnn_utils.pad_sequence(xs, batch_first=True, padding_value=0.0)
    
    # 2. Pad the 'y' sequences (targets)
    # We use padding_value=0. This assumes '0' is your 'blank' token index.
    # This will also correctly handle the empty 'y' tensors from the test set.
    padded_ys = rnn_utils.pad_sequence(ys, batch_first=True, padding_value=0)
    
    if is_test:
        return padded_xs, padded_ys, x_lengths, y_lengths, keys
    else:
        return padded_xs, padded_ys, x_lengths, y_lengths

# Create the DataLoaders using this new collate function
train_loader = DataLoader(
    train_dataset, 
    batch_size=CFG.BATCH_SIZE, 
    shuffle=True, 
    collate_fn=custom_collate
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=CFG.BATCH_SIZE, 
    shuffle=False, 
    collate_fn=custom_collate 
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=CFG.BATCH_SIZE, 
    shuffle=False, 
    collate_fn=custom_collate
)

print("DataLoaders created with CTC-ready padding.")

# --- Let's check a sample batch ---
try:
    x_batch, y_batch, x_len, y_len = next(iter(train_loader))
    print("\nChecking one batch from train_loader:")
    print(f"  x_batch shape: {x_batch.shape}")
    print(f"  y_batch shape: {y_batch.shape}")
    print(f"  x_lengths shape: {x_len.shape}, sample: {x_len[:5]}")
    print(f"  y_lengths shape: {y_len.shape}, sample: {y_len[:5]}")
except Exception as e:
    print(f"\nCould not get batch from train_loader (is it empty?): {e}")

In [None]:
# ============================================================
# STEP 7: Experiment Setup (for CTC / Seq2Seq)
# ============================================================
import jiwer # For Word/Character/Phoneme Error Rate
import shutil
# ---
# 1. Automatically build the TOKEN_MAP from the VOCAB
# ---
TOKEN_MAP = {i + 1: phoneme for i, phoneme in enumerate(VOCAB)}
TOKEN_MAP[BLANK_ID] = "" # BLANK_ID was set to 0 in STEP 4

print("Token map created:")
print(f"  Index 0: '{TOKEN_MAP[0]}' (BLANK)")
print(f"  Index 1: '{TOKEN_MAP[1]}' (e.g., AA)")
print(f"  Index 40: '{TOKEN_MAP[40]}' (e.g., |)")

# 2. Define the models AND configurations you want to test
# We'll run each model twice: once with the pretrained checkpoint, once from scratch.
# Format: (model_name, use_checkpoint_boolean)
experiments_to_run = [
    ("RNN", True),
    ("RNN", False),
    ("LSTM", True),
    ("LSTM", False),
    ("GRU", True),
    ("GRU", False),
    ("TRANSFORMER", True),
    ("TRANSFORMER", False),
]

# 3. A dictionary to store results from all experiments
all_experiment_results = {}

# 4. Define the decoders
def greedy_decoder(logits, token_map):
    pred_indices = torch.argmax(logits, dim=-1)
    collapsed_indices = torch.unique_consecutive(pred_indices)
    final_indices = [idx.item() for idx in collapsed_indices if idx.item() != BLANK_ID]
    
    # --- MODIFICATION ---
    # Join with a space to treat each phoneme as a "word"
    phonemes = [token_map.get(i, "?") for i in final_indices]
    text = " ".join(phonemes)
    # --- END MODIFICATION ---
    
    return text

def decode_true_target(target_indices, token_map):
    phonemes = [token_map.get(i.item(), "?") for i in target_indices]
    text = " ".join(phonemes)
    return text

# 5. Define the new training and validation functions
def train_one_epoch(epoch, model, train_loader, criterion, optimizer):
    model.train()
    running_loss = 0.0
    
    for x, y, x_lengths, y_lengths in tqdm(train_loader, desc=f"Epoch {epoch} [Train]", leave=False):
        x, y, x_lengths, y_lengths = x.to(CFG.DEVICE), y.to(CFG.DEVICE), x_lengths.to(CFG.DEVICE), y_lengths.to(CFG.DEVICE)
        
        optimizer.zero_grad()
        y_pred = model(x) # [B, S, 41]
        
        y_pred_for_loss = y_pred.permute(1, 0, 2) # [S, B, 41]
        
        loss = criterion(y_pred_for_loss, y, x_lengths, y_lengths)
        
        if torch.isinf(loss) or torch.isnan(loss):
            print("Warning: Skipping batch with inf/nan loss")
            continue
            
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * x.size(0)
        
    return running_loss / len(train_loader.dataset)

def validate_one_epoch(epoch, model, val_loader, criterion, token_map):
    model.eval()
    val_loss = 0.0
    all_pred_texts = []
    all_true_texts = []
    
    with torch.no_grad():
        for x, y, x_lengths, y_lengths in tqdm(val_loader, desc=f"Epoch {epoch} [Val]", leave=False):
            x, y, x_lengths, y_lengths = x.to(CFG.DEVICE), y.to(CFG.DEVICE), x_lengths.to(CFG.DEVICE), y_lengths.to(CFG.DEVICE)

            y_pred = model(x) # [B, S, 41]
            y_pred_for_loss = y_pred.permute(1, 0, 2) # [S, B, 41]
            loss = criterion(y_pred_for_loss, y, x_lengths, y_lengths)
            val_loss += loss.item() * x.size(0)

            for i in range(x.size(0)):
                pred_logits = y_pred[i, :x_lengths[i], :]
                true_indices = y[i, :y_lengths[i]]
                
                pred_text = greedy_decoder(pred_logits, token_map)
                true_text = decode_true_target(true_indices, token_map)
                
                all_pred_texts.append(pred_text)
                all_true_texts.append(true_text)

    # --- MODIFICATION ---
    # Use .wer() (Word Error Rate) to calculate Phoneme Error Rate
    error_rate = jiwer.wer(all_true_texts, all_pred_texts)
    # --- END MODIFICATION ---
    
    return val_loss / len(val_loader.dataset), error_rate

print("Experiment setup for CTC complete.")

In [None]:
# ============================================================
# STEP 8: Run Experiment Loop (with Fine-tuning)
# ============================================================

# --- 8.1: Load the checkpoint data ---
print("Loading checkpoint weights from file...")
checkpoint = torch.load(CFG.CHECKPOINT_PATH, map_location=CFG.DEVICE, weights_only=False)
print("Checkpoint data loaded.")


# --- 8.2: Start the main experiment loop ---
# This loop iterates through the (model_name, use_checkpoint) tuples defined in STEP 7
for model_name, use_checkpoint in experiments_to_run:
    experiment_name = f"{model_name}_{'pretrained' if use_checkpoint else 'scratch'}"
    
    print(f"\n{'='*20} üöÄ STARTING EXPERIMENT: {experiment_name} {'='*20}\n")
    
    # --- 8.3: Initialize Model ---
    model = None
    if model_name == "TRANSFORMER":
        print("Initializing TransformerEncModel...")
        model = TransformerEncModel(
            data_input_size=DATA_INPUT_SIZE,
            adapter_output_size=ADAPTER_OUTPUT_SIZE,
            n_head=CFG.N_HEAD,
            num_layers=NUM_LAYERS,
            dim_feedforward=HIDDEN_SIZE,
            output_size=OUTPUT_SIZE
        )
    elif model_name in ["RNN", "LSTM", "GRU"]:
        print(f"Initializing RecurrentModel (Type: {model_name})...")
        model = RecurrentModel(
            model_type=model_name,
            data_input_size=DATA_INPUT_SIZE,
            adapter_output_size=ADAPTER_OUTPUT_SIZE, # From STEP 4
            hidden_size=HIDDEN_SIZE,
            output_size=OUTPUT_SIZE,
            num_layers=NUM_LAYERS,
            bidirectional=IS_BIDIRECTIONAL
        )
    
    model = model.to(CFG.DEVICE)
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total Trainable Parameters: {total_params:,}")

    # --- 8.4: Load Pretrained Weights ---
    if use_checkpoint:
        print("Applying pretrained weights (strict=False)...")
        missing_keys, unexpected_keys = model.load_state_dict(
            checkpoint['model_state_dict'], 
            strict=False
        )
        print(f"  > Weights loaded. Missing keys (good): {missing_keys}")
        print(f"  > Unexpected keys (should be empty): {unexpected_keys}")
    else:
        print("Training from scratch. No checkpoint loaded.")


    # --- 8.5: Define Loss and Optimizer ---
    criterion = nn.CTCLoss(blank=BLANK_ID, zero_infinity=True) 
    optimizer = torch.optim.Adam(model.parameters(), lr=CFG.LR)
    
    # --- 8.6: Training Loop ---
    history = {'train_loss': [], 'val_loss': [], 'error_rate': []}
    best_error_rate = float("inf")
    
    # Save file using the new experiment_name
    best_model_path = f"/kaggle/working/best_model_{experiment_name}.pth"
    
    print("Starting training...")

    for epoch in range(1, CFG.EPOCHS + 1):
        train_loss = train_one_epoch(epoch, model, train_loader, criterion, optimizer)
        val_loss, error_rate = validate_one_epoch(epoch, model, val_loader, criterion, TOKEN_MAP)
        
        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['error_rate'].append(error_rate)
        
        print(f"Epoch {epoch}/{CFG.EPOCHS} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Error Rate (PER): {error_rate:.4f}")
        
        if error_rate < best_error_rate:
            best_error_rate = error_rate
            torch.save(model.state_dict(), best_model_path)
            print(f"‚úÖ Saved new best model for {experiment_name} (PER: {error_rate:.4f})!")

    # --- 8.7: Test / Prediction Loop ---
    print(f"Running predictions for {experiment_name}...")
    
    # Load the correct best model file
    model.load_state_dict(torch.load(best_model_path))
    model.eval()
    
    all_pred_texts = []
    all_trial_keys = []
    
    if len(test_loader.dataset) > 0:
        with torch.no_grad():
            for x, y, x_lengths, y_lengths, keys in tqdm(test_loader, desc=f"Testing {experiment_name}", leave=False):
                x, x_lengths = x.to(CFG.DEVICE), x_lengths.to(CFG.DEVICE)
                y_pred = model(x)
                
                for i in range(x.size(0)):
                    pred_logits = y_pred[i, :x_lengths[i], :] 
                    pred_text = greedy_decoder(pred_logits, TOKEN_MAP)
                    all_pred_texts.append(pred_text)
                    all_trial_keys.append(keys[i])

    # --- 8.8: Generate Submission ---
    print(f"Generating submission file for {experiment_name}...")
    submission_df = pd.DataFrame({'id': all_trial_keys, 'text': all_pred_texts})
    
    # Format text for submission (e.g., "AA B | K")
    submission_df['text'] = submission_df['text'].str.strip()
    
    # Save submission using the new experiment_name
    submission_path = f"/kaggle/working/submission_{experiment_name}.csv"
    submission_df.to_csv(submission_path, index=False)
    print(f"‚úÖ Submission file saved to {submission_path}")

    # --- 8.9: Store results ---
    
    # Save results using the new experiment_name
    all_experiment_results[experiment_name] = {
        'history': history,
        'best_val_loss': min(history['val_loss']),
        'best_error_rate': best_error_rate,
        'total_params': total_params
    }

print("\nüéâ All experiments complete! üéâ")

In [None]:
# ============================================================
# STEP 9: Compare Experiment Results (for CTC)
# ============================================================

print("üìä Experiment Results Summary\n")
sns.set(style="whitegrid", font_scale=1.1)

# --- 1. Print Summary Table ---
results_df = pd.DataFrame.from_dict(all_experiment_results, orient='index')
results_df = results_df.drop(columns='history') # Drop history for clean table
print(results_df.to_markdown(floatfmt=".6f"))


# --- Define a color map for model types ---
model_colors = {
    "RNN": "blue",
    "LSTM": "green",
    "GRU": "red",
    "TRANSFORMER": "purple"
}


# --- 2. Plot Error Rate Curves ---
plt.figure(figsize=(16, 9)) 
for model_name, results in all_experiment_results.items():
    
    base_model = model_name.split('_')[0] # e.g., "RNN"
    is_pretrained = "pretrained" in model_name
    
    color = model_colors.get(base_model, 'black') # Get color from map
    linestyle = '-' if is_pretrained else '--'    # Solid for pretrained, dashed for scratch

    error_history = results['history']['error_rate']
    
    # --- Added color and linestyle to plot call ---
    plt.plot(
        error_history, 
        label=f"{model_name} (Best: {results['best_error_rate']:.4f})", 
        lw=2, 
        color=color, 
        linestyle=linestyle
    )

plt.title('Model Comparison: Validation Error Rate (PER)')
plt.xlabel('Epoch')
plt.ylabel('Phoneme Error Rate (PER) (lower is better)')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left') # Move legend outside
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()


# --- 3. Plot Validation Loss Curves ---
plt.figure(figsize=(16, 9))
for model_name, results in all_experiment_results.items():
    base_model = model_name.split('_')[0] 
    is_pretrained = "pretrained" in model_name
    color = model_colors.get(base_model, 'black')
    linestyle = '-' if is_pretrained else '--'
    
    val_loss_history = results['history']['val_loss']

    plt.plot(
        val_loss_history, 
        label=f"{model_name} (Best: {results['best_val_loss']:.4f})", 
        lw=2,
        color=color,
        linestyle=linestyle
    )

plt.title('Model Comparison: Validation Loss (CTC)')
plt.xlabel('Epoch')
plt.ylabel('Validation CTCLoss')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()

# --- 4. Plot Training Loss Curves ---
plt.figure(figsize=(16, 9))
for model_name, results in all_experiment_results.items():
    base_model = model_name.split('_')[0] 
    is_pretrained = "pretrained" in model_name
    color = model_colors.get(base_model, 'black')
    linestyle = '-' if is_pretrained else '--'

    train_loss_history = results['history']['train_loss']
    
    plt.plot(
        train_loss_history, 
        label=f"{model_name}", 
        lw=2,
        color=color,
        linestyle=linestyle
    )

plt.title('Model Comparison: Training Loss (CTC)')
plt.xlabel('Epoch')
plt.ylabel('Training CTCLoss')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(True, which='both', linestyle='--', linewidth=0.5)
plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# STEP 10: Generate Final Submission
# ============================================================
import shutil # Make sure shutil is imported

print("Generating final submission file...\n")

if not all_experiment_results:
    print("Warning: `all_experiment_results` is empty. Please run STEP 8 first.")
else:
    # --- MODIFIED: Section 1 ---
    # 1. Find the best model based on LOWEST PER, using LOWEST VAL_LOSS as a tie-breaker
    best_model_name = None
    best_model_error = float("inf")
    best_model_val_loss = float("inf") # <-- New variable to track val_loss

    for model_name, results in all_experiment_results.items():
        current_error = results['best_error_rate']
        current_val_loss = results['best_val_loss']
        
        # Check if this model is better:
        # 1. Is the error rate strictly lower?
        # 2. OR is the error rate the same, but the val_loss is lower?
        if current_error < best_model_error or \
           (current_error == best_model_error and current_val_loss < best_model_val_loss):
            
            # This is our new best model
            best_model_error = current_error
            best_model_val_loss = current_val_loss
            best_model_name = model_name

    # --- MODIFIED: Print statement now shows both scores ---
    print(f"üèÜ The best overall model is: {best_model_name}")
    print(f"   > Best Phoneme Error Rate (PER): {best_model_error:.6f}")
    print(f"   > Best Validation Loss: {best_model_val_loss:.6f}")
    # --- End of modifications ---

    # 2. Define the paths
    # This is the file saved by STEP 8 for the best model
    best_submission_path = f"/kaggle/working/submission_{best_model_name}.csv"
    
    # This is the standard file Kaggle looks for
    final_submission_path = "/kaggle/working/submission.csv"

    # 3. Copy the best model's submission to "submission.csv"
    try:
        shutil.copyfile(best_submission_path, final_submission_path)
        print(f"\n‚úÖ Successfully copied '{best_submission_path}' to '{final_submission_path}'")
        
        # 4. Display the head of the final submission
        final_df = pd.read_csv(final_submission_path)
        print("\nFinal submission file head:")
        print(final_df.head())
        
    except FileNotFoundError:
        print(f"Error: Could not find '{best_submission_path}'. Make sure STEP 8 ran correctly.")
    except Exception as e:
        print(f"An error occurred while copying the file: {e}")