In [1]:
import os
import numpy as np
import h5py
from scipy import stats
import scipy.io
import mne

mne.set_log_level('error')

from random import shuffle
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import accuracy_score
from sklearn.model_selection import KFold
from sklearn.utils import shuffle



import torch
import torch.nn as nn
import torch.optim as optim
from torchsummary import summary

import optuna


from utils.load import Load
from config.default import cfg

%load_ext autoreload
%autoreload 2


In [2]:
device_name = 'cuda' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)
print(device)

cuda


In [3]:
subject_data = {}
# Load the data  from the HDF5 file
target_dir = 'features'
tag = 'reproduced_with_bad'

for subject in cfg['subjects']:
    file_path = os.path.join(target_dir, tag+'_'+subject + '.h5')

    data = {}
    with h5py.File(file_path, 'r') as h5file:
        for key in h5file.keys():
            data[key] = np.array(h5file[key])

    subject_data[subject] = data


for subject_id in subject_data:
    print(subject_id)
    print(subject_data[subject_id].keys())

S1
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S2
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S3
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S4
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])
S5
dict_keys(['index', 'little', 'middle', 'ring', 'thumb'])


In [4]:
# 250 samples per subject
# 1250 total samples

In [5]:
def create_mini_batches(data, batch_size_per_subject):
    subject_data = []
    subject_labels = []

    for subject_id, (subject, data) in enumerate(data.items()):
        features = np.concatenate(list(data.values()), axis=0)
        labels = []
        for i, key in enumerate(data.keys()):
            labels.append(np.ones(len(data[key])) * i)
        
        labels = np.concatenate(labels, axis=0)
    

        features = features.reshape(features.shape[0], -1)
        
        np.random.seed(42)
        indices = np.arange(labels.shape[0])
        np.random.shuffle(indices)

        # Apply shuffled indices to features and labels
        features = features[indices]
        labels = labels[indices]

        subject_data.append(features)
        subject_labels.append(labels)

    
    mini_batches = []
    
    for i in range(0, len(subject_data[0]), batch_size_per_subject):

        mini_batch = []
        mini_features = []
        mini_labels = []

        for s in range(len(subject_data)):
            mini_features.append(subject_data[s][i:i+batch_size_per_subject])
            mini_labels.append(subject_labels[s][i:i+batch_size_per_subject])

        mini_features = torch.Tensor(np.stack(mini_features, axis=0)).to(device)
        mini_labels = torch.Tensor(np.stack(mini_labels, axis=0)).to(device)
        mini_featurs = mini_features.permute(1, 0, 2)
        mini_labels = mini_labels.permute(0, 1)


        mini_batches.append((mini_features, mini_labels))
    
    return mini_batches

In [6]:
mini_batches = create_mini_batches(subject_data, batch_size_per_subject=10)
# mini_batches: A list of tuples, with each tuple representing a mini-batch, having a length equal to 'batch_size':
#   - Each tuple contains:
#       1. batch_data (list): A list of length 'num_subjects', where each element is a PyTorch tensor of shape (batch_size, num_features) representing one subject's data.
#       2. batch_labels (tensor): A PyTorch tensor of shape (batch_size * num_subjects,) containing the concatenated labels for all subjects in the mini-batch.

print(f'Number of minibatches: {len(mini_batches)}')

print(f'Number of subjects: {len(mini_batches[0][0])}')

# 25 minibatches
# 5 * 10 samples per minibatch (10 samples per subject)
# 1280 samples total

Number of minibatches: 25
Number of subjects: 5


In [7]:
train_mini_batches = mini_batches[:20]
test_mini_batches = mini_batches[20:]

In [8]:
class CustomMLP(nn.Module):
    def __init__(self, input_size, subject_hidden_size, shared_hidden_size, num_subjects, output_size):
        super(CustomMLP, self).__init__()
        self.subject_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(input_size, subject_hidden_size),
                nn.ReLU()
            ) for _ in range(num_subjects)
        ])
        self.shared_layer = nn.Sequential(
            # nn.Linear(subject_hidden_size, shared_hidden_size),
            # nn.ReLU(),
            #nn.Linear(shared_hidden_size, output_size),
            nn.Linear(subject_hidden_size, output_size),
        )

    def forward(self, data):
        subject_outputs = []
        for x, subject_layer in zip(data, self.subject_layers):


            # Pass the input data (x) through the subject-specific layer
            hidden_output = subject_layer(x)
            
            # Pass the output from the subject-specific layer through the shared layer
            shared_output = self.shared_layer(hidden_output)
            

            subject_outputs.append(shared_output)
        return torch.stack(subject_outputs)


In [9]:
# Hyperparameters
batch_size = 10
input_size = 8216
subject_hidden_size = 1
shared_hidden_size = 5000
num_subjects = 5
output_size = 5
learning_rate = 0.001


epochs = 100


In [10]:
model = CustomMLP(input_size, subject_hidden_size, shared_hidden_size, num_subjects, output_size)
summary(model, input_size=(5, 10, input_size));

Layer (type:depth-idx)                   Param #
├─ModuleList: 1-1                        --
|    └─Sequential: 2-1                   --
|    |    └─Linear: 3-1                  8,217
|    |    └─ReLU: 3-2                    --
|    └─Sequential: 2-2                   --
|    |    └─Linear: 3-3                  8,217
|    |    └─ReLU: 3-4                    --
|    └─Sequential: 2-3                   --
|    |    └─Linear: 3-5                  8,217
|    |    └─ReLU: 3-6                    --
|    └─Sequential: 2-4                   --
|    |    └─Linear: 3-7                  8,217
|    |    └─ReLU: 3-8                    --
|    └─Sequential: 2-5                   --
|    |    └─Linear: 3-9                  8,217
|    |    └─ReLU: 3-10                   --
├─Sequential: 1-2                        --
|    └─Linear: 2-6                       10
Total params: 41,095
Trainable params: 41,095
Non-trainable params: 0


In [12]:
def calculate_subject_accuracy(model, test_mini_batches):
    subject_correct = {}
    subject_total_samples = {}
    
    for batch_data, batch_labels in test_mini_batches:
        subject_outputs = model(batch_data)
        
        for subject_output, subject_label in zip(subject_outputs, batch_labels):
            _, predicted = torch.max(subject_output.data, 1)
            
            for pred, true_label in zip(predicted, subject_label):
                true_label = true_label.item()
                pred = pred.item()
                
                if true_label not in subject_total_samples:
                    subject_total_samples[true_label] = 0
                if true_label not in subject_correct:
                    subject_correct[true_label] = 0
                
                subject_total_samples[true_label] += 1
                subject_correct[true_label] += int(pred == true_label)
    
    subject_accuracy = {}
    for subject in subject_total_samples:
        subject_accuracy[subject] = subject_correct[subject] / subject_total_samples[subject]
        
    return subject_accuracy

In [13]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate)
model.to(device)

# Training loop
for epoch in range(epochs):
    epoch_loss = 0.0

    epoch_accuracy = 0.0
    total_samples = 0
    correct = 0
    for i, (batch_data, batch_labels) in enumerate(train_mini_batches):

        #print(f'batch_data_shape: {batch_data.shape}')
        #print(f'batch_labels_shape: {batch_labels.shape}')
 


        optimizer.zero_grad()

        # Forward propagate the data through each subject's first layer and then through the shared second layer
        subject_outputs = model(batch_data)
        #print(f'subject_outputs_shape: {subject_outputs.shape}')


        # Create an empty list to store the losses for each subject
        losses = []
        # Iterate over the subject outputs and ground-truth labels
        for subject_output, subject_labels in zip(subject_outputs, batch_labels):
            loss = criterion(subject_output, subject_labels.long())
            losses.append(loss)

        # Calculate the average loss across subjects
        loss = torch.mean(torch.stack(losses))

        # Backward propagation
        loss.backward()

        # Update the weights
        optimizer.step()

        epoch_loss += loss.item()

        # Calculate the accuracy
        
     
        for subject_output, subject_label in zip(subject_outputs, batch_labels):
            _, predicted = torch.max(subject_output.data, 1)
            total_samples += subject_label.size(0)
            correct += (predicted == subject_label).sum().item()
    epoch_accuracy += correct / total_samples


    test_epoch_accuracy = 0.0
    test_total_samples = 0
    test_correct = 0
    for batch_data, batch_labels in test_mini_batches:
            subject_outputs = model(batch_data)
            for subject_output, subject_label in zip(subject_outputs, batch_labels):
                _, predicted = torch.max(subject_output.data, 1)
                test_total_samples += subject_label.size(0)
                test_correct += (predicted == subject_label).sum().item()
    test_epoch_accuracy += test_correct / test_total_samples

      

    print(f"Epoch {epoch + 1}/{epochs}, Loss: {epoch_loss}, Accuracy: {epoch_accuracy}, Test Accuracy: {test_epoch_accuracy}")
    print(calculate_subject_accuracy(model, test_mini_batches))


Epoch 1/100, Loss: 36.01104915142059, Accuracy: 0.205, Test Accuracy: 0.18
{4.0: 0.0, 1.0: 0.0, 2.0: 1.0, 3.0: 0.0, 0.0: 0.0}
Epoch 2/100, Loss: 35.911216735839844, Accuracy: 0.205, Test Accuracy: 0.18
{4.0: 0.0, 1.0: 0.0, 2.0: 1.0, 3.0: 0.0, 0.0: 0.0}
Epoch 3/100, Loss: 35.82422959804535, Accuracy: 0.205, Test Accuracy: 0.18
{4.0: 0.0, 1.0: 0.0, 2.0: 1.0, 3.0: 0.0, 0.0: 0.0}
Epoch 4/100, Loss: 35.74298298358917, Accuracy: 0.205, Test Accuracy: 0.18
{4.0: 0.0, 1.0: 0.0, 2.0: 1.0, 3.0: 0.0, 0.0: 0.0}
Epoch 5/100, Loss: 35.66886007785797, Accuracy: 0.205, Test Accuracy: 0.18
{4.0: 0.0, 1.0: 0.0, 2.0: 1.0, 3.0: 0.0, 0.0: 0.0}
Epoch 6/100, Loss: 35.60084044933319, Accuracy: 0.205, Test Accuracy: 0.18
{4.0: 0.0, 1.0: 0.0, 2.0: 1.0, 3.0: 0.0, 0.0: 0.0}
Epoch 7/100, Loss: 35.53451216220856, Accuracy: 0.205, Test Accuracy: 0.18
{4.0: 0.0, 1.0: 0.0, 2.0: 1.0, 3.0: 0.0, 0.0: 0.0}
Epoch 8/100, Loss: 35.466535806655884, Accuracy: 0.205, Test Accuracy: 0.18
{4.0: 0.0, 1.0: 0.0, 2.0: 1.0, 3.0: 0.0, 