In [None]:
from src.load_dataset.dataset import create_balanced_loader

train_loader, val_loader = create_balanced_loader(
    features_path="data/features/audio_embeddings_feature_selection_emotion.pkl",
    labels_path="data/features/data_emotion.p",
    batch_size=64,
    augment=True,
    T=25,
)

batch, labels, mask = next(iter(train_loader))
print(batch.shape, labels.shape, mask.shape)


Loaded simplified spike dataset: N=13708, F=1611
Class distribution: Counter({np.int64(0): 6436, np.int64(4): 2308, np.int64(1): 1636, np.int64(6): 1606, np.int64(3): 1003, np.int64(5): 361, np.int64(2): 358})
torch.Size([64, 25, 1611]) torch.Size([64]) torch.Size([64, 25])


In [None]:
# src/model_training/train_spike_attention_masked.py
import torch
import torch.nn as nn
import torch.optim as optim
from torch import autocast, GradScaler
from tqdm import tqdm
import logging
import os
from datetime import datetime
import importlib

import src.model_training.model
importlib.reload(src.model_training.model)
from src.model_training.model import SpikeAttentionNet

In [27]:
def train_spike_attention(
    train_loader,
    val_loader,
    in_dim=1611,
    num_classes=7,
    embed_dim=256,
    num_heads=4,
    lr=1e-4,
    epochs=100,
    log_interval=50,
    save_path="checkpoints",
    use_bfloat16=True,
):
    # ---- Setup ----
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SpikeAttentionNet(
        in_dim=in_dim, embed_dim=embed_dim, num_heads=num_heads, num_classes=num_classes
    ).to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-4)
    scaler = GradScaler()
    dtype = torch.bfloat16 if use_bfloat16 else torch.float16

    # ---- Logging ----
    os.makedirs(save_path, exist_ok=True)
    log_file = os.path.join(save_path, f"train_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
    logging.basicConfig(
        level=logging.INFO,
        format="%(asctime)s | %(levelname)s | %(message)s",
        handlers=[logging.FileHandler(log_file), logging.StreamHandler()],
    )

    logging.info(f"Using device: {device}")
    logging.info(f"Training SpikeAttentionNet with {sum(p.numel() for p in model.parameters()):,} parameters")

    best_val_acc = 0.0

    # ---- Epoch Loop ----
    for epoch in range(1, epochs + 1):
        model.train()
        total_loss, total_acc = 0.0, 0.0

        for batch, labels, mask in tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}"):
            batch, labels, mask = batch.to(device), labels.to(device), mask.to(device)

            optimizer.zero_grad(set_to_none=True)
            with autocast(device_type="cuda", dtype=dtype):
                logits = model(batch, mask=mask)
                loss = criterion(logits, labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            preds = logits.argmax(dim=1)
            total_acc += (preds == labels).float().mean().item()
            total_loss += loss.item()

        avg_train_loss = total_loss / len(train_loader)
        avg_train_acc = total_acc / len(train_loader)

        # ---- Validation ----
        model.eval()
        val_loss, val_acc = 0.0, 0.0
        with torch.no_grad():
            for batch, labels, mask in val_loader:
                batch, labels, mask = batch.to(device), labels.to(device), mask.to(device)
                with autocast(device_type="cuda", dtype=dtype):
                    logits = model(batch, mask=mask)
                    loss = criterion(logits, labels)

                preds = logits.argmax(dim=1)
                val_acc += (preds == labels).float().mean().item()
                val_loss += loss.item()

        avg_val_loss = val_loss / len(val_loader)
        avg_val_acc = val_acc / len(val_loader)

        # ---- Logging ----
        logging.info(
            f"Epoch {epoch:03d} | "
            f"Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.4f} | "
            f"Val Loss: {avg_val_loss:.4f} | Val Acc: {avg_val_acc:.4f}"
        )
        
        os.makedirs(save_path, exist_ok=True)

        # ---- Save Best Model ----
        if avg_val_acc > best_val_acc:
            best_val_acc = avg_val_acc
            torch.save(model.state_dict(), os.path.join(save_path, "best_model.pt"))
            logging.info(f"✅ New best model saved at epoch {epoch} with val acc={best_val_acc:.4f}")

        # ---- Periodic Console Print ----
        if epoch % log_interval == 0 or epoch == 1:
            print(
                f"Epoch {epoch:03d} | "
                f"Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.4f} | "
                f"Val Loss: {avg_val_loss:.4f} | Val Acc: {avg_val_acc:.4f}"
            )

    logging.info(f"Training completed. Best validation accuracy: {best_val_acc:.4f}")
    return model


In [28]:
from src.load_dataset.dataset import create_balanced_loader
from torch.utils.data import random_split, DataLoader

# Step 1: Load all data once
full_loader = create_balanced_loader(
    features_path="data/features/audio_embeddings_feature_selection_emotion.pkl",
    labels_path="data/features/data_emotion.p",
    batch_size=64,
    augment=True,
    T=25,
)


Loaded simplified spike dataset: N=13708, F=1611
Class distribution: Counter({np.int64(0): 6436, np.int64(4): 2308, np.int64(1): 1636, np.int64(6): 1606, np.int64(3): 1003, np.int64(5): 361, np.int64(2): 358})


In [29]:
# Extract the underlying dataset
full_dataset = full_loader.dataset
dataset_len = len(full_dataset)
train_len = int(0.8 * dataset_len)
val_len = dataset_len - train_len

# Step 2: Split the dataset
train_dataset, val_dataset = random_split(full_dataset, [train_len, val_len])

# Step 3: Create two loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

In [30]:
# Done: Pass to your training function
model = train_spike_attention(
    train_loader=train_loader,
    val_loader=val_loader,
    in_dim=1611,
    num_classes=7,
    epochs=100,
)


2025-10-22 18:26:55,873 | INFO | Using device: cuda
2025-10-22 18:26:55,874 | INFO | Training SpikeAttentionNet with 999,175 parameters
Epoch 1/100: 100%|██████████| 172/172 [00:08<00:00, 19.32it/s]
2025-10-22 18:27:06,894 | INFO | Epoch 001 | Train Loss: 1.5544 | Train Acc: 0.4689 | Val Loss: 1.5560 | Val Acc: 0.4617
2025-10-22 18:27:06,903 | INFO | ✅ New best model saved at epoch 1 with val acc=0.4617


Epoch 001 | Train Loss: 1.5544 | Train Acc: 0.4689 | Val Loss: 1.5560 | Val Acc: 0.4617


Epoch 2/100: 100%|██████████| 172/172 [00:10<00:00, 16.84it/s]
2025-10-22 18:27:19,232 | INFO | Epoch 002 | Train Loss: 1.5443 | Train Acc: 0.4717 | Val Loss: 1.5537 | Val Acc: 0.4617
Epoch 3/100: 100%|██████████| 172/172 [00:08<00:00, 19.53it/s]
2025-10-22 18:27:30,153 | INFO | Epoch 003 | Train Loss: 1.5375 | Train Acc: 0.4712 | Val Loss: 1.5474 | Val Acc: 0.4617
Epoch 4/100: 100%|██████████| 172/172 [00:09<00:00, 18.65it/s]
2025-10-22 18:27:41,396 | INFO | Epoch 004 | Train Loss: 1.5292 | Train Acc: 0.4717 | Val Loss: 1.5350 | Val Acc: 0.4617
Epoch 5/100: 100%|██████████| 172/172 [00:08<00:00, 21.30it/s]
2025-10-22 18:27:51,391 | INFO | Epoch 005 | Train Loss: 1.5277 | Train Acc: 0.4712 | Val Loss: 1.5373 | Val Acc: 0.4617
Epoch 6/100: 100%|██████████| 172/172 [00:08<00:00, 20.72it/s]
2025-10-22 18:28:01,626 | INFO | Epoch 006 | Train Loss: 1.5221 | Train Acc: 0.4716 | Val Loss: 1.5302 | Val Acc: 0.4617
Epoch 7/100: 100%|██████████| 172/172 [00:08<00:00, 20.49it/s]
2025-10-22 18:28:

KeyboardInterrupt: 