In [1]:
import os
import numpy as np
import h5py
from scipy import stats
import scipy.io
import mne
import math 
mne.set_log_level('error')

from random import shuffle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from sklearn.utils import shuffle

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score


import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

import optuna


from utils.load import Load
from config.default import cfg

%load_ext autoreload
%autoreload 2


torch.manual_seed(42)
np.random.seed(42)

In [2]:
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)

cuda


In [3]:
subject_data = {}
# Load the data  from the HDF5 file
target_dir = 'features'
tag = 'reproduced_with_bad'

for subject in cfg['subjects']:
    file_path = os.path.join(target_dir, tag+'_'+subject + '.h5')

    data = {}
    with h5py.File(file_path, 'r') as h5file:
        for key in h5file.keys():
            data[key] = np.array(h5file[key])

    subject_data[subject] = data


for subject_id in subject_data:
    print(subject_id)
    print(subject_data[subject_id].keys())

S1
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S2
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S3
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S4
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S5
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])


In [4]:
# 250 samples per subject
# 1250 total samples

In [180]:
class CustomDataset(Dataset):
    def __init__(self, subject_data, train_percent=0.8, seed=42, device=None, is_train=True):
        self.device = device
        self.is_train = is_train
        self.train_X, self.train_y, self.test_X, self.test_y = self.preprocess_data(subject_data, train_percent, seed)
        self.dim = self.train_X[0][0].shape

    def get_dim():
        return self.dim

    def preprocess_data(self, subject_data, train_percent, seed):
        np.random.seed(seed)
        global_train_features = []
        global_train_labels = []
        global_test_features = []
        global_test_labels = []

        for s, subject_id in enumerate(subject_data):
            data = subject_data[subject_id]

            for i, finger in enumerate(data):
                finger_data = data[finger]

               

                # Normalize (uncomment if needed)
                # finger_data = StandardScaler().fit_transform(finger_data)


                ids = torch.tensor(np.ones((len(finger_data))) * s).to(torch.int64).to(self.device)
                labels = torch.tensor(np.ones((len(finger_data))) * i).to(self.device)

                # To GPU
                finger_data = torch.tensor(finger_data).to(torch.float32).to(self.device)
                features = [(finger_data[d], ids[d]) for d in range(len(finger_data))]

                # Split
                train_features = features[:int(len(finger_data) * train_percent)]
                train_labels = labels[:int(len(finger_data) * train_percent)]
                test_features = features[int(len(finger_data) * train_percent):]
                test_labels = labels[int(len(finger_data) * train_percent):]



                global_train_features.extend(train_features)
                global_train_labels.extend(train_labels)
                global_test_features.extend(test_features)
                global_test_labels.extend(test_labels)

        # Shuffle
        global_train_features, global_train_labels = shuffle(global_train_features, global_train_labels)
        global_test_features, global_test_labels = shuffle(global_test_features, global_test_labels)

        return global_train_features, global_train_labels, global_test_features, global_test_labels

    
    def __len__(self):
        return len(self.train_y) if self.is_train else len(self.test_y)

    def __getitem__(self, idx):
        if self.is_train:
            return self.get_train_item(idx)
        else:
            return self.get_test_item(idx)

    def get_train_item(self, idx):
        features = self.train_X[idx][0]
        subject_id = self.train_X[idx][1]
        label = self.train_y[idx]

        return (features, subject_id), label

    def get_test_item(self, idx):
        features = self.test_X[idx][0]
        subject_id = self.test_X[idx][1]
        label = self.test_y[idx]

        return (features, subject_id), label




train_dataset = CustomDataset(subject_data, device=device, is_train=True)
test_dataset = CustomDataset(subject_data, device=device, is_train=False)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)

In [6]:
# feature = ([EEG], [subject_id])
# label = finger

In [19]:
for i, (feature, label) in enumerate(train_dataloader):
    print(feature[0].shape)
    print(feature[1].shape)
    print(label.shape)
    print('---------------')
    break

torch.Size([16, 8216])
torch.Size([16])
torch.Size([16])
---------------


In [500]:
class CrossAttention(nn.Module):
    def __init__(self, input_dim):
        super(CrossAttention, self).__init__()
        self.c = input_dim[0]
        self.t = input_dim[1]
        self.b = input_dim[2]
        
        self.user_dim = 5
        self.num_classes = 5

        #user_dim -> condition_dim
        # feature_dim -> EEG_dim
        self.embed_dim = 16
        self.k_dim = 16
        self.v_dim = 16

        self.layers = nn.ModuleDict({
            'reduce': nn.Linear(2,1),
            'W_q' : nn.Linear(self.t, self.embed_dim, bias = False),      # Query transformation
            'W_k' : nn.Linear(1, self.k_dim, bias = False),   # Key transformation
            'W_v': nn.Linear(1, self.v_dim, bias = False),  # Value transformation
            'dropout': nn.Dropout(0.2),
            'classifier' : nn.Linear(self.c * self.v_dim, self.num_classes)
        })
        self.softmax = nn.Softmax(dim=2)

    def forward(self, features, user_indices):
        features = features.reshape(-1,2)
        features = self.layers['reduce'](features)         
        features = features.view(-1, self.c, self. t)         


        # Convert user_indices to one_hot vectors
        user_one_hot = torch.zeros(user_indices.size(0), self.user_dim, device=user_indices.device)
        user_one_hot.scatter_(1, user_indices.unsqueeze(1), 1)
        user_one_hot = torch.unsqueeze(input=user_one_hot, dim=2)
      
   
        # Query Matrix
        query = self.layers['W_q'](features)

        
        # Key Matrix
        key = self.layers['W_k'](user_one_hot) 
        key = torch.transpose(key, 1, 2)


        # Attention
        attention_scores = torch.bmm(query, key) 
        attention_scores = attention_scores / math.sqrt(self.embed_dim)
        attention_probs = self.softmax(attention_scores) 
       
        # Value Matrix
        value = self.layers['W_v'](user_one_hot)
      
        # Calculate the attended features
        attended_features = torch.bmm(attention_probs, value) 


        # Flatten
        batch_size = attended_features.size(0)
        attended_features = attended_features.view(batch_size, -1)

        # Dropout
        attended_features = self.layers['dropout'](attended_features)
        
        # Classify the attended features
        output = self.layers['classifier'](attended_features)

        return output


In [526]:
class TorchAttention(nn.Module):
    def __init__(self, input_dim):
        super(TorchAttention, self).__init__()
        self.c = input_dim[0]
        self.t = input_dim[1]
        self.b = input_dim[2]
        
        self.user_dim = 5
        self.num_classes = 5

        #user_dim -> condition_dim
        # feature_dim -> EEG_dim
        self.embed_dim = 16
        self.k_dim = 16
        self.v_dim = 16

        self.attention = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=1, batch_first= True, dropout=0.2)
        self.layers = nn.ModuleDict({
            'reduce': nn.Linear(2,1),
            'W_q' : nn.Linear(self.t, self.embed_dim, bias = False),      # Query transformation
            'W_k' : nn.Linear(1, self.k_dim, bias = False),   # Key transformation
            'W_v': nn.Linear(1, self.v_dim, bias = False),  # Value transformation
            'dropout': nn.Dropout(0.2),
            'classifier' : nn.Linear(self.c * self.v_dim, self.num_classes)
        })
        self.softmax = nn.Softmax(dim=2)

    def forward(self, features, user_indices):
        features = features.reshape(-1,2)
        features = self.layers['reduce'](features)         
        features = features.view(-1, self.c, self. t)         


        # Convert user_indices to one_hot vectors
        user_one_hot = torch.zeros(user_indices.size(0), self.user_dim, device=user_indices.device)
        user_one_hot.scatter_(1, user_indices.unsqueeze(1), 1)
        user_one_hot = torch.unsqueeze(input=user_one_hot, dim=2)
      
   
        # Query Matrix
        query = self.layers['W_q'](features)

        
        # Key Matrix
        key = self.layers['W_k'](user_one_hot) 
        
       
        # Value Matrix
        value = self.layers['W_v'](user_one_hot)
      
        # Calculate the attended features
        attended_features, _ = self.attention(query, key, value)
      
        # Flatten
        batch_size = attended_features.size(0)
        attended_features = attended_features.reshape(batch_size, -1)

        # Dropout
        attended_features = self.layers['dropout'](attended_features)
        
        # Classify the attended features
        output = self.layers['classifier'](attended_features)

        return output


In [527]:

learning_rate = 1e-3
epochs = 200


In [528]:
# Create model and pass data
#model = CrossAttention(train_dataset.dim)
model = TorchAttention(train_dataset.dim)
model.to(device)
summary(model, input_size=(5, 10, input_size));

Layer (type:depth-idx)                             Param #
├─MultiheadAttention: 1-1                          --
|    └─NonDynamicallyQuantizableLinear: 2-1        272
├─ModuleDict: 1-2                                  --
|    └─Linear: 2-2                                 3
|    └─Linear: 2-3                                 416
|    └─Linear: 2-4                                 16
|    └─Linear: 2-5                                 16
|    └─Dropout: 2-6                                --
|    └─Linear: 2-7                                 12,645
├─Softmax: 1-3                                     --
Total params: 13,368
Trainable params: 13,368
Non-trainable params: 0


In [529]:
for i, (feature, label) in enumerate(train_dataloader):
    model(feature[0], feature[1])
    break


In [530]:
def accuracy(dataloader):
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for features,  labels in dataloader:

            outputs = model(features = features[0], user_indices = features[1])
            _, predicted = torch.max(outputs.data, 1)
            total_predictions += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    accuracy = correct_predictions / total_predictions

    return accuracy * 100
    #print(f"Accuracy: {accuracy * 100:.2f}% ({correct_predictions}/{total_predictions})")

In [531]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(features = batch_features[0], user_indices = batch_features[1])

        loss = criterion(outputs, batch_labels.long())
          
        # Backward propagation
        loss.backward()
        # Update the weights
        optimizer.step()

        epoch_loss += loss.item()

   

    if epoch % 10 == 9:
        train_accuracy = accuracy(train_dataloader)
        test_accuracy = accuracy(test_dataloader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(train_dataloader):.2f}%')
print(f'Final test accuracy: {accuracy(test_dataloader):.2f}%')

Epoch 10/200, Loss: 91.97771072387695, Train accuracy: 41.70%, Test accuracy: 23.60%
Epoch 20/200, Loss: 83.3547922372818, Train accuracy: 48.00%, Test accuracy: 22.00%
Epoch 30/200, Loss: 80.10269731283188, Train accuracy: 48.00%, Test accuracy: 19.20%
Epoch 40/200, Loss: 78.47107994556427, Train accuracy: 49.80%, Test accuracy: 22.80%
Epoch 50/200, Loss: 78.2370075583458, Train accuracy: 48.20%, Test accuracy: 18.80%
Epoch 60/200, Loss: 76.95687717199326, Train accuracy: 49.10%, Test accuracy: 20.00%
Epoch 70/200, Loss: 75.79242211580276, Train accuracy: 49.80%, Test accuracy: 22.40%
Epoch 80/200, Loss: 75.36960899829865, Train accuracy: 52.00%, Test accuracy: 21.20%
Epoch 90/200, Loss: 74.15151876211166, Train accuracy: 51.40%, Test accuracy: 19.60%
Epoch 100/200, Loss: 74.62020021677017, Train accuracy: 53.60%, Test accuracy: 20.40%
Epoch 110/200, Loss: 73.56184983253479, Train accuracy: 52.10%, Test accuracy: 21.20%
Epoch 120/200, Loss: 72.79227709770203, Train accuracy: 53.40%, T