In [1]:
import os
import numpy as np
import h5py
from scipy import stats
import scipy.io
import mne
import math 
mne.set_log_level('error')

from random import shuffle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from sklearn.utils import shuffle

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score


import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

import optuna


from utils.load import Load
from config.default import cfg

%load_ext autoreload
%autoreload 2


torch.manual_seed(42)
np.random.seed(42)

In [2]:
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)

cuda


In [3]:
subject_data = {}
# Load the data  from the HDF5 file
target_dir = 'features'
tag = 'reproduced_with_bad'

for subject in cfg['subjects']:
    file_path = os.path.join(target_dir, tag+'_'+subject + '.h5')

    data = {}
    with h5py.File(file_path, 'r') as h5file:
        for key in h5file.keys():
            data[key] = np.array(h5file[key])

    subject_data[subject] = data


for subject_id in subject_data:
    print(subject_id)
    print(subject_data[subject_id].keys())

S1
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S2
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S3
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S4
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S5
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])


In [4]:
# 250 samples per subject
# 1250 total samples

In [5]:
class CustomDataset(Dataset):
    def __init__(self, subject_data, train_percent=0.8, seed=42, device=None, is_train=True):
        self.device = device
        self.is_train = is_train
        self.train_X, self.train_y, self.test_X, self.test_y = self.preprocess_data(subject_data, train_percent, seed)


    def preprocess_data(self, subject_data, train_percent, seed):
        np.random.seed(seed)
        global_train_features = []
        global_train_labels = []
        global_test_features = []
        global_test_labels = []

        for s, subject_id in enumerate(subject_data):
            data = subject_data[subject_id]

            for i, finger in enumerate(data):
                finger_data = data[finger]

                # Reshape
                finger_data = finger_data.reshape(finger_data.shape[0], -1)

                # Normalize (uncomment if needed)
                # finger_data = StandardScaler().fit_transform(finger_data)


                ids = torch.tensor(np.ones((len(finger_data))) * s).to(torch.int64).to(self.device)
                labels = torch.tensor(np.ones((len(finger_data))) * i).to(self.device)

                # To GPU
                finger_data = torch.tensor(finger_data).to(torch.float32).to(self.device)
                features = [(finger_data[d], ids[d]) for d in range(len(finger_data))]

                # Split
                train_features = features[:int(len(finger_data) * train_percent)]
                train_labels = labels[:int(len(finger_data) * train_percent)]
                test_features = features[int(len(finger_data) * train_percent):]
                test_labels = labels[int(len(finger_data) * train_percent):]



                global_train_features.extend(train_features)
                global_train_labels.extend(train_labels)
                global_test_features.extend(test_features)
                global_test_labels.extend(test_labels)

        # Shuffle
        global_train_features, global_train_labels = shuffle(global_train_features, global_train_labels)
        global_test_features, global_test_labels = shuffle(global_test_features, global_test_labels)

        return global_train_features, global_train_labels, global_test_features, global_test_labels

    
    def __len__(self):
        return len(self.train_y) if self.is_train else len(self.test_y)

    def __getitem__(self, idx):
        if self.is_train:
            return self.get_train_item(idx)
        else:
            return self.get_test_item(idx)

    def get_train_item(self, idx):
        features = self.train_X[idx][0]
        subject_id = self.train_X[idx][1]
        label = self.train_y[idx]

        return (features, subject_id), label

    def get_test_item(self, idx):
        features = self.test_X[idx][0]
        subject_id = self.test_X[idx][1]
        label = self.test_y[idx]

        return (features, subject_id), label




train_dataset = CustomDataset(subject_data, device=device, is_train=True)
test_dataset = CustomDataset(subject_data, device=device, is_train=False)

train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)

In [6]:
# feature = ([EEG], [subject_id])
# label = finger

In [19]:
for i, (feature, label) in enumerate(train_dataloader):
    print(feature[0].shape)
    print(feature[1].shape)
    print(label.shape)
    print('---------------')
    break

torch.Size([16, 8216])
torch.Size([16])
torch.Size([16])
---------------


In [152]:
# Hyperparameters
input_size = 8216
hidden_dim = 2
feautre_hidden_dim = 8

num_subjects = 5
output_size = 5

In [156]:
class CrossAttention(nn.Module):
    def __init__(self, feature_dim, user_dim, hidden_dim, feautre_hidden_dim,  num_classes):
        super(CrossAttention, self).__init__()
        self.feature_dim = feature_dim
        self.user_dim = user_dim
        self.hidden_dim = hidden_dim
        self.feature_hidden_dim = feautre_hidden_dim
        self.num_queries = 3


        #user_dim -> condition_dim
        # feature_dim -> EEG_dim
        self.embed_dim = 4
        self.k_dim = 4
        self.v_dim = 4

        self.layers = nn.ModuleDict({
            'feature extractor' : nn.Linear(self.feature_dim,  self.feature_hidden_dim),
            'batch norm': nn.BatchNorm1d(self.feature_hidden_dim),
            'query layer' : nn.Linear( self.feature_hidden_dim, self.embed_dim),      # Query transformation
            'key layer' : nn.Linear(self.user_dim, self.k_dim  ),   # Key transformation
            'value layer': nn.Linear(self.user_dim, self.v_dim  ),  # Value transformation
            'dropout': nn.Dropout(0.2),
            'classifier' : nn.Linear(self.feature_hidden_dim * self.embed_dim, num_classes)
        })
        self.softmax = nn.Softmax(dim=2)

    def forward(self, features, user_indices):
        
        # Convert user_indices to one_hot vectors
        user_one_hot = torch.zeros(user_indices.size(0), self.user_dim, device=user_indices.device)
        user_one_hot.scatter_(1, user_indices.unsqueeze(1), 1)

        # Transform features to hidden_dim
        features = self.layers['feature extractor'](features)  
        features = self.layers['batch norm'](features)
       

        # Query Matrix
        d_q = self.layers['query layer'](features)
        d_q = torch.unsqueeze(input=d_q, dim=1)
        features = torch.unsqueeze(input=features, dim=2)  
        query = torch.bmm(features, d_q)


        # Key Matrix
        d_k = self.layers['key layer'](user_one_hot)  
        d_k = torch.unsqueeze(input=d_k, dim=1)
        user_one_hot = torch.unsqueeze(input=user_one_hot, dim=2)
        key = torch.bmm(user_one_hot, d_k)
        user_one_hot = torch.square(input=user_one_hot)
        key = torch.transpose(key, 1, 2)

       
        # Attention
        attention_scores = torch.bmm(query, key) 
        attention_scores = attention_scores / math.sqrt(self.embed_dim)
        attention_probs = self.softmax(attention_scores) 
       
        # Value Matrix
        user_one_hot = torch.squeeze(user_one_hot)
        d_v = self.layers['value layer'](user_one_hot) 
        user_one_hot = torch.unsqueeze(input=user_one_hot, dim=2)
        d_v = torch.unsqueeze(input=d_v, dim=1)
        values = torch.bmm(user_one_hot, d_v)
    
       
        # Calculate the attended features
        attended_features = torch.bmm(attention_probs, values) 
        batch_size = attended_features.size(0)
        attended_features = attended_features.view(batch_size, -1)
        
        # Dropout
        attended_features = self.layers['dropout'](attended_features)
        
        # Classify the attended features
        output = self.layers['classifier'](attended_features)

        return output


In [157]:

learning_rate = 1e-3
epochs = 200


In [158]:
# Create model and pass data
model = CrossAttention(input_size, num_subjects, hidden_dim, feautre_hidden_dim, output_size)
model.to(device)
summary(model, input_size=(5, 10, input_size));

Layer (type:depth-idx)                   Param #
├─ModuleDict: 1-1                        --
|    └─Linear: 2-1                       65,736
|    └─BatchNorm1d: 2-2                  16
|    └─Linear: 2-3                       36
|    └─Linear: 2-4                       24
|    └─Linear: 2-5                       24
|    └─Dropout: 2-6                      --
|    └─Linear: 2-7                       165
├─Softmax: 1-2                           --
Total params: 66,001
Trainable params: 66,001
Non-trainable params: 0


In [159]:
for i, (feature, label) in enumerate(train_dataloader):
    model(feature[0], feature[1])
    break


In [160]:
def accuracy(dataloader):
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for features,  labels in dataloader:

            outputs = model(features = features[0], user_indices = features[1])
            _, predicted = torch.max(outputs.data, 1)
            total_predictions += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    accuracy = correct_predictions / total_predictions

    return accuracy * 100
    #print(f"Accuracy: {accuracy * 100:.2f}% ({correct_predictions}/{total_predictions})")

In [162]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(features = batch_features[0], user_indices = batch_features[1])

        loss = criterion(outputs, batch_labels.long())
          
        # Backward propagation
        loss.backward()
        # Update the weights
        optimizer.step()

        epoch_loss += loss.item()

   

    if epoch % 10 == 9:
        train_accuracy = accuracy(train_dataloader)
        test_accuracy = accuracy(test_dataloader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(train_dataloader):.2f}%')
print(f'Final test accuracy: {accuracy(test_dataloader):.2f}%')

Epoch 10/200, Loss: 67.65774589776993, Train accuracy: 55.60%, Test accuracy: 21.60%
Epoch 20/200, Loss: 65.16666907072067, Train accuracy: 57.10%, Test accuracy: 21.60%
Epoch 30/200, Loss: 70.23433631658554, Train accuracy: 57.50%, Test accuracy: 18.80%
Epoch 40/200, Loss: 61.79432076215744, Train accuracy: 60.40%, Test accuracy: 20.00%
Epoch 50/200, Loss: 61.38509237766266, Train accuracy: 60.80%, Test accuracy: 23.60%
Epoch 60/200, Loss: 62.58268052339554, Train accuracy: 58.00%, Test accuracy: 22.80%
Epoch 70/200, Loss: 59.481901705265045, Train accuracy: 59.20%, Test accuracy: 23.20%
Epoch 80/200, Loss: 65.20154601335526, Train accuracy: 56.30%, Test accuracy: 24.40%
Epoch 90/200, Loss: 63.99209460616112, Train accuracy: 55.80%, Test accuracy: 25.20%
Epoch 100/200, Loss: 64.48908042907715, Train accuracy: 57.60%, Test accuracy: 24.80%
Epoch 110/200, Loss: 63.35892075300217, Train accuracy: 58.60%, Test accuracy: 26.80%
Epoch 120/200, Loss: 68.09774720668793, Train accuracy: 55.00%