In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from scipy import signal
from sklearn.preprocessing import RobustScaler
import scipy.io as sio

In [1]:
class FingerFlexDataset(Dataset):
    def __init__(self, ecog, finger, window_size=256, delay=2):
        self.X = []
        self.Y = []
        for i in range(0, len(ecog) - window_size - delay):
            self.X.append(ecog[i:i+window_size])
            self.Y.append(finger[i+delay:i+delay+window_size])
        self.X = np.stack(self.X)
        self.Y = np.stack(self.Y)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        x = torch.tensor(self.X[idx], dtype=torch.float32).permute(1, 0)
        y = torch.tensor(self.Y[idx], dtype=torch.float32)
        return x, y

# ----------------------
# Model Architecture
# ----------------------
class FingerFlexModel(nn.Module):
    def __init__(self, input_channels, output_channels=5):
        super(FingerFlexModel, self).__init__()
        self.enc1 = self.conv_block(input_channels, 32)
        self.enc2 = self.conv_block(32, 32)
        self.enc3 = self.conv_block(32, 64)
        self.enc4 = self.conv_block(64, 64)
        self.enc5 = self.conv_block(64, 128)
        self.enc6 = self.conv_block(128, 128)
        self.dec1 = self.deconv_block(128, 128)
        self.dec2 = self.deconv_block(256, 64)
        self.dec3 = self.deconv_block(128, 64)
        self.dec4 = self.deconv_block(128, 32)
        self.dec5 = self.deconv_block(64, 32)
        self.final_conv = nn.Conv1d(64, output_channels, kernel_size=1)
    def conv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv1d(in_c, out_c, kernel_size=3, padding=1),
            nn.LayerNorm([out_c, 256]),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.MaxPool1d(kernel_size=2, stride=2)
        )
    def deconv_block(self, in_c, out_c):
        return nn.Sequential(
            nn.Conv1d(in_c, out_c, kernel_size=3, padding=1),
            nn.GELU(),
            nn.Dropout(0.1),
            nn.Upsample(scale_factor=2, mode='nearest')
        )
    def forward(self, x):
        skip = []
        x1 = self.enc1(x); skip.append(x1)
        x2 = self.enc2(x1); skip.append(x2)
        x3 = self.enc3(x2); skip.append(x3)
        x4 = self.enc4(x3); skip.append(x4)
        x5 = self.enc5(x4); skip.append(x5)
        x6 = self.enc6(x5)
        d1 = self.dec1(x6)
        d2 = self.dec2(torch.cat([d1, skip[4]], dim=1))
        d3 = self.dec3(torch.cat([d2, skip[3]], dim=1))
        d4 = self.dec4(torch.cat([d3, skip[2]], dim=1))
        d5 = self.dec5(torch.cat([d4, skip[1]], dim=1))
        out = self.final_conv(torch.cat([d5, skip[0]], dim=1))
        return out.permute(0, 2, 1)

# ----------------------
# Combined Loss Function
# ----------------------
class CombinedLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    def forward(self, pred, target):
        mse = self.mse(pred, target)
        cos = F.cosine_similarity(pred, target, dim=-1).mean()
        return 0.5 * (mse + (1 - cos))

# ----------------------
# Training Loop
# ----------------------
def train_model(model, train_loader, epochs=10, lr=8.4e-5):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = CombinedLoss()
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for x, y in train_loader:
            optimizer.zero_grad()
            pred = model(x)
            loss = loss_fn(pred, y)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        print(f"Epoch {epoch + 1}: Loss = {total_loss / len(train_loader):.4f}")


In [2]:
from scipy.stats import pearsonr

def evaluate(model, test_loader):
    model.eval()
    all_preds, all_targets = [], []
    with torch.no_grad():
        for x, y in test_loader:
            y_hat = model(x)
            all_preds.append(y_hat.numpy())
            all_targets.append(y.numpy())

    preds = np.concatenate(all_preds, axis=0)
    targets = np.concatenate(all_targets, axis=0)
    correlations = [pearsonr(preds[:, :, i].flatten(), targets[:, :, i].flatten())[0] for i in range(5)]
    return correlations

In [4]:
train_data = sio.loadmat('raw_training_data.mat')
ecog = train_data['train_ecog']
data_glove = train_data['train_dg']