In [16]:
import argparse
import torch
import numpy as np
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import accuracy_score, classification_report
from btbench_train_test_splits import generate_splits_SS_ST
from braintreebank_subject import Subject
from scipy import signal
from sklearn.metrics import roc_auc_score
import os
import json
from datetime import datetime
import gc
import psutil

def compute_spectrogram(data, fs=2048, max_freq=2000):
    """Compute spectrogram for a single trial of data.
    
    Args:
        data (numpy.ndarray): Input voltage data of shape (n_channels, n_samples) or (batch_size, n_channels, n_samples)
        fs (int): Sampling frequency in Hz
    
    Returns:
        numpy.ndarray: Spectrogram representation
    """
    # For 1 second of data at 2048Hz, we'll use larger window
    nperseg = 256  # 125ms window
    noverlap = 0  # 0% overlap
    
    f, t, Sxx = signal.spectrogram(
        data, 
        fs=fs,
        nperseg=nperseg,
        noverlap=noverlap,
        window='boxcar'
    )
    
    return np.log10(Sxx[:, (f<max_freq) & (f>0)] + 1e-10)

class MLPClassifierGPU(nn.Module):
    def __init__(self, input_size, hidden_sizes, num_classes):
        super().__init__()
        layers = []
        prev_size = input_size
        for hidden_size in hidden_sizes:
            layers.extend([
                nn.Linear(prev_size, hidden_size),
                nn.ReLU(),
                nn.Dropout(0.2)
            ])
            prev_size = hidden_size
        layers.append(nn.Linear(prev_size, num_classes))
        self.model = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.model(x)

subject_id = 3
trial_id = 0
eval_name = "volume"
k_folds = 5
spectrogram = False
"""Run MLP classification for a given subject, trial, and eval_name.

Args:
    subject_id (int): Subject ID
    trial_id (int): Trial ID
    eval_name (str): eval_name name (e.g., "rms" for volume classification)
    k_folds (int): Number of cross-validation folds
"""
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")


# Load subject data
print(f"Loading subject {subject_id}...")
subject = Subject(subject_id, cache=True)
subject.load_neural_data(trial_id)

# Generate train/test splits
print(f"Generating train/test splits for subject {subject_id}...")
train_datasets, test_datasets = generate_splits_SS_ST(
    test_subject=subject,
    test_trial_id=trial_id,
    eval_name=eval_name,
    k_folds=k_folds
)

# Store results for each fold
fold_accuracies = []
fold_results = []

for fold, (train_data, test_data) in enumerate(zip(train_datasets, test_datasets)):
    print(f"\nProcessing fold {fold + 1}/{len(train_datasets)}")

    print(f"Train data shape: {train_data[0][0].shape}")
    print(f"Test data shape: {test_data[0][0].shape}")
    print(f"Length of train data: {len(train_data)}")
    print(f"Length of test data: {len(test_data)}")

    sample_time_from, sample_time_to = 1024, 3072 # get the first second of neural data after word onset
    
    # Convert dataset to tensors
    print(f"Processing train data...")
    X_train = []
    y_train = []
    for i in range(len(train_data)):
        features, label = train_data[i]
        features = features[:, sample_time_from:sample_time_to]
        if spectrogram: 
            features = torch.from_numpy(compute_spectrogram(features.numpy())).float()
        X_train.append(features.flatten())
        y_train.append(label)
    X_train = torch.stack(X_train)
    y_train = torch.tensor(y_train, dtype=torch.long)

    print(f"Processing test data...")
    X_test = []
    y_test = []
    for i in range(len(test_data)):
        features, label = test_data[i]
        features = features[:, sample_time_from:sample_time_to]
        if spectrogram:
            features = torch.from_numpy(compute_spectrogram(features.numpy())).float()
        X_test.append(features.flatten())
        y_test.append(label)
    X_test = torch.stack(X_test)
    y_test = torch.tensor(y_test, dtype=torch.long)

    # Move data to GPU
    X_train = X_train.to(device)
    y_train = y_train.to(device)
    X_test = X_test.to(device)
    y_test = y_test.to(device)
    break

Using device: cuda
Loading subject 3...


Generating train/test splits for subject 3...

Processing fold 1/5
Train data shape: torch.Size([124, 5120])
Test data shape: torch.Size([124, 5120])
Length of train data: 4303
Length of test data: 1076
Processing train data...
Processing test data...


In [17]:

# Create data loaders
train_dataset = TensorDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)

# Initialize model
n_classes = len(torch.unique(y_train))
model = MLPClassifierGPU(
    input_size=X_train.shape[1],
    hidden_sizes=[256, 128],
    num_classes=n_classes
).to(device)

# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

# Training loop
print(f"Training MLP...")
model.train()
for epoch in range(100):  # Max 100 epochs
    total_loss = 0
    for batch_X, batch_y in train_loader:
        optimizer.zero_grad()
        outputs = model(batch_X)
        loss = criterion(outputs, batch_y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    
    if epoch % 1 == 0:
        print(f'Epoch {epoch}, Loss: {total_loss/len(train_loader):.4f}')

Training MLP...
Epoch 0, Loss: 25.5866
Epoch 1, Loss: 17.4546
Epoch 2, Loss: 14.5913
Epoch 3, Loss: 13.3507
Epoch 4, Loss: 10.6591
Epoch 5, Loss: 6.5103
Epoch 6, Loss: 5.4744
Epoch 7, Loss: 4.4197
Epoch 8, Loss: 4.3732
Epoch 9, Loss: 5.2332
Epoch 10, Loss: 4.7851
Epoch 11, Loss: 5.4961
Epoch 12, Loss: 3.4100
Epoch 13, Loss: 2.8403
Epoch 14, Loss: 1.7685
Epoch 15, Loss: 2.0723
Epoch 16, Loss: 1.8226
Epoch 17, Loss: 1.4923
Epoch 18, Loss: 2.0489
Epoch 19, Loss: 2.9949
Epoch 20, Loss: 2.1192
Epoch 21, Loss: 1.5683
Epoch 22, Loss: 1.7264
Epoch 23, Loss: 1.2500
Epoch 24, Loss: 1.3130
Epoch 25, Loss: 1.1035
Epoch 26, Loss: 1.2370
Epoch 27, Loss: 1.4844
Epoch 28, Loss: 1.4532
Epoch 29, Loss: 1.0816
Epoch 30, Loss: 0.9842
Epoch 31, Loss: 0.9395
Epoch 32, Loss: 1.7183
Epoch 33, Loss: 1.9088
Epoch 34, Loss: 1.2867
Epoch 35, Loss: 0.9715
Epoch 36, Loss: 0.7021
Epoch 37, Loss: 0.7989
Epoch 38, Loss: 0.8460
Epoch 39, Loss: 0.6240
Epoch 40, Loss: 1.0993
Epoch 41, Loss: 0.5955
Epoch 42, Loss: 0.6249


In [18]:

# Evaluation
model.eval()
with torch.no_grad():
    y_pred = model(X_test)
    y_score = torch.softmax(y_pred, dim=1)
    y_pred = torch.argmax(y_pred, dim=1)

# Convert predictions back to CPU for metric calculation
y_pred_cpu = y_pred.cpu().numpy()
y_test_cpu = y_test.cpu().numpy()
y_score_cpu = y_score.cpu().numpy()

# Calculate metrics
accuracy = accuracy_score(y_test_cpu, y_pred_cpu)

if n_classes == 2:
    auroc = roc_auc_score(y_test_cpu, y_score_cpu[:, 1])
else:
    auroc_per_class = []
    for i in range(n_classes):
        y_binary = (y_test_cpu == i).astype(int)
        auroc_per_class.append(roc_auc_score(y_binary, y_score_cpu[:, i]))
    auroc = np.mean(auroc_per_class)
    
fold_accuracies.append(accuracy)

# Store fold results
fold_results.append({
    'fold': fold + 1,
    'accuracy': float(accuracy),
    'auroc': float(auroc),
    'n_train_samples': len(y_train),
    'n_test_samples': len(y_test),
    'classification_report': classification_report(y_test_cpu, y_pred_cpu, output_dict=True)
})

if n_classes > 2:
    fold_results[-1]['auroc_per_class'] = {f'class_{i}': float(auc) for i, auc in enumerate(auroc_per_class)}

# Print fold results
print(f"Fold {fold + 1} Accuracy: {accuracy:.4f}")
print("\nClassification Report:")
print(classification_report(y_test_cpu, y_pred_cpu))

# Clean up memory
del X_train, y_train, X_test, y_test, y_pred
del model, accuracy, auroc
if n_classes > 2:
    del auroc_per_class, y_score
torch.cuda.empty_cache()
gc.collect()

# Calculate and print overall results
mean_accuracy = np.mean(fold_accuracies)
std_accuracy = np.std(fold_accuracies)
print("\nOverall Results:")
print(f"Mean Accuracy: {mean_accuracy:.4f}")
print(f"Std Accuracy: {std_accuracy:.4f}")

# Save results to JSON
results = {
    'subject_id': subject_id,
    'trial_id': trial_id,
    'eval_name': eval_name,
    'k_folds': k_folds,
    'mean_accuracy': float(mean_accuracy),
    'std_accuracy': float(std_accuracy),
    'fold_results': fold_results,
    'n_classes': int(n_classes)
}
results

Fold 1 Accuracy: 0.6041

Classification Report:
              precision    recall  f1-score   support

           0       0.82      0.58      0.67       768
           1       0.39      0.68      0.49       308

    accuracy                           0.60      1076
   macro avg       0.60      0.63      0.58      1076
weighted avg       0.69      0.60      0.62      1076


Overall Results:
Mean Accuracy: 0.6041
Std Accuracy: 0.0000


{'subject_id': 3,
 'trial_id': 0,
 'eval_name': 'volume',
 'k_folds': 5,
 'mean_accuracy': 0.604089219330855,
 'std_accuracy': 0.0,
 'fold_results': [{'fold': 1,
   'accuracy': 0.604089219330855,
   'auroc': 0.6651722301136365,
   'n_train_samples': 4303,
   'n_test_samples': 1076,
   'classification_report': {'0': {'precision': 0.8154981549815498,
     'recall': 0.5755208333333334,
     'f1-score': 0.6748091603053435,
     'support': 768.0},
    '1': {'precision': 0.3895131086142322,
     'recall': 0.6753246753246753,
     'f1-score': 0.49406175771971494,
     'support': 308.0},
    'accuracy': 0.604089219330855,
    'macro avg': {'precision': 0.602505631797891,
     'recall': 0.6254227543290043,
     'f1-score': 0.5844354590125292,
     'support': 1076.0},
    'weighted avg': {'precision': 0.6935619149433213,
     'recall': 0.604089219330855,
     'f1-score': 0.6230710562194945,
     'support': 1076.0}}}],
 'n_classes': 2}