# Brain-to-Text Inference Pipeline

This notebook replicates the original `slightly_tidy_version_Mamba_GRU_LISA_Ensemble_with_TTA` notebook using modular code from `src/`.

**What this does:**
1. Load neural data (HDF5)
2. Run Mamba + GRU models
3. Ensemble predictions (LISA)
4. Compute metrics (WER/CER)
5. Show results

## 1. Setup & Imports

In [None]:
import sys
import torch
import numpy as np
import pandas as pd
from pathlib import Path

# Import from modular code
from src.models import MambaDecoder, GRUDecoderBaseline
from src.data_loader import BrainToTextDataset, create_data_loader
from src.decoding import (
    run_single_decoding_step,
    ensemble_logit_averaging,
    ensemble_majority_vote,
    LISAEnsemble,
    apply_test_time_augmentation
)
from src.utils import compute_wer, compute_cer, gauss_smooth, phoneme_ids_to_text

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 1.5. Kaggle + Hugging Face Authentication & Downloads

You can authenticate in two ways:

1) Interactive login (recommended for first time)
2) .env file with credentials

**.env example:**
```
KAGGLE_USERNAME=your_username
KAGGLE_KEY=your_api_key
HUGGINGFACE_TOKEN=your_hf_token
```

In [None]:
import os
import kagglehub
from huggingface_hub import notebook_login

# Optional: load credentials from .env
# Install once if needed: pip install python-dotenv
try:
    from dotenv import load_dotenv

    load_dotenv()
except Exception:
    print("python-dotenv not available; skipping .env loading")

# If .env provides Kaggle credentials, set environment variables
kaggle_user = os.getenv("KAGGLE_USERNAME")
kaggle_key = os.getenv("KAGGLE_KEY")
if kaggle_user and kaggle_key:
    os.environ["KAGGLE_USERNAME"] = kaggle_user
    os.environ["KAGGLE_KEY"] = kaggle_key
    print("Kaggle credentials loaded from .env")
else:
    print("No Kaggle credentials found; Kaggle will prompt interactively")

# Hugging Face login (token from .env if available)
# If token is missing, this will open an interactive prompt
hf_token = os.getenv("HUGGINGFACE_TOKEN")
notebook_login(token=hf_token)

# Login to Kaggle
kagglehub.login()

# Download competition dataset
competition_name = "brain-to-text-25"  # Update with actual competition name
print("Downloading dataset from Kaggle...")
dataset_path = kagglehub.competition_download(competition_name)
print(f"Dataset downloaded to: {dataset_path}")

In [None]:
from src.data_sources import download_all_sources

# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.

sources = download_all_sources()

# Example: use main dataset path from sources
# dataset_path = sources["brain_to_text_25_path"]

## 2. Load Data

Data will be loaded from the Kaggle download path.
**Optional:** Update paths below if using your own local files instead.

In [None]:
# Load from Kaggle download (if available) or specify manual paths
if 'sources' in locals() and "brain_to_text_25_path" in sources:
    dataset_path = sources["brain_to_text_25_path"]

if 'dataset_path' in locals():
    # Use Kaggle downloaded data
    data_path = f"{dataset_path}/neural_data.h5"  # Update filename if different
    metadata_path = f"{dataset_path}/metadata.csv"  # Update filename if different
    print(f"Using Kaggle dataset from: {dataset_path}")
else:
    # Manual paths (update these if not using Kaggle download)
    data_path = 'path/to/your/data.h5'
    metadata_path = 'path/to/your/metadata.csv'
    print("Using manual paths")

# Load dataset
dataset = BrainToTextDataset(
    data_path=data_path,
    metadata_path=metadata_path
)

# Create DataLoader
test_loader = create_data_loader(
    dataset=dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0  # Set to 0 on Windows to avoid multiprocessing issues
)

print(f"Dataset size: {len(dataset)} samples")
print(f"Number of batches: {len(test_loader)}")

## 3. Initialize Models

Create Mamba and GRU decoders with your configuration.

In [None]:
# Model configuration (update based on your data)
config = {
    'neural_dim': 513,      # Neural feature dimension
    'n_units': 256,         # Hidden units
    'n_days': 5,            # Number of recording days
    'n_classes': 40,        # Number of phoneme classes
    'n_layers': 3,          # Number of Mamba layers
    'drop_path': 0.1        # Stochastic depth rate
}

# Initialize Mamba model
mamba_model = MambaDecoder(
    neural_dim=config['neural_dim'],
    n_units=config['n_units'],
    n_days=config['n_days'],
    n_classes=config['n_classes'],
    n_layers=config['n_layers'],
    drop_path=config['drop_path']
).to(device)

# Initialize GRU model
gru_model = GRUDecoderBaseline(
    neural_dim=config['neural_dim'],
    hidden_dim=config['n_units'],
    n_days=config['n_days'],
    n_classes=config['n_classes'],
    n_layers=config['n_layers']
).to(device)

print(f"Mamba parameters: {sum(p.numel() for p in mamba_model.parameters()):,}")
print(f"GRU parameters: {sum(p.numel() for p in gru_model.parameters()):,}")

## 4. Load Pretrained Weights

Models will be loaded from Kaggle download path (if available).
**Optional:** Update paths below if using your own local model files.

In [None]:
# Load from Kaggle download (if available) or specify manual paths
if 'dataset_path' in locals():
    # Use Kaggle downloaded models
    mamba_checkpoint = f"{dataset_path}/mamba_model.pth"  # Update filename if different
    gru_checkpoint = f"{dataset_path}/gru_model.pth"  # Update filename if different
    print(f"Loading models from Kaggle: {dataset_path}")
else:
    # Manual paths (update these if not using Kaggle download)
    mamba_checkpoint = 'path/to/mamba_model.pth'
    gru_checkpoint = 'path/to/gru_model.pth'
    print("Using manual model paths")

# Load weights
mamba_model.load_state_dict(torch.load(mamba_checkpoint, map_location=device))
gru_model.load_state_dict(torch.load(gru_checkpoint, map_location=device))

# Set to evaluation mode
mamba_model.eval()
gru_model.eval()

print("✓ Models loaded and ready for inference")

## 5. Run Inference

Process data through both models and collect predictions.

In [None]:
all_predictions = []
all_targets = []
mamba_logits_list = []
gru_logits_list = []

with torch.no_grad():
    for batch_idx, (neural_data, day_idx, targets) in enumerate(test_loader):
        # Move to device
        neural_data = neural_data.to(device)
        day_idx = day_idx.to(device)
        
        # Mamba inference
        mamba_logits = run_single_decoding_step(
            model=mamba_model,
            neural_data=neural_data,
            day_idx=day_idx,
            apply_smoothing=True,
            sigma=1.5
        )
        
        # GRU inference
        gru_logits = run_single_decoding_step(
            model=gru_model,
            neural_data=neural_data,
            day_idx=day_idx,
            apply_smoothing=True,
            sigma=1.5
        )
        
        # Store logits for ensemble
        mamba_logits_list.append(mamba_logits.cpu().numpy())
        gru_logits_list.append(gru_logits.cpu().numpy())
        all_targets.append(targets.numpy())
        
        if (batch_idx + 1) % 10 == 0:
            print(f"Processed {batch_idx + 1}/{len(test_loader)} batches")

print("\n✓ Inference complete")

## 6. Ensemble Predictions (LISA)

Combine Mamba and GRU predictions using LISA ensemble method.

In [None]:
# Concatenate all batches
mamba_logits_all = np.concatenate(mamba_logits_list, axis=0)
gru_logits_all = np.concatenate(gru_logits_list, axis=0)
targets_all = np.concatenate(all_targets, axis=0)

# LISA Ensemble
lisa = LISAEnsemble(strategy='weighted', weights=[0.6, 0.4])  # Mamba 60%, GRU 40%
ensemble_logits = lisa.aggregate([mamba_logits_all, gru_logits_all])

# Get predictions
ensemble_preds = np.argmax(ensemble_logits, axis=-1)
mamba_preds = np.argmax(mamba_logits_all, axis=-1)
gru_preds = np.argmax(gru_logits_all, axis=-1)

print(f"Ensemble predictions shape: {ensemble_preds.shape}")

## 7. Compute Metrics

Calculate Word Error Rate (WER) and Character Error Rate (CER).

In [None]:
# Convert phoneme IDs to text
def batch_phoneme_to_text(phoneme_ids):
    """Convert batch of phoneme IDs to text."""
    texts = []
    for seq in phoneme_ids:
        text = phoneme_ids_to_text(seq)
        texts.append(text)
    return texts

# Convert predictions and targets to text
ensemble_texts = batch_phoneme_to_text(ensemble_preds)
mamba_texts = batch_phoneme_to_text(mamba_preds)
gru_texts = batch_phoneme_to_text(gru_preds)
target_texts = batch_phoneme_to_text(targets_all)

# Compute metrics
def compute_metrics(predictions, targets):
    wers = [compute_wer(pred, tgt) for pred, tgt in zip(predictions, targets)]
    cers = [compute_cer(pred, tgt) for pred, tgt in zip(predictions, targets)]
    return np.mean(wers), np.mean(cers)

ensemble_wer, ensemble_cer = compute_metrics(ensemble_texts, target_texts)
mamba_wer, mamba_cer = compute_metrics(mamba_texts, target_texts)
gru_wer, gru_cer = compute_metrics(gru_texts, target_texts)

# Print results
print("="*60)
print("RESULTS")
print("="*60)
print(f"Mamba Model:     WER = {mamba_wer:.4f} ({mamba_wer*100:.2f}%)  |  CER = {mamba_cer:.4f}")
print(f"GRU Model:       WER = {gru_wer:.4f} ({gru_wer*100:.2f}%)  |  CER = {gru_cer:.4f}")
print(f"LISA Ensemble:   WER = {ensemble_wer:.4f} ({ensemble_wer*100:.2f}%)  |  CER = {ensemble_cer:.4f}")
print("="*60)
print(f"Improvement: {(mamba_wer - ensemble_wer) / mamba_wer * 100:.1f}% relative reduction")

## 8. Show Sample Predictions

In [None]:
# Show first 5 examples
print("\nSample Predictions:\n")
for i in range(min(5, len(target_texts))):
    print(f"Example {i+1}:")
    print(f"  Target:    {target_texts[i]}")
    print(f"  Mamba:     {mamba_texts[i]}")
    print(f"  GRU:       {gru_texts[i]}")
    print(f"  Ensemble:  {ensemble_texts[i]}")
    print()

## 9. Optional: Test-Time Augmentation (TTA)

Apply TTA for additional performance boost.

In [None]:
# Apply TTA to first batch as example
sample_batch = next(iter(test_loader))
neural_data, day_idx, targets = sample_batch
neural_data = neural_data.to(device)
day_idx = day_idx.to(device)

# Run TTA
tta_logits = apply_test_time_augmentation(
    model=mamba_model,
    neural_data=neural_data,
    day_idx=day_idx,
    n_augmentations=5,
    noise_std=0.01
)

tta_preds = np.argmax(tta_logits, axis=-1)
tta_texts = batch_phoneme_to_text(tta_preds)
target_texts_sample = batch_phoneme_to_text(targets.numpy())

tta_wer, tta_cer = compute_metrics(tta_texts, target_texts_sample)
print(f"TTA Results: WER = {tta_wer:.4f} ({tta_wer*100:.2f}%)  |  CER = {tta_cer:.4f}")

## Summary

This notebook demonstrated:
1. ✅ Loading neural data using modular `src.data_loader`
2. ✅ Running Mamba and GRU models from `src.models`
3. ✅ Ensemble predictions using LISA from `src.decoding`
4. ✅ Computing metrics using `src.utils`
5. ✅ Test-time augmentation for improved accuracy

**Next steps:**
- Fine-tune ensemble weights for your specific data
- Experiment with different TTA strategies
- Deploy as production inference pipeline