# BCI Project: Classifying preprocessed OpenMIIR data and hvEEGNet encoded data with a Conformer model

## By BrainRot: Lotte Michels & Selma Ancel

This notebook was used to train the Conformer model by (Song et al., 2023) on both the pre-processed OpenMIIR data and the encoded OpenMIIR data by the hvEEGNet encoder. Please refer to our project submission for a textual description of the training steps. The notebook and comments in the code will also provide an outline of the steps that are performed.

### References:

* Song, Y., Zheng, Q., Liu, B. and Gao, X. (2023). EEG Conformer: Convolutional Transformer for EEG Decoding and Visualization. IEEE Transactions on Neural Systems and Rehabilitation Engineering, vol. 31, pp. 710-719, doi: 10.1109/TNSRE.2022.3230250.
*   Stober, S. (2017). Toward Studying Music Cognition with Information Retrieval Techniques: Lessons Learned from the OpenMIIR Initiative. Frontiers in Psychology, 8. https://doi.org/10.3389/fpsyg.2017.01255. Related code is published here: https://github.com/sstober/openmiir.






In [1]:
!pip install mne

Collecting mne
  Downloading mne-1.9.0-py3-none-any.whl.metadata (20 kB)
Downloading mne-1.9.0-py3-none-any.whl (7.4 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.4/7.4 MB[0m [31m62.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: mne
Successfully installed mne-1.9.0


In [2]:
import mne
import torch
import numpy as np
from torch.utils.data import TensorDataset, DataLoader
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

In [3]:
# Connect to drive to load the data
from google.colab import drive
drive.mount('/content/drive')
!ls
%cd /content/drive/My Drive/Colab_Notebooks/EEG_DATA

Mounted at /content/drive
c1_latent_features.npy	c2_latent_features.npy	drive  sample_data
/content/drive/My Drive/Colab_Notebooks/EEG_DATA


# Load and process the pre-processed OpenMIIR data

In [4]:
# Load the data
raw1 = mne.io.read_raw_fif('01_preprocessed.fif', preload=True)
raw4 = mne.io.read_raw_fif('04_preprocessed.fif', preload=True)
raw6 = mne.io.read_raw_fif('06_preprocessed.fif', preload=True)
raw7 = mne.io.read_raw_fif('07_preprocessed.fif', preload=True)
raw9 = mne.io.read_raw_fif('09_preprocessed.fif', preload=True)
raw11 = mne.io.read_raw_fif('11_preprocessed.fif', preload=True)
raw12 = mne.io.read_raw_fif('12_preprocessed.fif', preload=True)
raw13 = mne.io.read_raw_fif('13_preprocessed.fif', preload=True)
raw14 = mne.io.read_raw_fif('14_preprocessed.fif', preload=True)

Opening raw data file 01_preprocessed.fif...


  raw1 = mne.io.read_raw_fif('01_preprocessed.fif', preload=True)


    Range : 0 ... 309770 =      0.000 ...  4840.156 secs
Ready.
Reading 0 ... 309770  =      0.000 ...  4840.156 secs...
Opening raw data file 04_preprocessed.fif...


  raw4 = mne.io.read_raw_fif('04_preprocessed.fif', preload=True)


    Range : 0 ... 310003 =      0.000 ...  4843.797 secs
Ready.
Reading 0 ... 310003  =      0.000 ...  4843.797 secs...
Opening raw data file 06_preprocessed.fif...


  raw6 = mne.io.read_raw_fif('06_preprocessed.fif', preload=True)


    Range : 0 ... 304128 =      0.000 ...  4752.000 secs
Ready.
Reading 0 ... 304128  =      0.000 ...  4752.000 secs...
Opening raw data file 07_preprocessed.fif...


  raw7 = mne.io.read_raw_fif('07_preprocessed.fif', preload=True)


    Range : 0 ... 315054 =      0.000 ...  4922.719 secs
Ready.
Reading 0 ... 315054  =      0.000 ...  4922.719 secs...
Opening raw data file 09_preprocessed.fif...


  raw9 = mne.io.read_raw_fif('09_preprocessed.fif', preload=True)


    Range : 0 ... 305069 =      0.000 ...  4766.703 secs
Ready.
Reading 0 ... 305069  =      0.000 ...  4766.703 secs...
Opening raw data file 11_preprocessed.fif...


  raw11 = mne.io.read_raw_fif('11_preprocessed.fif', preload=True)


    Range : 0 ... 330609 =      0.000 ...  5165.766 secs
Ready.
Reading 0 ... 330609  =      0.000 ...  5165.766 secs...
Opening raw data file 12_preprocessed.fif...


  raw12 = mne.io.read_raw_fif('12_preprocessed.fif', preload=True)


    Range : 0 ... 310668 =      0.000 ...  4854.188 secs
Ready.
Reading 0 ... 310668  =      0.000 ...  4854.188 secs...
Opening raw data file 13_preprocessed.fif...


  raw13 = mne.io.read_raw_fif('13_preprocessed.fif', preload=True)


    Range : 0 ... 313988 =      0.000 ...  4906.062 secs
Ready.
Reading 0 ... 313988  =      0.000 ...  4906.062 secs...
Opening raw data file 14_preprocessed.fif...


  raw14 = mne.io.read_raw_fif('14_preprocessed.fif', preload=True)


    Range : 0 ... 308178 =      0.000 ...  4815.281 secs
Ready.
Reading 0 ... 308178  =      0.000 ...  4815.281 secs...


In [5]:
events1 = mne.find_events(raw1, stim_channel='STI 014')
events4 = mne.find_events(raw4, stim_channel='STI 014')
events6 = mne.find_events(raw6, stim_channel='STI 014')
events7 = mne.find_events(raw7, stim_channel='STI 014')
events9 = mne.find_events(raw9, stim_channel='STI 014')
events11 = mne.find_events(raw11, stim_channel='STI 014')
events12 = mne.find_events(raw12, stim_channel='STI 014')
events13 = mne.find_events(raw13, stim_channel='STI 014')
events14 = mne.find_events(raw14, stim_channel='STI 014')

# List all unwanted event codes
# For music imagery (2nd condition)
exclude_codes1 = [1111, 2000, 2001, 11, 21, 31, 41,
                 111, 121, 131, 141, 211, 221, 231,
                 241, 13, 23, 33, 43, 113, 123, 133,
                 143, 213, 223, 233, 243, 14, 24, 34,
                 44, 114, 124, 134, 144, 214, 224, 234, 244]

# For music perception (1st condition)
exclude_codes2 = [1111, 2000, 2001, 12, 22, 32, 42, 112, 122, 132, 142,
                 212, 222, 232, 242, 13, 23, 33, 43, 113, 123, 133,
                 143, 213, 223, 233, 243, 14, 24, 34, 44, 114, 124,
                 134, 144, 214, 224, 234, 244]

# Keep only rows where the event code (3rd column) is NOT in exclude_codes
filtered_events1 = events1[~np.isin(events1[:, 2], exclude_codes1)]
filtered_events4 = events4[~np.isin(events4[:, 2], exclude_codes1)]
filtered_events6 = events6[~np.isin(events6[:, 2], exclude_codes1)]
filtered_events7 = events7[~np.isin(events7[:, 2], exclude_codes1)]
filtered_events9 = events9[~np.isin(events9[:, 2], exclude_codes1)]
filtered_events11 = events11[~np.isin(events11[:, 2], exclude_codes1)]
filtered_events12 = events12[~np.isin(events12[:, 2], exclude_codes1)]
filtered_events13 = events13[~np.isin(events13[:, 2], exclude_codes1)]
filtered_events14 = events14[~np.isin(events14[:, 2], exclude_codes1)]

filtered_events1_per = events1[~np.isin(events1[:, 2], exclude_codes2)]
filtered_events4_per = events4[~np.isin(events4[:, 2], exclude_codes2)]
filtered_events6_per = events6[~np.isin(events6[:, 2], exclude_codes2)]
filtered_events7_per = events7[~np.isin(events7[:, 2], exclude_codes2)]
filtered_events9_per = events9[~np.isin(events9[:, 2], exclude_codes2)]
filtered_events11_per = events11[~np.isin(events11[:, 2], exclude_codes2)]
filtered_events12_per = events12[~np.isin(events12[:, 2], exclude_codes2)]
filtered_events13_per = events13[~np.isin(events13[:, 2], exclude_codes2)]
filtered_events14_per = events14[~np.isin(events14[:, 2], exclude_codes2)]

360 events found on stim channel STI 014
Event IDs: [  11   12   13   14   21   22   23   24   31   32   33   34   41   42
   43   44  111  112  113  114  121  122  123  124  131  132  133  134
  141  142  143  144  211  212  213  214  221  222  223  224  231  232
  233  234  241  242  243  244 1111 2001]
360 events found on stim channel STI 014
Event IDs: [  11   12   13   14   21   22   23   24   31   32   33   34   41   42
   43   44  111  112  113  114  121  122  123  124  131  132  133  134
  141  142  143  144  211  212  213  214  221  222  223  224  231  232
  233  234  241  242  243  244 1111 2000 2001]
360 events found on stim channel STI 014
Event IDs: [  11   12   13   14   21   22   23   24   31   32   33   34   41   42
   43   44  111  112  113  114  121  122  123  124  131  132  133  134
  141  142  143  144  211  212  213  214  221  222  223  224  231  232
  233  234  241  242  243  244 1111 2000 2001]
360 events found on stim channel STI 014
Event IDs: [  11   12   13  

In [6]:
# Create the events for music imagery

event_id = None # any
tmin = -0.2  # start of each epoch (200ms before the trigger) -0.2
tmax = 6.9  # shortest song is 6.9 seconds
detrend = 0 # remove dc

#01
picks1 = mne.pick_types(raw1.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs1 = mne.Epochs(raw1, filtered_events1, event_id, tmin, tmax, preload=True, proj=False, picks=picks1, verbose=False)
times1 = beat_epochs1.times
# 04
picks4 = mne.pick_types(raw4.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs4 = mne.Epochs(raw4, filtered_events4, event_id, tmin, tmax, preload=True, proj=False, picks=picks4, verbose=False)
times4 = beat_epochs4.times
#06
picks6 = mne.pick_types(raw6.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs6 = mne.Epochs(raw6, filtered_events6, event_id, tmin, tmax, preload=True, proj=False, picks=picks6, verbose=False)
times6 = beat_epochs6.times
#07
picks7 = mne.pick_types(raw7.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs7 = mne.Epochs(raw7, filtered_events7, event_id, tmin, tmax, preload=True, proj=False, picks=picks7, verbose=False)
times7 = beat_epochs7.times
#09
picks9 = mne.pick_types(raw9.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs9 = mne.Epochs(raw9, filtered_events9, event_id, tmin, tmax, preload=True, proj=False, picks=picks9, verbose=False)
times9 = beat_epochs9.times
#11
picks11 = mne.pick_types(raw11.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs11 = mne.Epochs(raw11, filtered_events11, event_id, tmin, tmax, preload=True, proj=False, picks=picks11, verbose=False)
times11 = beat_epochs11.times
#12
picks12 = mne.pick_types(raw12.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs12 = mne.Epochs(raw12, filtered_events12, event_id, tmin, tmax, preload=True, proj=False, picks=picks12, verbose=False)
times12 = beat_epochs12.times
#13
picks13 = mne.pick_types(raw13.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs13 = mne.Epochs(raw13, filtered_events13, event_id, tmin, tmax, preload=True, proj=False, picks=picks13, verbose=False)
times13 = beat_epochs13.times
#14
picks14 = mne.pick_types(raw14.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs14 = mne.Epochs(raw14, filtered_events14, event_id, tmin, tmax, preload=True, proj=False, picks=picks14, verbose=False)
times14 = beat_epochs14.times

In [7]:
# Create the events for music perception

#01
picks1 = mne.pick_types(raw1.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs1_per = mne.Epochs(raw1, filtered_events1_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks1, verbose=False)
times1 = beat_epochs1.times
# 04
picks4 = mne.pick_types(raw4.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs4_per = mne.Epochs(raw4, filtered_events4_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks4, verbose=False)
times4 = beat_epochs4.times
#06
picks6 = mne.pick_types(raw6.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs6_per = mne.Epochs(raw6, filtered_events6_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks6, verbose=False)
times6 = beat_epochs6.times
#07
picks7 = mne.pick_types(raw7.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs7_per = mne.Epochs(raw7, filtered_events7_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks7, verbose=False)
times7 = beat_epochs7.times
#09
picks9 = mne.pick_types(raw9.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs9_per = mne.Epochs(raw9, filtered_events9_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks9, verbose=False)
times9 = beat_epochs9.times
#11
picks11 = mne.pick_types(raw11.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs11_per = mne.Epochs(raw11, filtered_events11_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks11, verbose=False)
times11 = beat_epochs11.times
#12
picks12 = mne.pick_types(raw12.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs12_per = mne.Epochs(raw12, filtered_events12_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks12, verbose=False)
times12 = beat_epochs12.times
#13
picks13 = mne.pick_types(raw13.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs13_per = mne.Epochs(raw13, filtered_events13_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks13, verbose=False)
times13 = beat_epochs13.times
#14
picks14 = mne.pick_types(raw14.info, meg=False, eeg=True, eog=False, stim=False, exclude=["bads"])
beat_epochs14_per = mne.Epochs(raw14, filtered_events14_per, event_id, tmin, tmax, preload=True, proj=False, picks=picks14, verbose=False)
times14 = beat_epochs14.times

In [8]:
# Check if the music imagery events have been created correctly
beat_epochs1

Unnamed: 0,General,General.1
,MNE object type,Epochs
,Measurement date,2015-01-28 at 17:39:57 UTC
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,60
,Events counts,112: 5  12: 5  122: 5  132: 5  142: 5  212: 5  22: 5  222: 5  232: 5  242: 5  32: 5  42: 5
,Time range,-0.203 – 6.906 s
,Baseline,-0.203 – 0.000 s
,Sampling frequency,64.00 Hz


In [9]:
# Check if the music perception events have been created correctly
beat_epochs1_per

Unnamed: 0,General,General.1
,MNE object type,Epochs
,Measurement date,2015-01-28 at 17:39:57 UTC
,Participant,Unknown
,Experimenter,Unknown
,Acquisition,Acquisition
,Total number of events,60
,Events counts,11: 5  111: 5  121: 5  131: 5  141: 5  21: 5  211: 5  221: 5  231: 5  241: 5  31: 5  41: 5
,Time range,-0.203 – 6.906 s
,Baseline,-0.203 – 0.000 s
,Sampling frequency,64.00 Hz


In [10]:
# Add all the X and y data for music imagery together
X1 = beat_epochs1.get_data()
y1 = beat_epochs1.events[:, 2]
X4 = beat_epochs4.get_data()
y4 = beat_epochs4.events[:, 2]
X6 = beat_epochs6.get_data()
y6 = beat_epochs6.events[:, 2]
X7 = beat_epochs7.get_data()
y7 = beat_epochs7.events[:, 2]
X9 = beat_epochs9.get_data()
y9 = beat_epochs9.events[:, 2]
X11 = beat_epochs11.get_data()
y11 = beat_epochs11.events[:, 2]
X12 = beat_epochs12.get_data()
y12 = beat_epochs12.events[:, 2]
X13 = beat_epochs13.get_data()
y13 = beat_epochs13.events[:, 2]
X14 = beat_epochs14.get_data()
y14 = beat_epochs14.events[:, 2]

X = np.concatenate((X1, X4, X6, X7, X9, X11, X12, X13, X14), axis=0)
y = np.concatenate((y1, y4, y6, y7, y9, y11, y12, y13, y14), axis=0)

print(X.shape)
print(y.shape)

(540, 64, 456)
(540,)


In [11]:
# Add all the X and y data for music perception together
X1_per = beat_epochs1_per.get_data()
y1_per = beat_epochs1_per.events[:, 2]
X4_per = beat_epochs4_per.get_data()
y4_per = beat_epochs4_per.events[:, 2]
X6_per = beat_epochs6_per.get_data()
y6_per = beat_epochs6_per.events[:, 2]
X7_per = beat_epochs7_per.get_data()
y7_per = beat_epochs7_per.events[:, 2]
X9_per = beat_epochs9_per.get_data()
y9_per = beat_epochs9_per.events[:, 2]
X11_per = beat_epochs11_per.get_data()
y11_per = beat_epochs11_per.events[:, 2]
X12_per = beat_epochs12_per.get_data()
y12_per = beat_epochs12_per.events[:, 2]
X13_per = beat_epochs13_per.get_data()
y13_per = beat_epochs13_per.events[:, 2]
X14_per = beat_epochs14_per.get_data()
y14_per = beat_epochs14_per.events[:, 2]

X_per = np.concatenate((X1_per, X4_per, X6_per, X7_per, X9_per, X11_per, X12_per, X13_per, X14_per), axis=0)
y_per = np.concatenate((y1_per, y4_per, y6_per, y7_per, y9_per, y11_per, y12_per, y13_per, y14_per), axis=0)

print(X_per.shape)
print(y_per.shape)

(540, 64, 456)
(540,)


In [12]:
# Normalize all trial channels to zero mean and range [−1, 1].
mean = np.mean(X, axis=2, keepdims=True)
zero_mean_data = X - mean

min_val = np.min(zero_mean_data, axis=2, keepdims=True)
max_val = np.max(zero_mean_data, axis=2, keepdims=True)

range_val = max_val - min_val
range_val[range_val == 0] = 1

norm_X = 2 * (zero_mean_data - min_val) / range_val - 1

In [13]:
# Normalize all trial channels to zero mean and range [−1, 1].
mean = np.mean(X_per, axis=2, keepdims=True)
zero_mean_data = X_per - mean

min_val = np.min(zero_mean_data, axis=2, keepdims=True)
max_val = np.max(zero_mean_data, axis=2, keepdims=True)

range_val = max_val - min_val
range_val[range_val == 0] = 1

norm_X_per = 2 * (zero_mean_data - min_val) / range_val - 1

In [14]:
# Encode labels
le = LabelEncoder()
y_encoded = le.fit_transform(y)

# Split the data
X_train_img, X_test_img, y_train_img, y_test_img = train_test_split(norm_X, y_encoded, test_size=0.2, random_state=42, stratify = y_encoded)

[[-0.09926307  0.15910763 -0.00437945 ... -0.12787348  0.07994161
  -0.07088793]
 [ 0.09904717  0.43484837  0.31625373 ...  0.05276527  0.10450267
   0.20693644]
 [-0.01227497  0.39773527  0.1143693  ...  0.05873385  0.52895528
   0.22092594]
 ...
 [ 0.88369485  0.21168037  0.63429255 ...  0.36182016  0.43158993
   0.1406986 ]
 [ 0.75055187  0.44282568  0.45782235 ...  0.34331597  0.42493087
   0.09597021]
 [ 0.56440037 -0.25077971  0.14689489 ...  0.23180674  0.29671478
   0.2026068 ]]
[122  32 232 132 142 112  12  22  42 212 222 242 222  42 112  22 212 132
  32 232 122 242  12 142 142 222 232 122 212  42 112  12 132  32  22 242
  42 212  22 122 112 232 222 242  12 132 142  32 132 122 242  32 222 142
 232 112  42  12  22 212 122  32 232 132 142 112  12  22  42 212 222 242
 222  42 112  22 212 132  32 232 122 242  12 142 142 222 232 122 212  42
 112  12 132  32  22 242  42 212  22 122 112 232 222 242  12 132 142  32
 132 122 242  32 222 142 232 112  42  12  22 212 122  32 232 132 142 1

In [15]:
# Encode labels
le = LabelEncoder()
y_encoded_per = le.fit_transform(y_per)

# Split the data
X_train_per, X_test_per, y_train_per, y_test_per = train_test_split(norm_X_per, y_encoded_per, test_size=0.2, random_state=42, stratify = y_encoded_per)

In [16]:
# Create labels for the group condition
group_y_train_img = []
group_y_test_img = []
for i in y_train_img:
    if i < 4:
        group_y_train_img.append(0)
    elif i < 8:
        group_y_train_img.append(1)
    else:
        group_y_train_img.append(2)

for i in y_test_img:
    if i < 4:
        group_y_test_img.append(0)
    elif i < 8:
        group_y_test_img.append(1)
    else:
        group_y_test_img.append(2)

group_y_train_per = []
group_y_test_per = []
for i in y_train_per:
    if i < 4:
        group_y_train_per.append(0)
    elif i < 8:
        group_y_train_per.append(1)
    else:
        group_y_train_per.append(2)

for i in y_test_per:
    if i < 4:
        group_y_test_per.append(0)
    elif i < 8:
        group_y_test_per.append(1)
    else:
        group_y_test_per.append(2)

In [17]:
# Create labels for the meter condition
meter_y_train_img = []
meter_y_test_img = []
for i in y_train_img:
  if i in [0, 1, 4, 5, 8, 9]:
    meter_y_train_img.append(0)
  elif i in [2, 3, 6, 7, 10, 11]:
    meter_y_train_img.append(1)

for i in y_test_img:
  if i in [0, 1, 4, 5, 8, 9]:
    meter_y_test_img.append(0)
  elif i in [2, 3, 6, 7, 10, 11]:
    meter_y_test_img.append(1)

meter_y_train_per = []
meter_y_test_per = []
for i in y_train_per:
  if i in [0, 1, 4, 5, 8, 9]:
    meter_y_train_per.append(0)
  elif i in [2, 3, 6, 7, 10, 11]:
    meter_y_train_per.append(1)

for i in y_test_per:
  if i in [0, 1, 4, 5, 8, 9]:
    meter_y_test_per.append(0)
  elif i in [2, 3, 6, 7, 10, 11]:
    meter_y_test_per.append(1)

# Conformer model

In [28]:
"""
EEG Conformer

Convolutional Transformer for EEG decoding

Couple CNN and Transformer in a concise manner with amazing results
"""

import argparse
import os
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
import numpy as np
import math
import glob
import random
import itertools
import datetime
import time
import datetime
import sys
import scipy.io

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchsummary import summary
import torch.autograd as autograd
from torchvision.models import vgg19

import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn.init as init

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from sklearn.decomposition import PCA

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
# from common_spatial_pattern import csp

import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True

# Convolution module for
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
    def __init__(self, num_channels, emb_size=40):
        # self.patch_size = patch_size
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
            nn.Conv2d(40, 40, (num_channels, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d((1, 75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
            Rearrange('b e (h) (w) -> b (h w) e'),
        )


    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.shallownet(x)
        x = self.projection(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class GELU(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes, size):
        super().__init__()

        # global average pooling
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )
        self.fc = nn.Sequential(
            nn.Linear(size, 256),
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(256, 32),
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(32, n_classes)
        )

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        out = self.fc(x)
        return x, out

class Conformer(nn.Sequential):
    def __init__(self, num_classes, num_channels, size, emb_size=40, depth=6):
        super().__init__(
            PatchEmbedding(emb_size=emb_size, num_channels=num_channels),
            TransformerEncoder(depth=depth, emb_size=emb_size),
            ClassificationHead(emb_size, num_classes, size)
        )


class ExP():
    def __init__(self, nsub, nclasses, nchannels, size):
        super(ExP, self).__init__()
        self.batch_size = 72
        self.n_epochs = 2000
        self.c_dim = 12
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.dimension = (190, 50)
        self.nSub = nsub
        self.num_classes = nclasses
        self.num_channels = nchannels
        self.size = size

        self.start_epoch = 0


        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor

        self.criterion_l1 = torch.nn.L1Loss().cuda()
        self.criterion_l2 = torch.nn.MSELoss().cuda()
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()

        self.model = Conformer(self.num_classes, self.num_channels, self.size).cuda()
        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
        self.model = self.model.cuda()
        # summary(self.model, (1, 64, 456))

    def get_source_data(self, X_train, y_train, X_test, y_test, enc):
        # ! please please recheck if you need validation set
        # ! and the data segement compared methods used

        # train data
        #self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub)
        self.train_data = X_train
        self.train_label = y_train

        #self.train_data = np.transpose(self.train_data, (2, 1, 0))
        if enc == False:
          self.train_data = np.expand_dims(self.train_data, axis=1)
        self.train_label = np.transpose(self.train_label)

        self.allData = self.train_data
        self.allLabel = self.train_label #[0]

        shuffle_num = np.random.permutation(len(self.allData))
        self.allData = self.allData[shuffle_num, :, :, :]
        self.allLabel = self.allLabel[shuffle_num]

        # test data
        #self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub)
        self.test_data = X_test
        self.test_label = y_test

        #self.test_data = np.transpose(self.test_data, (2, 1, 0))
        if enc == False:
          self.test_data = np.expand_dims(self.test_data, axis=1)
        self.test_label = np.transpose(self.test_label)

        self.testData = self.test_data
        self.testLabel = self.test_label#[0]

        # data shape: (trial, conv channel, electrode channel, time samples)
        print(self.allData.shape)
        return self.allData, self.allLabel, self.testData, self.testLabel


    def train(self, X_train, y_train, X_test, y_test, enc):

        img, label, test_data, test_label = self.get_source_data(X_train, y_train, X_test, y_test, enc)

        img = torch.from_numpy(img)
        label = torch.from_numpy(label)

        dataset = torch.utils.data.TensorDataset(img, label)
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)

        test_data = torch.from_numpy(test_data)
        test_label = torch.from_numpy(test_label)
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)

        # Optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        test_data = Variable(test_data.type(self.Tensor))
        test_label = Variable(test_label.type(self.LongTensor))

        bestAcc = 0
        averAcc = 0
        num = 0
        Y_true = 0
        Y_pred = 0

        # Train the cnn model
        total_step = len(self.dataloader)
        curr_lr = self.lr

        for e in range(self.n_epochs):
            # in_epoch = time.time()
            self.model.train()
            for i, (img, label) in enumerate(self.dataloader):

                img = Variable(img.cuda().type(self.Tensor))
                label = Variable(label.cuda().type(self.LongTensor))

                tok, outputs = self.model(img)

                loss = self.criterion_cls(outputs, label)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()


            # out_epoch = time.time()


            # test process
            if (e + 1) % 1 == 0:
                self.model.eval()
                Tok, Cls = self.model(test_data)


                loss_test = self.criterion_cls(Cls, test_label)
                y_pred = torch.max(Cls, 1)[1]
                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
                train_pred = torch.max(outputs, 1)[1]
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))

                print('Epoch:', e,
                      '  Train loss: %.6f' % loss.detach().cpu().numpy(),
                      '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
                      '  Train accuracy %.6f' % train_acc,
                      '  Test accuracy is %.6f' % acc)

                num = num + 1
                averAcc = averAcc + acc
                if acc > bestAcc:
                    bestAcc = acc
                    Y_true = test_label
                    Y_pred = y_pred


        averAcc = averAcc / num
        print('The average accuracy is:', averAcc)
        print('The best accuracy is:', bestAcc)

        return bestAcc, averAcc, Y_true, Y_pred
        # writer.close()

# Classification of the OpenMIIR data

## Stimulus (12) Classification

### Music Imagery

In [34]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)


print('Music Imagery (stimulus)')
exp = ExP(1, 12, 64, 960)

enc = False
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_img, y_train_img, X_test_img, y_test_img, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Imagery (stimulus)
(432, 1, 64, 456)
Epoch: 0   Train loss: 2.560309   Test loss: 2.491183   Train accuracy 0.138889   Test accuracy is 0.092593
Epoch: 1   Train loss: 2.476555   Test loss: 2.488458   Train accuracy 0.138889   Test accuracy is 0.092593
Epoch: 2   Train loss: 2.580481   Test loss: 2.500107   Train accuracy 0.069444   Test accuracy is 0.055556
Epoch: 3   Train loss: 2.489622   Test loss: 2.502063   Train accuracy 0.152778   Test accuracy is 0.046296
Epoch: 4   Train loss: 2.461510   Test loss: 2.503032   Train accuracy 0.097222   Test accuracy is 0.037037
Epoch: 5   Train loss: 2.488692   Test loss: 2.496148   Train accuracy 0.125000   Test accuracy is 0.092593
Epoch: 6   Train loss: 2.456931   Test loss: 2.493709   Train accuracy 0.111111   Test accuracy is 0.111111
Epoch: 7   Train loss: 2.548489   Test loss: 2.495150   Train accuracy 0.055556   Test accuracy is 0.120370
Epoch: 8   Train loss: 2.497465   Test loss: 2.521341   Train accuracy 0.125000   

### Music Perception

In [35]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)

print('Music Perception (stimulus)')
exp = ExP(1, 12, 64, 960)

enc = False
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_per, y_train_per, X_test_per, y_test_per, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Perception (stimulus)
(432, 1, 64, 456)
Epoch: 0   Train loss: 2.560942   Test loss: 2.498019   Train accuracy 0.097222   Test accuracy is 0.074074
Epoch: 1   Train loss: 2.506402   Test loss: 2.498619   Train accuracy 0.152778   Test accuracy is 0.074074
Epoch: 2   Train loss: 2.560747   Test loss: 2.507602   Train accuracy 0.055556   Test accuracy is 0.074074
Epoch: 3   Train loss: 2.478660   Test loss: 2.524446   Train accuracy 0.083333   Test accuracy is 0.055556
Epoch: 4   Train loss: 2.523980   Test loss: 2.541183   Train accuracy 0.111111   Test accuracy is 0.046296
Epoch: 5   Train loss: 2.540038   Test loss: 2.561084   Train accuracy 0.111111   Test accuracy is 0.055556
Epoch: 6   Train loss: 2.455421   Test loss: 2.573547   Train accuracy 0.111111   Test accuracy is 0.046296
Epoch: 7   Train loss: 2.472241   Test loss: 2.594215   Train accuracy 0.083333   Test accuracy is 0.037037
Epoch: 8   Train loss: 2.452838   Test loss: 2.612417   Train accuracy 0.138889

## Group (3) Classification

### Music Imagery

In [36]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)

print('Music Imagery (group)')
exp = ExP(1, 3, 64, 960)

enc = False
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_img, group_y_train_img, X_test_img, group_y_test_img, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Imagery (group)
(432, 1, 64, 456)
Epoch: 0   Train loss: 1.105184   Test loss: 1.123243   Train accuracy 0.402778   Test accuracy is 0.314815
Epoch: 1   Train loss: 1.160038   Test loss: 1.120264   Train accuracy 0.347222   Test accuracy is 0.324074
Epoch: 2   Train loss: 1.156032   Test loss: 1.127722   Train accuracy 0.291667   Test accuracy is 0.324074
Epoch: 3   Train loss: 1.092264   Test loss: 1.139009   Train accuracy 0.416667   Test accuracy is 0.259259
Epoch: 4   Train loss: 1.150053   Test loss: 1.165143   Train accuracy 0.291667   Test accuracy is 0.203704
Epoch: 5   Train loss: 1.147364   Test loss: 1.171531   Train accuracy 0.347222   Test accuracy is 0.222222
Epoch: 6   Train loss: 1.151778   Test loss: 1.179364   Train accuracy 0.333333   Test accuracy is 0.277778
Epoch: 7   Train loss: 1.099300   Test loss: 1.215855   Train accuracy 0.375000   Test accuracy is 0.212963
Epoch: 8   Train loss: 1.088334   Test loss: 1.234483   Train accuracy 0.458333   Tes

### Music Perception

In [37]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)


print('Music Perception (group)')
exp = ExP(1, 3, 64, 960)

enc = False
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_per, group_y_train_per, X_test_per, group_y_test_per, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Perception (group)
(432, 1, 64, 456)
Epoch: 0   Train loss: 1.086698   Test loss: 1.109043   Train accuracy 0.472222   Test accuracy is 0.342593
Epoch: 1   Train loss: 1.150658   Test loss: 1.107053   Train accuracy 0.388889   Test accuracy is 0.314815
Epoch: 2   Train loss: 1.120194   Test loss: 1.116007   Train accuracy 0.388889   Test accuracy is 0.277778
Epoch: 3   Train loss: 1.073728   Test loss: 1.134446   Train accuracy 0.416667   Test accuracy is 0.314815
Epoch: 4   Train loss: 1.123021   Test loss: 1.125493   Train accuracy 0.333333   Test accuracy is 0.296296
Epoch: 5   Train loss: 1.080178   Test loss: 1.125783   Train accuracy 0.402778   Test accuracy is 0.370370
Epoch: 6   Train loss: 1.124660   Test loss: 1.143194   Train accuracy 0.388889   Test accuracy is 0.324074
Epoch: 7   Train loss: 1.076769   Test loss: 1.155498   Train accuracy 0.388889   Test accuracy is 0.333333
Epoch: 8   Train loss: 1.101189   Test loss: 1.177519   Train accuracy 0.458333   

## Meter (2) Classification

### Music Imagery

In [38]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)


print('Music Imagery (meter)')
exp = ExP(1, 2, 64, 960)

enc = False
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_img, meter_y_train_img, X_test_img, meter_y_test_img, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Imagery (meter)
(432, 1, 64, 456)
Epoch: 0   Train loss: 0.755551   Test loss: 0.695480   Train accuracy 0.375000   Test accuracy is 0.490741
Epoch: 1   Train loss: 0.715464   Test loss: 0.696742   Train accuracy 0.527778   Test accuracy is 0.490741
Epoch: 2   Train loss: 0.734071   Test loss: 0.699576   Train accuracy 0.458333   Test accuracy is 0.490741
Epoch: 3   Train loss: 0.686742   Test loss: 0.700147   Train accuracy 0.541667   Test accuracy is 0.500000
Epoch: 4   Train loss: 0.681268   Test loss: 0.689956   Train accuracy 0.597222   Test accuracy is 0.555556
Epoch: 5   Train loss: 0.699450   Test loss: 0.688218   Train accuracy 0.458333   Test accuracy is 0.527778
Epoch: 6   Train loss: 0.748968   Test loss: 0.689637   Train accuracy 0.402778   Test accuracy is 0.546296
Epoch: 7   Train loss: 0.704064   Test loss: 0.691787   Train accuracy 0.597222   Test accuracy is 0.527778
Epoch: 8   Train loss: 0.742487   Test loss: 0.714148   Train accuracy 0.472222   Tes

### Music Perception

In [39]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)


print('Music Perception (meter)')
exp = ExP(1, 2, 64, 960)

enc = False
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_per, meter_y_train_per, X_test_per, meter_y_test_per, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Perception (meter)
(432, 1, 64, 456)
Epoch: 0   Train loss: 0.728528   Test loss: 0.697160   Train accuracy 0.472222   Test accuracy is 0.500000
Epoch: 1   Train loss: 0.715868   Test loss: 0.705193   Train accuracy 0.458333   Test accuracy is 0.500000
Epoch: 2   Train loss: 0.744452   Test loss: 0.723211   Train accuracy 0.513889   Test accuracy is 0.490741
Epoch: 3   Train loss: 0.681076   Test loss: 0.727151   Train accuracy 0.486111   Test accuracy is 0.500000
Epoch: 4   Train loss: 0.656404   Test loss: 0.719911   Train accuracy 0.625000   Test accuracy is 0.509259
Epoch: 5   Train loss: 0.694981   Test loss: 0.736910   Train accuracy 0.541667   Test accuracy is 0.416667
Epoch: 6   Train loss: 0.679088   Test loss: 0.761200   Train accuracy 0.555556   Test accuracy is 0.481481
Epoch: 7   Train loss: 0.674219   Test loss: 0.922719   Train accuracy 0.597222   Test accuracy is 0.462963
Epoch: 8   Train loss: 0.702339   Test loss: 0.994185   Train accuracy 0.611111   

# Classification of the encoded hvEEGNet data


## Load and process the hvEEGNet encoded data

In [44]:
# Features for music imagery
encoded_data_img = np.load("/content/c2_latent_features.npy")
X_train_img_enc, X_test_img_enc, y_train_img_enc, y_test_img_enc = train_test_split(encoded_data_img, y_encoded, test_size=0.2, random_state=42, stratify = y_encoded)
X_train_img_enc = X_train_img_enc.transpose(0, 2, 1, 3)
X_test_img_enc = X_test_img_enc.transpose(0, 2, 1, 3)
# Features for music perception
encoded_data_per = np.load("/content/c1_latent_features.npy")
X_train_per_enc , X_test_per_enc , y_train_per_enc , y_test_per_enc  = train_test_split(encoded_data_per, y_encoded, test_size=0.2, random_state=42, stratify = y_encoded)
X_train_per_enc = X_train_per_enc.transpose(0, 2, 1, 3)
X_test_per_enc = X_test_per_enc.transpose(0, 2, 1, 3)

In [45]:
print(X_train_img_enc.shape)

(432, 1, 16, 45)


In [46]:
# Create labels for the group condition
group_y_train_img_enc = []
group_y_test_img_enc = []
for i in y_train_img_enc:
    if i < 4:
        group_y_train_img_enc.append(0)
    elif i < 8:
        group_y_train_img_enc.append(1)
    else:
        group_y_train_img_enc.append(2)

for i in y_test_img_enc:
    if i < 4:
        group_y_test_img_enc.append(0)
    elif i < 8:
        group_y_test_img_enc.append(1)
    else:
        group_y_test_img_enc.append(2)

group_y_train_per_enc = []
group_y_test_per_enc = []
for i in y_train_per_enc:
    if i < 4:
        group_y_train_per_enc.append(0)
    elif i < 8:
        group_y_train_per_enc.append(1)
    else:
        group_y_train_per_enc.append(2)

for i in y_test_per_enc:
    if i < 4:
        group_y_test_per_enc.append(0)
    elif i < 8:
        group_y_test_per_enc.append(1)
    else:
        group_y_test_per_enc.append(2)

In [60]:
# Create labels for the meter condition
meter_y_train_img_enc = []
meter_y_test_img_enc = []
for i in y_train_img:
  if i in [0, 1, 4, 5, 8, 9]:
    meter_y_train_img_enc.append(0)
  elif i in [2, 3, 6, 7, 10, 11]:
    meter_y_train_img_enc.append(1)

for i in y_test_img_enc:
  if i in [0, 1, 4, 5, 8, 9]:
    meter_y_test_img_enc.append(0)
  elif i in [2, 3, 6, 7, 10, 11]:
    meter_y_test_img_enc.append(1)

meter_y_train_per_enc = []
meter_y_test_per_enc = []
for i in y_train_per:
  if i in [0, 1, 4, 5, 8, 9]:
    meter_y_train_per_enc.append(0)
  elif i in [2, 3, 6, 7, 10, 11]:
    meter_y_train_per_enc.append(1)

for i in y_test_per:
  if i in [0, 1, 4, 5, 8, 9]:
    meter_y_test_per_enc.append(0)
  elif i in [2, 3, 6, 7, 10, 11]:
    meter_y_test_per_enc.append(1)

## Conformer for the encoded data

In [53]:
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
    def __init__(self, num_channels, emb_size=40):
        # self.patch_size = patch_size
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
            nn.Conv2d(40, 40, (num_channels, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d((1, 15), (1, 5)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
            Rearrange('b e (h) (w) -> b (h w) e'),
        )


    def forward(self, x: Tensor) -> Tensor:
        b, _, _, _ = x.shape
        x = self.shallownet(x)
        x = self.projection(x)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class GELU(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes, size):
        super().__init__()

        # global average pooling
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )
        self.fc = nn.Sequential(
            nn.Linear(size, 256),
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(256, 32),
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(32, n_classes)
        )

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        out = self.fc(x)
        return x, out

class Conformer(nn.Sequential):
    def __init__(self, num_classes, num_channels, size, emb_size=40, depth=6):
        super().__init__(
            PatchEmbedding(emb_size=emb_size, num_channels=num_channels),
            TransformerEncoder(depth=depth, emb_size=emb_size),
            ClassificationHead(emb_size, num_classes, size)
        )


class ExP():
    def __init__(self, nsub, nclasses, nchannels, size):
        super(ExP, self).__init__()
        self.batch_size = 72
        self.n_epochs = 2000
        self.c_dim = 12
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.dimension = (190, 50)
        self.nSub = nsub
        self.num_classes = nclasses
        self.num_channels = nchannels
        self.size = size

        self.start_epoch = 0


        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor

        self.criterion_l1 = torch.nn.L1Loss().cuda()
        self.criterion_l2 = torch.nn.MSELoss().cuda()
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()

        self.model = Conformer(self.num_classes, self.num_channels, self.size).cuda()
        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
        self.model = self.model.cuda()
        # summary(self.model, (1, 64, 456))

    def get_source_data(self, X_train, y_train, X_test, y_test, enc):
        # ! please please recheck if you need validation set
        # ! and the data segement compared methods used

        # train data
        #self.total_data = scipy.io.loadmat(self.root + 'A0%dT.mat' % self.nSub)
        self.train_data = X_train
        self.train_label = y_train

        #self.train_data = np.transpose(self.train_data, (2, 1, 0))
        if enc == False:
          self.train_data = np.expand_dims(self.train_data, axis=1)
        self.train_label = np.transpose(self.train_label)

        self.allData = self.train_data
        self.allLabel = self.train_label #[0]

        shuffle_num = np.random.permutation(len(self.allData))
        self.allData = self.allData[shuffle_num, :, :, :]
        self.allLabel = self.allLabel[shuffle_num]

        # test data
        #self.test_tmp = scipy.io.loadmat(self.root + 'A0%dE.mat' % self.nSub)
        self.test_data = X_test
        self.test_label = y_test

        #self.test_data = np.transpose(self.test_data, (2, 1, 0))
        if enc == False:
          self.test_data = np.expand_dims(self.test_data, axis=1)
        self.test_label = np.transpose(self.test_label)

        self.testData = self.test_data
        self.testLabel = self.test_label#[0]

        # data shape: (trial, conv channel, electrode channel, time samples)
        print(self.allData.shape)
        return self.allData, self.allLabel, self.testData, self.testLabel


    def train(self, X_train, y_train, X_test, y_test, enc):

        img, label, test_data, test_label = self.get_source_data(X_train, y_train, X_test, y_test, enc)

        img = torch.from_numpy(img)
        label = torch.from_numpy(label)

        dataset = torch.utils.data.TensorDataset(img, label)
        self.dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)

        test_data = torch.from_numpy(test_data)
        test_label = torch.from_numpy(test_label)
        test_dataset = torch.utils.data.TensorDataset(test_data, test_label)
        self.test_dataloader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)

        # Optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        test_data = Variable(test_data.type(self.Tensor))
        test_label = Variable(test_label.type(self.LongTensor))

        bestAcc = 0
        averAcc = 0
        num = 0
        Y_true = 0
        Y_pred = 0

        # Train the cnn model
        total_step = len(self.dataloader)
        curr_lr = self.lr

        for e in range(self.n_epochs):
            # in_epoch = time.time()
            self.model.train()
            for i, (img, label) in enumerate(self.dataloader):

                img = Variable(img.cuda().type(self.Tensor))
                label = Variable(label.cuda().type(self.LongTensor))

                tok, outputs = self.model(img)

                loss = self.criterion_cls(outputs, label)

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()


            # out_epoch = time.time()


            # test process
            if (e + 1) % 1 == 0:
                self.model.eval()
                Tok, Cls = self.model(test_data)


                loss_test = self.criterion_cls(Cls, test_label)
                y_pred = torch.max(Cls, 1)[1]
                acc = float((y_pred == test_label).cpu().numpy().astype(int).sum()) / float(test_label.size(0))
                train_pred = torch.max(outputs, 1)[1]
                train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))

                print('Epoch:', e,
                      '  Train loss: %.6f' % loss.detach().cpu().numpy(),
                      '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
                      '  Train accuracy %.6f' % train_acc,
                      '  Test accuracy is %.6f' % acc)

                num = num + 1
                averAcc = averAcc + acc
                if acc > bestAcc:
                    bestAcc = acc
                    Y_true = test_label
                    Y_pred = y_pred


        averAcc = averAcc / num
        print('The average accuracy is:', averAcc)
        print('The best accuracy is:', bestAcc)

        return bestAcc, averAcc, Y_true, Y_pred
        # writer.close()

## Stimulus (12) Classification with encoded data

### Music Imagery

In [55]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)

print('Music Imagery (stimulus)')
exp = ExP(1, 12, 16, 80)

enc = True
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_img_enc, y_train_img_enc, X_test_img_enc, y_test_img_enc, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Imagery (stimulus)
(432, 1, 16, 45)
Epoch: 0   Train loss: 2.580720   Test loss: 2.492373   Train accuracy 0.055556   Test accuracy is 0.083333
Epoch: 1   Train loss: 2.577056   Test loss: 2.490440   Train accuracy 0.041667   Test accuracy is 0.101852
Epoch: 2   Train loss: 2.549945   Test loss: 2.489599   Train accuracy 0.097222   Test accuracy is 0.092593
Epoch: 3   Train loss: 2.554906   Test loss: 2.490283   Train accuracy 0.069444   Test accuracy is 0.064815
Epoch: 4   Train loss: 2.580858   Test loss: 2.489111   Train accuracy 0.097222   Test accuracy is 0.046296
Epoch: 5   Train loss: 2.577107   Test loss: 2.488560   Train accuracy 0.069444   Test accuracy is 0.046296
Epoch: 6   Train loss: 2.519996   Test loss: 2.485822   Train accuracy 0.069444   Test accuracy is 0.046296
Epoch: 7   Train loss: 2.491054   Test loss: 2.486178   Train accuracy 0.180556   Test accuracy is 0.046296
Epoch: 8   Train loss: 2.639588   Test loss: 2.486567   Train accuracy 0.041667   T

### Music Perception

In [56]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)

print('Music Perception (stimulus)')
exp = ExP(1, 12, 16, 80)

enc = True
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_per_enc, y_train_per_enc, X_test_per_enc, y_test_per_enc, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Perception (stimulus)
(432, 1, 16, 45)
Epoch: 0   Train loss: 2.618979   Test loss: 2.498488   Train accuracy 0.027778   Test accuracy is 0.074074
Epoch: 1   Train loss: 2.599030   Test loss: 2.500230   Train accuracy 0.083333   Test accuracy is 0.064815
Epoch: 2   Train loss: 2.491683   Test loss: 2.499096   Train accuracy 0.111111   Test accuracy is 0.074074
Epoch: 3   Train loss: 2.637296   Test loss: 2.498282   Train accuracy 0.069444   Test accuracy is 0.074074
Epoch: 4   Train loss: 2.533016   Test loss: 2.497605   Train accuracy 0.083333   Test accuracy is 0.083333
Epoch: 5   Train loss: 2.576912   Test loss: 2.498473   Train accuracy 0.041667   Test accuracy is 0.074074
Epoch: 6   Train loss: 2.555480   Test loss: 2.496220   Train accuracy 0.055556   Test accuracy is 0.064815
Epoch: 7   Train loss: 2.496372   Test loss: 2.497215   Train accuracy 0.083333   Test accuracy is 0.083333
Epoch: 8   Train loss: 2.586824   Test loss: 2.496435   Train accuracy 0.125000 

## Group (3) Classification with encoded data

### Music Imagery

In [57]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)

print('Music Imagery (group)')
exp = ExP(1, 3, 16, 80)

enc = True
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_img_enc, group_y_train_img_enc, X_test_img_enc, group_y_test_img_enc, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Imagery (group)
(432, 1, 16, 45)
Epoch: 0   Train loss: 1.131751   Test loss: 1.112242   Train accuracy 0.402778   Test accuracy is 0.333333
Epoch: 1   Train loss: 1.183155   Test loss: 1.110631   Train accuracy 0.291667   Test accuracy is 0.333333
Epoch: 2   Train loss: 1.235963   Test loss: 1.108821   Train accuracy 0.305556   Test accuracy is 0.296296
Epoch: 3   Train loss: 1.143146   Test loss: 1.106554   Train accuracy 0.361111   Test accuracy is 0.314815
Epoch: 4   Train loss: 1.168406   Test loss: 1.103312   Train accuracy 0.347222   Test accuracy is 0.287037
Epoch: 5   Train loss: 1.131219   Test loss: 1.103207   Train accuracy 0.361111   Test accuracy is 0.314815
Epoch: 6   Train loss: 1.097847   Test loss: 1.101685   Train accuracy 0.402778   Test accuracy is 0.314815
Epoch: 7   Train loss: 1.233155   Test loss: 1.102625   Train accuracy 0.180556   Test accuracy is 0.277778
Epoch: 8   Train loss: 1.094999   Test loss: 1.103398   Train accuracy 0.402778   Test

### Music Perception

In [58]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)


print('Music Perception (group)')
exp = ExP(1, 3, 16, 80)

enc = True
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_per_enc, group_y_train_per_enc, X_test_per_enc, group_y_test_per_enc, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Perception (group)
(432, 1, 16, 45)
Epoch: 0   Train loss: 1.121874   Test loss: 1.116040   Train accuracy 0.375000   Test accuracy is 0.342593
Epoch: 1   Train loss: 1.137179   Test loss: 1.115800   Train accuracy 0.333333   Test accuracy is 0.342593
Epoch: 2   Train loss: 1.140076   Test loss: 1.110305   Train accuracy 0.388889   Test accuracy is 0.305556
Epoch: 3   Train loss: 1.145172   Test loss: 1.108686   Train accuracy 0.347222   Test accuracy is 0.296296
Epoch: 4   Train loss: 1.162485   Test loss: 1.105791   Train accuracy 0.277778   Test accuracy is 0.305556
Epoch: 5   Train loss: 1.122366   Test loss: 1.103516   Train accuracy 0.333333   Test accuracy is 0.324074
Epoch: 6   Train loss: 1.131126   Test loss: 1.104172   Train accuracy 0.347222   Test accuracy is 0.342593
Epoch: 7   Train loss: 1.196831   Test loss: 1.104831   Train accuracy 0.250000   Test accuracy is 0.342593
Epoch: 8   Train loss: 1.104249   Test loss: 1.103752   Train accuracy 0.375000   T

## Meter (2) Classification with encoded data

### Music Imagery

In [61]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)


print('Music Imagery (meter)')
exp = ExP(1, 2, 16, 80)

enc = True
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_img_enc, meter_y_train_img_enc, X_test_img_enc, meter_y_test_img_enc, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Imagery (meter)
(432, 1, 16, 45)
Epoch: 0   Train loss: 0.794522   Test loss: 0.693855   Train accuracy 0.458333   Test accuracy is 0.509259
Epoch: 1   Train loss: 0.757519   Test loss: 0.697428   Train accuracy 0.472222   Test accuracy is 0.509259
Epoch: 2   Train loss: 0.814661   Test loss: 0.700344   Train accuracy 0.388889   Test accuracy is 0.509259
Epoch: 3   Train loss: 0.779602   Test loss: 0.699582   Train accuracy 0.430556   Test accuracy is 0.500000
Epoch: 4   Train loss: 0.837734   Test loss: 0.699104   Train accuracy 0.458333   Test accuracy is 0.500000
Epoch: 5   Train loss: 0.705722   Test loss: 0.698657   Train accuracy 0.527778   Test accuracy is 0.518519
Epoch: 6   Train loss: 0.782341   Test loss: 0.698376   Train accuracy 0.430556   Test accuracy is 0.500000
Epoch: 7   Train loss: 0.762866   Test loss: 0.698247   Train accuracy 0.416667   Test accuracy is 0.490741
Epoch: 8   Train loss: 0.822930   Test loss: 0.698215   Train accuracy 0.375000   Test

### Music Perception

In [62]:
seed_n = 42
print('seed is ' + str(seed_n))
random.seed(seed_n)
np.random.seed(seed_n)
torch.manual_seed(seed_n)
torch.cuda.manual_seed(seed_n)
torch.cuda.manual_seed_all(seed_n)


print('Music Perception (meter)')
exp = ExP(1, 2, 16, 80)

enc = True
bestAcc, averAcc, Y_true, Y_pred = exp.train(X_train_per_enc, meter_y_train_per_enc, X_test_per_enc, meter_y_test_per_enc, enc)
print('THE BEST ACCURACY IS ' + str(bestAcc))

seed is 42
Music Perception (meter)
(432, 1, 16, 45)
Epoch: 0   Train loss: 0.740980   Test loss: 0.695933   Train accuracy 0.541667   Test accuracy is 0.509259
Epoch: 1   Train loss: 0.777287   Test loss: 0.700849   Train accuracy 0.500000   Test accuracy is 0.490741
Epoch: 2   Train loss: 0.825644   Test loss: 0.705810   Train accuracy 0.458333   Test accuracy is 0.490741
Epoch: 3   Train loss: 0.735018   Test loss: 0.706054   Train accuracy 0.513889   Test accuracy is 0.490741
Epoch: 4   Train loss: 0.746035   Test loss: 0.704973   Train accuracy 0.486111   Test accuracy is 0.500000
Epoch: 5   Train loss: 0.688938   Test loss: 0.705956   Train accuracy 0.583333   Test accuracy is 0.500000
Epoch: 6   Train loss: 0.802301   Test loss: 0.705290   Train accuracy 0.472222   Test accuracy is 0.500000
Epoch: 7   Train loss: 0.732600   Test loss: 0.703024   Train accuracy 0.472222   Test accuracy is 0.509259
Epoch: 8   Train loss: 0.809065   Test loss: 0.701334   Train accuracy 0.416667   T