In [1]:
import os
from mne.datasets import eegbci

from data_loader import load_raw_edf, edf_to_numpy

In [2]:
runs = [4, 8, 12]

train_subjects = [1, 2]
test_subjects = [3]

train_trials = []
train_labels = []
test_trials = []
test_labels = []

for kind in ['train', 'test']:
    with open(f'{kind}_annotations.csv', 'w') as f:
        f.write('filename, label\n')

for subject in train_subjects:
    for fname in eegbci.load_data(subject, runs):
        edf = load_raw_edf(fname)
        
        pre_name = fname.split('/')[-1].split('.')[0]
        trials, labels = edf_to_numpy(edf, max_length=656, save=True, data_dir='train', prefix=pre_name)
        
        train_trials.extend(trials)
        train_labels.extend(labels)

for subject in test_subjects:
    for fname in eegbci.load_data(subject, runs):
        edf = load_raw_edf(fname)
        
        pre_name = fname.split('/')[-1].split('.')[0]
        trials, labels = edf_to_numpy(edf, max_length=656, save=True, data_dir='test', prefix=pre_name)
        
        test_trials.extend(trials)
        test_labels.extend(labels)

Extracting EDF parameters from /Users/marktaylor/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R04.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /Users/marktaylor/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R08.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /Users/marktaylor/mne_data/MNE-eegbci-data/files/eegmmidb/1.0.0/S001/S001R12.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 19999  =      0.000 ...   124.994 secs...
Used Annotations descriptions: ['T0', 'T1', 'T2']
Extracting EDF parameters from /Users/marktaylor/mne_data/MNE-eegbci-data/files/eegmmidb

In [32]:
for fname in eegbci.load_data(subject, runs):
    print(fname.split('/')[-1].split('.')[0])

S002R04
S002R08
S002R12


In [29]:
# make all trials as long as shortest
max_trial_length = min([t.shape[1] for t in all_trials])
all_trials = [t[:,:max_trial_length] for t in all_trials]

In [12]:
import pandas as pd
pd.read_csv('annotations.csv')

Unnamed: 0,filename,label
0,S001R04_0.csv,T0
1,S001R04_1.csv,T2
2,S001R04_2.csv,T0
3,S001R04_3.csv,T1
4,S001R04_4.csv,T0
...,...,...
265,S003R12_25.csv,T1
266,S003R12_26.csv,T0
267,S003R12_27.csv,T2
268,S003R12_28.csv,T0


In [7]:
help(pd.read_csv)

Help on function read_csv in module pandas.io.parsers:

read_csv(filepath_or_buffer: Union[str, pathlib.Path, IO[~AnyStr]], sep=',', delimiter=None, header='infer', names=None, index_col=None, usecols=None, squeeze=False, prefix=None, mangle_dupe_cols=True, dtype=None, engine=None, converters=None, true_values=None, false_values=None, skipinitialspace=False, skiprows=None, skipfooter=0, nrows=None, na_values=None, keep_default_na=True, na_filter=True, verbose=False, skip_blank_lines=True, parse_dates=False, infer_datetime_format=False, keep_date_col=False, date_parser=None, dayfirst=False, cache_dates=True, iterator=False, chunksize=None, compression='infer', thousands=None, decimal: str = '.', lineterminator=None, quotechar='"', quoting=0, doublequote=True, escapechar=None, comment=None, encoding=None, dialect=None, error_bad_lines=True, warn_bad_lines=True, delim_whitespace=False, low_memory=True, memory_map=False, float_precision=None)
    Read a comma-separated values (csv) file in

In [20]:
torch.tensor(pd.read_csv('data/S001R04_0.csv', header=None).to_numpy()).shape

torch.Size([64, 672])

In [4]:
import torch
from torch.utils.data import Dataset

import pandas as pd

class eegDataset(Dataset):
    
    def __init__(self, annotations_file, data_dir, transform=None, target_transform=None):
        self.data_labels = pd.read_csv(annotations_file)
        self.data_dir = data_dir
        self.transform = transform
        self.target_transform = target_transform
    
    def __len__(self):
        return len(self.data_labels)
    
    def __getitem__(self, idx):
        data_path = os.path.join(self.data_dir, self.data_labels.iloc[idx, 0])
        data = pd.read_csv(data_path, header=None).to_numpy()
        label = self.data_labels.iloc[idx, 1]
        if self.transform:
            data = self.transform(data)
        if self.target_transform:
            label = self.target_transform(label)
        return data, label

In [33]:
f'{os.getcwd()}/data'

'/Users/marktaylor/classifyEEG/data'

In [6]:
from torchvision.transforms import ToTensor

train_data = eegDataset('train_annotations.csv', f'{os.getcwd()}/train', transform=ToTensor())
test_data = eegDataset('test_annotations.csv', f'{os.getcwd()}/test', transform=ToTensor())

In [7]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_data, batch_size=6, shuffle=False)
test_dataloader = DataLoader(test_data, batch_size=6, shuffle=False)

In [15]:
len(test_dataloader)

15

In [13]:
for t in train_dataloader:
    print(t)

[tensor([[[[-5.0000e-06, -1.2000e-05, -7.7000e-05,  ..., -2.3000e-05,
           -4.3000e-05, -2.1000e-05],
          [ 2.0000e-06, -2.4000e-05, -7.8000e-05,  ..., -3.1000e-05,
           -3.9000e-05, -2.1000e-05],
          [ 3.7000e-05,  1.0000e-06, -5.9000e-05,  ..., -3.3000e-05,
           -2.7000e-05, -1.0000e-05],
          ...,
          [-4.8000e-05, -4.2000e-05, -4.2000e-05,  ..., -6.9000e-05,
           -7.2000e-05, -7.0000e-05],
          [-3.9000e-05, -3.1000e-05, -2.9000e-05,  ..., -3.7000e-05,
           -5.4000e-05, -5.9000e-05],
          [-3.9000e-05, -3.4000e-05, -2.7000e-05,  ..., -4.3000e-05,
           -5.2000e-05, -4.9000e-05]]],


        [[[-2.5000e-05, -1.6000e-05,  7.0000e-06,  ..., -1.0000e-06,
            2.1000e-05,  3.0000e-05],
          [-2.0000e-05, -1.1000e-05,  1.0000e-05,  ...,  1.3000e-05,
            3.6000e-05,  3.2000e-05],
          [-8.0000e-06,  3.0000e-06,  1.8000e-05,  ...,  2.6000e-05,
            2.6000e-05,  2.7000e-05],
          ...,
  

[tensor([[[[ 2.7000e-05,  1.3000e-05,  2.0000e-06,  ..., -9.3000e-05,
           -7.8000e-05, -6.4000e-05],
          [ 4.2000e-05,  2.6000e-05,  4.0000e-06,  ..., -6.7000e-05,
           -5.6000e-05, -2.9000e-05],
          [ 6.2000e-05,  4.1000e-05,  2.0000e-05,  ..., -5.4000e-05,
           -3.8000e-05, -1.0000e-05],
          ...,
          [ 1.3000e-05, -3.0000e-06, -1.3000e-05,  ...,  8.0000e-06,
           -1.2000e-05, -2.4000e-05],
          [-1.0000e-06, -1.6000e-05, -2.4000e-05,  ...,  4.1000e-05,
            2.1000e-05,  2.0000e-06],
          [ 5.4000e-05,  2.8000e-05,  3.0000e-06,  ..., -9.0000e-06,
           -2.1000e-05, -4.2000e-05]]],


        [[[-2.2000e-05, -4.4000e-05, -2.9000e-05,  ...,  2.7000e-05,
            5.5000e-05,  4.1000e-05],
          [-2.0000e-06, -2.0000e-05, -5.0000e-06,  ...,  3.0000e-05,
            6.1000e-05,  4.9000e-05],
          [-4.0000e-06, -2.4000e-05, -1.5000e-05,  ...,  2.5000e-05,
            5.4000e-05,  4.8000e-05],
          ...,
  

[tensor([[[[-4.4000e-05, -3.6000e-05, -2.7000e-05,  ...,  1.4000e-05,
            1.5000e-05,  3.4000e-05],
          [-4.1000e-05, -2.7000e-05, -1.7000e-05,  ...,  3.3000e-05,
            3.5000e-05,  4.3000e-05],
          [-3.1000e-05, -1.3000e-05, -6.0000e-06,  ...,  5.0000e-05,
            4.6000e-05,  4.6000e-05],
          ...,
          [-2.5000e-05, -2.2000e-05, -2.9000e-05,  ...,  4.7000e-05,
            5.8000e-05,  4.8000e-05],
          [-7.6000e-05, -7.5000e-05, -9.0000e-05,  ...,  8.3000e-05,
            8.6000e-05,  6.7000e-05],
          [-6.3000e-05, -5.8000e-05, -5.4000e-05,  ...,  6.1000e-05,
            7.4000e-05,  7.0000e-05]]],


        [[[-1.2000e-05, -1.7000e-05, -1.0000e-06,  ...,  4.8000e-05,
            2.7000e-05,  4.4000e-05],
          [-2.0000e-06, -7.0000e-06,  2.4000e-05,  ...,  4.1000e-05,
            2.3000e-05,  3.7000e-05],
          [ 6.0000e-06, -3.0000e-06,  3.1000e-05,  ...,  4.1000e-05,
            2.7000e-05,  3.5000e-05],
          ...,
  

[tensor([[[[ 5.4000e-05,  2.9000e-05,  5.1000e-05,  ..., -5.1000e-05,
           -2.2000e-05,  3.0000e-06],
          [ 2.2000e-05,  3.0000e-06,  3.0000e-05,  ..., -3.3000e-05,
           -3.0000e-06,  2.3000e-05],
          [ 1.1000e-05,  2.0000e-06,  1.6000e-05,  ..., -2.7000e-05,
            4.0000e-06,  3.1000e-05],
          ...,
          [-6.7000e-05, -6.7000e-05, -6.7000e-05,  ..., -5.3000e-05,
           -4.4000e-05, -1.7000e-05],
          [-6.0000e-05, -5.8000e-05, -4.4000e-05,  ..., -8.0000e-06,
           -9.0000e-06,  4.0000e-06],
          [-5.0000e-06, -1.0000e-05, -1.4000e-05,  ..., -2.7000e-05,
           -2.2000e-05, -5.0000e-06]]],


        [[[-6.4000e-05, -7.6000e-05, -6.6000e-05,  ...,  3.9000e-05,
            2.4000e-05,  3.7000e-05],
          [-3.1000e-05, -2.9000e-05, -1.4000e-05,  ...,  3.0000e-05,
            1.8000e-05,  2.5000e-05],
          [-8.0000e-06,  3.0000e-06,  1.8000e-05,  ...,  3.0000e-05,
            1.1000e-05,  2.0000e-05],
          ...,
  

[tensor([[[[ 3.1000e-05, -1.0000e-06, -3.2000e-05,  ...,  5.0000e-06,
            0.0000e+00,  3.5000e-05],
          [ 6.5000e-05,  3.3000e-05, -3.0000e-06,  ..., -8.0000e-06,
           -1.8000e-05,  1.6000e-05],
          [ 6.0000e-05,  3.6000e-05,  1.0000e-06,  ...,  1.0000e-06,
           -1.9000e-05,  7.0000e-06],
          ...,
          [ 2.6000e-05,  3.0000e-06,  2.0000e-06,  ...,  4.2000e-05,
            3.9000e-05,  2.8000e-05],
          [ 5.3000e-05,  3.1000e-05,  3.1000e-05,  ...,  1.0000e-05,
            4.0000e-06, -7.0000e-06],
          [ 6.3000e-05,  4.8000e-05,  4.9000e-05,  ...,  4.2000e-05,
            3.5000e-05,  2.2000e-05]]],


        [[[-2.1000e-05,  2.0000e-06, -4.0000e-06,  ...,  1.8000e-05,
            2.9000e-05,  4.0000e-05],
          [-3.4000e-05, -1.5000e-05, -2.6000e-05,  ...,  4.4000e-05,
            4.9000e-05,  6.1000e-05],
          [-3.7000e-05, -2.5000e-05, -3.6000e-05,  ...,  2.9000e-05,
            3.9000e-05,  5.0000e-05],
          ...,
  

[tensor([[[[ 9.0000e-06,  1.3000e-05, -7.0000e-06,  ..., -1.0000e-06,
           -1.0000e-06, -2.3000e-05],
          [ 8.0000e-06,  2.0000e-05,  5.0000e-06,  ..., -3.0000e-06,
           -1.0000e-06, -1.6000e-05],
          [ 2.1000e-05,  2.7000e-05,  1.6000e-05,  ...,  7.0000e-06,
            8.0000e-06, -3.0000e-06],
          ...,
          [ 1.2000e-05,  2.0000e-05,  9.0000e-06,  ..., -2.0000e-05,
           -2.7000e-05, -1.3000e-05],
          [-1.1000e-05,  5.0000e-06,  0.0000e+00,  ..., -2.2000e-05,
           -2.7000e-05,  1.0000e-06],
          [ 1.0000e-05,  1.8000e-05,  7.0000e-06,  ..., -2.3000e-05,
           -2.9000e-05, -1.3000e-05]]],


        [[[-2.5000e-05, -2.1000e-05, -2.2000e-05,  ...,  5.0000e-06,
           -7.0000e-06,  1.3000e-05],
          [-1.9000e-05, -2.3000e-05, -2.6000e-05,  ...,  1.2000e-05,
            8.0000e-06,  2.9000e-05],
          [-4.0000e-06, -7.0000e-06, -1.0000e-05,  ...,  1.3000e-05,
            1.6000e-05,  3.7000e-05],
          ...,
  

[tensor([[[[ 1.4000e-05,  3.2000e-05,  2.3000e-05,  ...,  6.0000e-06,
            0.0000e+00, -5.0000e-06],
          [ 1.9000e-05,  3.7000e-05,  2.5000e-05,  ...,  4.0000e-06,
            7.0000e-06, -2.0000e-06],
          [ 1.8000e-05,  3.9000e-05,  2.9000e-05,  ..., -9.0000e-06,
           -2.0000e-06, -6.0000e-06],
          ...,
          [ 2.3000e-05,  2.4000e-05,  2.0000e-05,  ...,  1.0000e-05,
           -1.0000e-06, -2.9000e-05],
          [ 5.2000e-05,  4.6000e-05,  2.2000e-05,  ...,  2.1000e-05,
           -7.0000e-06, -3.6000e-05],
          [ 2.1000e-05,  2.3000e-05,  1.7000e-05,  ...,  8.0000e-06,
           -3.0000e-06, -3.1000e-05]]],


        [[[-2.3000e-05, -2.3000e-05, -3.9000e-05,  ..., -1.7000e-05,
           -1.7000e-05, -9.0000e-06],
          [-2.5000e-05, -2.3000e-05, -3.6000e-05,  ..., -2.0000e-05,
           -1.9000e-05, -1.4000e-05],
          [-2.2000e-05, -2.1000e-05, -3.3000e-05,  ..., -1.2000e-05,
           -7.0000e-06, -5.0000e-06],
          ...,
  

[tensor([[[[ 7.0000e-06,  1.1000e-05, -1.2000e-05,  ..., -1.4000e-05,
            1.7000e-05,  6.0000e-06],
          [ 1.3000e-05,  2.4000e-05,  7.0000e-06,  ..., -6.0000e-06,
            1.3000e-05,  6.0000e-06],
          [ 4.0000e-06,  1.7000e-05, -1.0000e-06,  ..., -3.0000e-06,
            9.0000e-06,  3.0000e-06],
          ...,
          [ 1.3000e-05,  2.0000e-06,  8.0000e-06,  ..., -2.2000e-05,
           -2.4000e-05, -2.3000e-05],
          [ 3.7000e-05,  1.6000e-05,  1.6000e-05,  ..., -4.9000e-05,
           -5.0000e-05, -2.4000e-05],
          [ 1.2000e-05,  1.0000e-06,  5.0000e-06,  ..., -2.5000e-05,
           -2.6000e-05, -2.4000e-05]]],


        [[[-7.0000e-06, -9.0000e-06, -5.0000e-06,  ..., -5.0000e-06,
           -1.2000e-05,  6.0000e-06],
          [ 1.0000e-06, -5.0000e-06, -2.0000e-06,  ..., -1.2000e-05,
           -1.4000e-05,  3.0000e-06],
          [-1.0000e-06, -4.0000e-06, -2.0000e-06,  ..., -1.4000e-05,
           -1.3000e-05, -1.0000e-06],
          ...,
  

[tensor([[[[-1.4000e-05, -1.0000e-05,  6.0000e-06,  ...,  9.0000e-06,
            9.0000e-06,  1.1000e-05],
          [-6.0000e-06, -2.0000e-06,  9.0000e-06,  ...,  3.4000e-05,
            2.9000e-05,  2.3000e-05],
          [-4.0000e-06,  1.0000e-06,  8.0000e-06,  ...,  4.3000e-05,
            3.1000e-05,  2.1000e-05],
          ...,
          [-1.8000e-05,  1.2000e-05,  1.9000e-05,  ...,  7.0000e-06,
            1.0000e-05,  2.7000e-05],
          [-3.3000e-05,  3.3000e-05,  4.5000e-05,  ...,  1.4000e-05,
            1.3000e-05,  3.8000e-05],
          [-1.9000e-05,  1.1000e-05,  1.7000e-05,  ...,  6.0000e-06,
            1.1000e-05,  2.6000e-05]]],


        [[[ 2.6000e-05,  1.8000e-05,  9.0000e-06,  ...,  2.0000e-06,
            1.8000e-05, -5.0000e-06],
          [ 2.7000e-05,  2.9000e-05,  2.2000e-05,  ..., -2.0000e-06,
            0.0000e+00, -2.5000e-05],
          [ 2.3000e-05,  3.2000e-05,  2.2000e-05,  ...,  6.0000e-06,
            1.0000e-06, -3.4000e-05],
          ...,
  

[tensor([[[[-3.5000e-05, -4.3000e-05, -2.4000e-05,  ...,  2.0000e-06,
            2.0000e-06, -4.0000e-06],
          [-3.3000e-05, -4.2000e-05, -1.8000e-05,  ..., -3.0000e-06,
            5.0000e-06, -6.0000e-06],
          [-3.1000e-05, -4.1000e-05, -1.5000e-05,  ...,  1.0000e-06,
            1.4000e-05,  1.0000e-06],
          ...,
          [-4.1000e-05, -4.0000e-06,  1.8000e-05,  ...,  1.7000e-05,
            3.0000e-06, -1.0000e-05],
          [-4.5000e-05, -4.0000e-06,  8.0000e-06,  ...,  3.0000e-06,
           -1.1000e-05, -3.9000e-05],
          [-4.0000e-05, -4.0000e-06,  1.6000e-05,  ...,  1.5000e-05,
            1.0000e-06, -1.1000e-05]]],


        [[[-1.3000e-05, -2.2000e-05,  5.0000e-06,  ..., -7.0000e-06,
           -4.0000e-06,  7.0000e-06],
          [-1.5000e-05, -2.1000e-05,  1.0000e-06,  ..., -1.7000e-05,
           -4.0000e-06,  6.0000e-06],
          [-7.0000e-06, -1.3000e-05,  5.0000e-06,  ..., -2.4000e-05,
            2.0000e-06,  4.0000e-06],
          ...,
  

In [8]:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")

Feature batch shape: torch.Size([6, 1, 64, 656])


AttributeError: 'tuple' object has no attribute 'size'

In [10]:
train_features, train_labels

(tensor([[[[-5.0000e-06, -1.2000e-05, -7.7000e-05,  ..., -2.3000e-05,
            -4.3000e-05, -2.1000e-05],
           [ 2.0000e-06, -2.4000e-05, -7.8000e-05,  ..., -3.1000e-05,
            -3.9000e-05, -2.1000e-05],
           [ 3.7000e-05,  1.0000e-06, -5.9000e-05,  ..., -3.3000e-05,
            -2.7000e-05, -1.0000e-05],
           ...,
           [-4.8000e-05, -4.2000e-05, -4.2000e-05,  ..., -6.9000e-05,
            -7.2000e-05, -7.0000e-05],
           [-3.9000e-05, -3.1000e-05, -2.9000e-05,  ..., -3.7000e-05,
            -5.4000e-05, -5.9000e-05],
           [-3.9000e-05, -3.4000e-05, -2.7000e-05,  ..., -4.3000e-05,
            -5.2000e-05, -4.9000e-05]]],
 
 
         [[[-2.5000e-05, -1.6000e-05,  7.0000e-06,  ..., -1.0000e-06,
             2.1000e-05,  3.0000e-05],
           [-2.0000e-05, -1.1000e-05,  1.0000e-05,  ...,  1.3000e-05,
             3.6000e-05,  3.2000e-05],
           [-8.0000e-06,  3.0000e-06,  1.8000e-05,  ...,  2.6000e-05,
             2.6000e-05,  2.7000e-05