In [1]:
import os
import numpy as np
import h5py
from scipy import stats
import scipy.io
import mne
import math 
mne.set_log_level('error')

from random import shuffle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from sklearn.utils import shuffle

from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score


import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

import optuna


from utils.load import Load
from config.default import cfg

%load_ext autoreload
%autoreload 2


In [2]:
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)

cuda


In [75]:
subject_data = {}
# Load the data  from the HDF5 file
target_dir = 'features'
tag = 'reproduced_with_bad'

for subject in cfg['subjects']:
    file_path = os.path.join(target_dir, tag+'_'+subject + '.h5')

    data = {}
    with h5py.File(file_path, 'r') as h5file:
        for key in h5file.keys():
            data[key] = np.array(h5file[key])

    subject_data[subject] = data


for subject_id in subject_data:
    print(subject_id)
    print(subject_data[subject_id].keys())

S1
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S2
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S3
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S4
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S5
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])


In [76]:
# 250 samples per subject
# 1250 total samples

In [88]:
train_percent = 0.8

np.random.seed(42)
global_train_features = []
global_train_labels = []
global_test_features = []
global_test_labels = []
for s, subject_id in enumerate(subject_data):
    data = subject_data[subject_id]

    for i, finger in enumerate(data):
        
        finger_data = data[finger]
        
        
        # RESHAPE
        finger_data = finger_data.reshape(finger_data.shape[0], -1)

        # Normalize
        #finger_data = StandardScaler().fit_transform(finger_data)

       
        np.random.shuffle(finger_data)

        # TO GPU
        finger_data = torch.tensor(finger_data).to(torch.float32).to(device)
        
        ids = torch.tensor(np.ones((len(finger_data))) * s).to(torch.int64).to(device)
        features = []
        for d in range(len(finger_data)):
            features.append((finger_data[d], ids[d]))

        labels = torch.tensor(np.ones((len(finger_data))) * i).to(device)

        train_features = features[:int(len(finger_data)*train_percent)]
        train_labels = labels[:int(len(finger_data)*train_percent)]
        test_features = features[int(len(finger_data)*train_percent):]
        test_labels = labels[int(len(finger_data)*train_percent):]

        global_train_features.extend(train_features)
        global_train_labels.extend(train_labels)
        global_test_features.extend(test_features)
        global_test_labels.extend(test_labels)
        
      

shuffler = np.random.permutation(len(global_train_features))
global_train_features = [global_train_features[i] for i in shuffler]
global_train_labels = [global_train_labels[i] for i in shuffler]

shuffler = np.random.permutation(len(global_test_features))
global_test_features = [global_test_features[i] for i in shuffler]
global_test_labels = [global_test_labels[i] for i in shuffler]

train_X = global_train_features
train_y = global_train_labels
test_X = global_test_features
test_y = global_test_labels

In [89]:
# feature = ([EEG], [subject_id])
# label = finger

In [90]:
# First 5 samples

for i in range(5):
    print(f'Feautres shape: {train_X[i][0].shape}')
    print(f'Subject_id: {train_X[i][1]}')
    print(f'Label: {train_y[i]}')
    print('------------------')

Feautres shape: torch.Size([8216])
Subject_id: 1
Label: 3.0
------------------
Feautres shape: torch.Size([8216])
Subject_id: 3
Label: 3.0
------------------
Feautres shape: torch.Size([8216])
Subject_id: 2
Label: 3.0
------------------
Feautres shape: torch.Size([8216])
Subject_id: 3
Label: 3.0
------------------
Feautres shape: torch.Size([8216])
Subject_id: 3
Label: 2.0
------------------


In [91]:
class CustomDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        features = self.X[idx][0]
        subject_id = self.X[idx][1]
        label = self.y[idx]
        
        return (features, subject_id), label

train_dataset = CustomDataset(train_X, train_y)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

test_dataset = CustomDataset(test_X, test_y)
test_dataloader = DataLoader(test_dataset, batch_size=16, shuffle=True)

In [106]:
class CrossAttention(nn.Module):
    def __init__(self, feature_dim, user_dim, hidden_dim, feautre_hidden_dim,  num_classes):
        super(CrossAttention, self).__init__()
        self.feature_dim = feature_dim
        self.user_dim = user_dim
        self.hidden_dim = hidden_dim
        self.feature_hidden_dim = feautre_hidden_dim
        self.num_queries = 3

        self.layers = nn.ModuleDict({
            'feature extractor' : nn.Linear(self.feature_dim,  self.feature_hidden_dim),
            'batch norm': nn.BatchNorm1d(self.feature_hidden_dim),
            'query layer' : nn.Linear( self.feature_hidden_dim, self.feature_hidden_dim * self.hidden_dim ),  # Query transformation
            'key layer' : nn.Linear(self.user_dim, self.hidden_dim *self.user_dim ),  # Key transformation
            'value layer': nn.Linear(self.user_dim, self.hidden_dim *self.user_dim ),  # Value transformation
            'dropout': nn.Dropout(0.2),
            'classifier' : nn.Linear(self.feature_hidden_dim * self.hidden_dim, num_classes)
        })
        self.softmax = nn.Softmax(dim=2)

    def forward(self, features, user_indices):
        # Transform features to hidden_dim
        features = self.layers['feature extractor'](features)  
        features = self.layers['batch norm'](features)
       
        # Convert user_indices to one_hot vectors
        user_one_hot = torch.zeros(user_indices.size(0), self.user_dim, device=user_indices.device)
        user_one_hot.scatter_(1, user_indices.unsqueeze(1), 1)
     



        # Transform user_one_hot to query
        query = self.layers['query layer'](features)  
        query = query.view(-1, self.feature_hidden_dim, self.hidden_dim )  


        # Transform features to keys and values
        keys = self.layers['key layer'](user_one_hot)  
        keys = keys.view(-1,  self.hidden_dim, self.user_dim)  
        values = self.layers['value layer'](user_one_hot) 
        values = values.view(-1,  self.user_dim, self.hidden_dim) 
        # print(f'query: {query.shape}')
        # print(f'keys: {keys.shape}')
        # print(f'values: {values.shape}')

       
        # Calculate attention scores
        attention_scores = torch.bmm(query, keys) 

        # Normalize the attention scores
        attention_scores = attention_scores / math.sqrt(self.hidden_dim)

        attention_probs = self.softmax(attention_scores) 
     
        # Calculate the attended features
        attended_features = torch.bmm(attention_probs, values) 
        attended_features = attended_features.view(-1, self.feature_hidden_dim * self.hidden_dim)
        
        # Dropout
        attended_features = self.layers['dropout'](attended_features)
        
        # Classify the attended features
        output = self.layers['classifier'](attended_features)

        return output


In [119]:
# Hyperparameters
input_size = 8216
hidden_dim = 2
feautre_hidden_dim = 8

num_subjects = 5
output_size = 5
learning_rate = 1e-3


epochs = 100


In [120]:
# Create model and pass data
model = CrossAttention(input_size, num_subjects, hidden_dim, feautre_hidden_dim, output_size)
model.to(device)
summary(model, input_size=(5, 10, input_size));

Layer (type:depth-idx)                   Param #
├─ModuleDict: 1-1                        --
|    └─Linear: 2-1                       65,736
|    └─BatchNorm1d: 2-2                  16
|    └─Linear: 2-3                       144
|    └─Linear: 2-4                       60
|    └─Linear: 2-5                       60
|    └─Dropout: 2-6                      --
|    └─Linear: 2-7                       85
├─Softmax: 1-2                           --
Total params: 66,101
Trainable params: 66,101
Non-trainable params: 0


In [121]:
output = model(train_X[0][0].unsqueeze(0), train_X[0][1].unsqueeze(0))
print(output.shape)
print(output)

torch.Size([1, 5])
tensor([[-0.5374, -0.3970, -0.0480,  0.1009, -0.1018]], device='cuda:0',
       grad_fn=<AddmmBackward0>)


In [68]:
def accuracy(dataloader):
    correct_predictions = 0
    total_predictions = 0

    with torch.no_grad():
        for features,  labels in dataloader:

            outputs = model(features = features[0], user_indices = features[1])
            _, predicted = torch.max(outputs.data, 1)
            total_predictions += labels.size(0)
            correct_predictions += (predicted == labels).sum().item()

    accuracy = correct_predictions / total_predictions

    return accuracy * 100
    #print(f"Accuracy: {accuracy * 100:.2f}% ({correct_predictions}/{total_predictions})")

In [122]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training loop
for epoch in range(epochs):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(features = batch_features[0], user_indices = batch_features[1])
        loss = criterion(outputs, batch_labels.long())
          
        # Backward propagation
        loss.backward()
        # Update the weights
        optimizer.step()

        epoch_loss += loss.item()

   

    if epoch % 10 == 0:
        train_accuracy = accuracy(train_dataloader)
        test_accuracy = accuracy(test_dataloader)
        print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")



Epoch 1/100, Loss: 102.6247946023941, Train accuracy: 23.60%, Test accuracy: 22.40%
Epoch 11/100, Loss: 31.48391565680504, Train accuracy: 86.00%, Test accuracy: 18.00%
Epoch 21/100, Loss: 11.367946427315474, Train accuracy: 96.80%, Test accuracy: 18.40%
Epoch 31/100, Loss: 5.2778927269391716, Train accuracy: 97.80%, Test accuracy: 16.80%
Epoch 41/100, Loss: 9.077177939703688, Train accuracy: 91.60%, Test accuracy: 23.60%
Epoch 51/100, Loss: 1.878758427221328, Train accuracy: 99.70%, Test accuracy: 20.00%
Epoch 61/100, Loss: 6.611761980224401, Train accuracy: 99.00%, Test accuracy: 19.20%
Epoch 71/100, Loss: 0.45220420393161476, Train accuracy: 99.90%, Test accuracy: 19.60%
Epoch 81/100, Loss: 0.25665980941266753, Train accuracy: 100.00%, Test accuracy: 20.00%
Epoch 91/100, Loss: 0.16855421553191263, Train accuracy: 100.00%, Test accuracy: 19.60%
