In [141]:
import numpy as np
import scipy.io
import torch
from torch.utils.data import DataLoader
import os
import pandas as pd
import matplotlib.pyplot as plt

data from https://www.kaggle.com/datasets/amananandrai/complete-eeg-dataset/data

In [142]:
labels=np.array([0,1,1,1,0,1,0,1,1,0,0,1,1,1,0,1,1,1,1,0,1,0,0,1,1,1,1,1,1,1,0,1,1,1,1,1])
len(labels)

36

In [143]:
counts = [str(i) for i in range(31)]
counts = ['0' + i if len(i) == 1 else i for i in counts]


In [144]:
# set construct_files to True if you want to construct the files
construct_files = False
if construct_files : 
    np.random.seed(0)

    for file_name in counts : 
        df = pd.read_csv('./../data/kaggle_2/s'+file_name+'.csv', header=None).transpose().to_numpy()
        for i in range(31*4): 
            sample = pd.DataFrame(df[:, 250*i:250*i+1000]) # slicing of 250 
            random_float = np.random.rand()
            if random_float < 0.1:
                sample.to_csv('./../data/test/'+file_name+'_'+str(i)+'_'+'.csv', index=False, header=False)
            elif random_float < 0.3 : 
                sample.to_csv('./../data/validation/'+file_name+'_'+str(i)+'_'+'.csv', index=False, header=False)
            else:
                sample.to_csv('./../data/train/'+file_name+'_'+str(i)+'_'+'.csv', index=False, header=False)
            


Chaque eeg = 19 x 31 000
On va les découper en patchs de 1000

In [145]:
class Mydataset(torch.utils.data.Dataset):
    def __init__(self, path_to_data):
        self.path_to_data = path_to_data # par example './../data/train/
        self.X = os.listdir(self.path_to_data) # the list of the files in the train set 


    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        labels=np.array([0,1,1,1,0,1,0,1,1,0,0,1,1,1,0,1,1,1,1,0,1,0,0,1,1,1,1,1,1,1,0,1,1,1,1,1])
        x = pd.read_csv(self.path_to_data + self.X[idx], header=None).to_numpy()
        record_number = int(self.X[idx].split('_')[0])
        return torch.tensor(x), torch.tensor(labels[record_number]) # un eeg sur 1000 échantillonages et le label correspondant



In [146]:
dataloader_train = DataLoader(Mydataset('./../data/train/'), batch_size=5, shuffle=True)
dataloader_val = DataLoader(Mydataset('./../data/validation/'), batch_size=5, shuffle=True)
dataloader_test = DataLoader(Mydataset('./../data/test/'), batch_size=5, shuffle=True)

for batch in dataloader_train :
  print(batch[0].shape)
  print(batch[1])
  break


torch.Size([5, 19, 1000])
tensor([1, 1, 1, 1, 1], dtype=torch.int32)


chatgpt code : 

In [147]:
import torch
import torch.nn as nn

class EEGFeatureExtractor(nn.Module):
    def __init__(self, feature_dim=100):
        super(EEGFeatureExtractor, self).__init__()
        # input [batch_size, 19, 1000]
        self.conv1= nn.Conv1d(19, 32, 3, padding=1)
        self.conv2= nn.Conv1d(32, 64, 3, padding=1)
        self.conv3= nn.Conv1d(64, 128, 3, padding=1)
        self.conv4= nn.Conv1d(6, 10, 3, padding=1)   
        self.conv5 = nn.Conv1d(64, 15, 3, padding=1)
        self.pool = nn.MaxPool1d(1, 13)
        self.dropout = nn.Dropout(0.5)
        self.fc = nn.Linear(750, 100)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.pool(x)
        x = self.relu(self.conv2(x))
        x = self.pool(x)
        x = torch.transpose(x, 1, 2)
        x = self.relu(self.conv4(x))
        x = torch.transpose(x, 1, 2)
        x = self.conv5(x)
        x = x.flatten()
        x = self.dropout(x)
        x = self.fc(x)
        return self.relu(x)


In [148]:
for batch in dataloader_train :
  x = batch[0]
  model = EEGFeatureExtractor()
  print(model(x.float()).shape)  
  break

torch.Size([100])


In [149]:
class EEGClassifier(nn.Module):
    def __init__(self, feature_extractor, num_classes = 2):
        super(EEGClassifier, self).__init__()
        self.feature_extractor = feature_extractor
        self.classifier = nn.Linear(100, num_classes)

    def forward(self, x):
        features = self.feature_extractor(x)
        return self.classifier(features)
