## COLAB TOOLS

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:

import os
import sys

GOOGLE_DRIVE_PATH_AFTER_MYDRIVE = "Colab Notebooks/prj_neuroread_analysis/neuroread/"
GOOGLE_DRIVE_PATH = os.path.join("/content", "drive", "MyDrive", GOOGLE_DRIVE_PATH_AFTER_MYDRIVE)
print(os.listdir(GOOGLE_DRIVE_PATH))

# Add to sys so we can import .py files.
sys.path.append(GOOGLE_DRIVE_PATH)
os.chdir(GOOGLE_DRIVE_PATH)

# Install unavailable packages
import pip
def import_or_install(package):
    try:
        __import__(package)
    except ImportError:
        pip.main(['install', package])

import_or_install("mne")


['train_cl_eeg2speech_rochester_v3_test_gridsearch.ipynb', '.git', '.DS_Store', '.gitignore', 'EEG', 'LICENSE', 'train_cl_eeg2speech_rochester_v1.ipynb', 'train_cl_eeg2speech_rochester_v2.ipynb', 'train_cl_eeg2speech_rochester_v3.ipynb', '.ipynb_checkpoints', 'train_cl_eeg2speech_rochester_v3_test_old.ipynb', 'runs', 'train_cl_eeg2speech_rochester_v3_test.ipynb', 'train_cl_eeg2speech_rochester_v4_gridseaerch.ipynb', 'train_cl_eeg2speech_2.ipynb', 'train_cl_eeg2speech_rochester_subj_2.ipynb', 'README.md', 'train_eeg2speech_rochester.ipynb']


Please see https://github.com/pypa/pip/issues/5599 for advice on fixing the underlying issue.
To avoid this problem you can invoke Python with '-m pip' instead of running pip directly.


Output()

In [3]:
%load_ext autoreload
%autoreload 2

In [4]:
from psutil import virtual_memory
ram_gb = virtual_memory().total / 1e9
print('Your runtime has {:.1f} gigabytes of available RAM\n'.format(ram_gb))

gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
 print('Not connected to a GPU')
else:
 print(gpu_info)

Your runtime has 89.6 gigabytes of available RAM

Wed Mar  1 13:12:23 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.85.12    Driver Version: 525.85.12    CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   29C    P0    48W / 400W |      0MiB / 40960MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-------------------------------------

## Main code

In [5]:
import os, sys, glob

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

import numpy as np

import mne

import matplotlib
import matplotlib.pyplot as plt
import time

from torchsummary import summary
from torch.utils.tensorboard import SummaryWriter

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

cuda:0


In [6]:
def eval_model_cl(dl, model, device=torch.device('cpu'), verbose=True):
    """ 
    This function calculates the loss on data, setting backward gradients and batchnorm
    off. This function is written for contrasting learning where the model takes in two
    inputs.

    Args:

    Returns:
      loss_test: Mean loss of all test samples (scalar)

    """
    losses, losses_X1, losses_X2 = [], [], []
    model.to(device)  # inplace for model
    # Set the model in evaluation mode
    model.eval()

    with torch.no_grad():
        for idx_batch, (X1b, X2b) in enumerate(dl):

            X1b = X1b.to(device)
            X2b = X2b.to(device)

            X1b_features, X2b_features, logit_sc = model(X1b, X2b)

            # Normalize features
            X1b_f_n = X1b_features / X1b_features.norm(dim=1, keepdim=True)
            X2b_f_n = X2b_features / X2b_features.norm(dim=1, keepdim=True)

            logits_per_X1 = logit_sc * X1b_f_n @ X2b_f_n.t()
            logits_per_X2 = logits_per_X1.t()

            # Number of labels equals to the 1st dimension of X1b
            labels = torch.arange(X1b.shape[0], device=device)

            # Batch Loss 
            loss_X1 = F.cross_entropy(logits_per_X1, labels)
            loss_X2 = F.cross_entropy(logits_per_X2, labels)
            loss_batch   = (loss_X1 + loss_X2) / 2
            losses.append(loss_batch.item())
            losses_X1.append(loss_X1.item())
            losses_X2.append(loss_X2.item())

        # Epoch loss (mean of batch losses)
        loss  = sum(losses) / len(losses)
        loss_X1 = sum(losses_X1) / len(losses_X1)
        loss_X2 = sum(losses_X2) / len(losses_X2)

        if verbose:
          print(f"====> Validation loss: {loss:.4f},  X1 loss: {loss_X1:.4f}   X2 loss: {loss_X2:.4f}")

        return loss, loss_X1, loss_X2


In [7]:
def unfold_raw(raw, window_size=None, stride=None):
    """
    This function unfolds raw MNE object into a list of raw objects
    Args:
        raw: a raw MNE object cropped by rejecting bad segments.
    Returns:
        raw_unfolded: a raw MNE object unfolded by applying a sliding window.
    """
    if window_size is None:
        window_size = int(5 * raw.info['sfreq'])
    if stride is None:
        stride = window_size
    nchans = len(raw.ch_names)
    sig = torch.tensor(raw.get_data(), dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    sig_unf = F.unfold(sig, (nchans, window_size), stride=stride , padding=0)
    sig_unf = sig_unf.permute(0, 2, 1).reshape(-1, sig_unf.shape[-1], nchans, window_size)
    return sig_unf

In [8]:
def rm_repeated_annotations(raw):
    """This functions taskes in raw MNE obejct and removes repeated annotations"""
    annots = raw.annotations.copy()
    annots_drop = []
    for k in annots:
        annots_drop.extend([k for kk in annots if (k['onset'] > kk['onset']) and (k['onset']+k['duration'] < kk['onset']+kk['duration']) ])

    annots_updated = [i for i in annots if i not in annots_drop]
    onsets = [i['onset'] for i in annots_updated]
    durations = [i['duration'] for i in annots_updated]
    descriptions = [i['description'] for i in annots_updated]
    print('Initial num of annots: %d  Num of removed annots: %d  Num of retained annots:  %d' % (len(annots), len(annots_drop), len(annots_updated)))
    print(f' New annots: {annots_updated}')
    raw.set_annotations(mne.Annotations(onsets, durations, descriptions) ) 
    return raw

## Read Data

In [9]:
subj_ids = [1, 2, 3, 4, 5]
fs = 128
window_size = int(5 * fs)
stride_size_train, stride_size_val, stride_size_test = int(2.5 * fs), int(5 * fs), int(5 * fs)
n_channs = 129 # 128 for eeg, 1 for env
batch_size = int(32)
print('-------------------------------------')
print(f'window_size: {window_size}  stride_size_test: {stride_size_test}')

dataset_name = ['rochester_data', 'natural_speech']
outputs_path = f'../outputs/'
data_path = os.path.join(outputs_path, dataset_name[0], dataset_name[1])
print(f'data_path: {data_path}')

-------------------------------------
window_size: 640  stride_size_test: 640
data_path: ../outputs/rochester_data/natural_speech


In [10]:
raws_train_windowed, raws_val_windowed, raws_test_windowed = [], [], []

for subj_id in subj_ids:
    subj_path = os.path.join(data_path, f'subj_{subj_id}')

    # load subject raw MNE object
    raw = mne.io.read_raw(os.path.join(subj_path, 'after_ica_raw.fif'), preload=True)
    # drop M1 and M2 channels
    raw.drop_channels(['M1', 'M2'])
    assert raw.info['nchan'] == n_channs

    raw = rm_repeated_annotations(raw)
    annots = raw.annotations.copy()
    raw_split = [raw.copy().crop(t1, t2) for t1, t2 in zip(annots.onset[:-1]+annots.duration[:-1], annots.onset[1:])]

    # Pick the split with the longest duration for validation, supposedly less noisy
    ix_val = np.argmax([i.get_data().shape[1] for i in raw_split])
    raw_val = [raw_split.pop(ix_val)] # create a list to make it iterable. later may be used for multiple splits

    # Pick the next split with the longest duration for testing, supposedly less noisy
    ix_test = np.argmax([i.get_data().shape[1] for i in raw_split])
    raw_test = [raw_split.pop(ix_test)]
    
    # creat list of unfolded tensor raw objects
    fs = raw.info['sfreq']
    raws_train_windowed.extend([unfold_raw(i, window_size=window_size, stride=stride_size_train) for i in raw_split if i.get_data().shape[1] > window_size])
    raws_val_windowed.extend([unfold_raw(i, window_size=window_size, stride=stride_size_val) for i in raw_val if i.get_data().shape[1] > window_size])
    raws_test_windowed.extend([unfold_raw(i, window_size=window_size, stride=stride_size_test) for i in raw_test if i.get_data().shape[1] > window_size])
    print("-------------------------------------")
    print('N train: %d  N val: %d  N test: %d' % (len(raws_train_windowed), len(raws_val_windowed), len(raws_test_windowed)))

# concatenate all in second dimension
sigs_train = torch.cat(raws_train_windowed, dim=1).permute(1, 0, 2, 3)
sigs_val = torch.cat(raws_val_windowed, dim=1).permute(1, 0, 2, 3)
sigs_test = torch.cat(raws_test_windowed, dim=1).permute(1, 0, 2, 3)
print(f"Shape Trian: {sigs_train.shape}  Shape Val: {sigs_val.shape}  Shape Test: {sigs_test.shape}")

eegs_train = sigs_train[:, :, :-1, :]
eegs_val = sigs_val[:, :, :-1, :]
eegs_test = sigs_test[:, :, :-1, :]
print("-------------------------------------")
print(f"Shape EEG Train: {eegs_train.shape}  Val: {eegs_val.shape}  Test: {eegs_test.shape}")

# To avoid information leakage, we estimate the mean and std from the training set only.
mean_eeg_train =  eegs_train.mean()
std_eeg_train = eegs_train.std()
print(f"Mean: {mean_eeg_train}  Std: {std_eeg_train}")

envs_train = sigs_train[:, :, [-1], :]
envs_val = sigs_val[:, :, [-1], :]
envs_test = sigs_test[:, :, [-1], :]
print(f"Shape Env Train: {envs_train.shape}  Val: {envs_val.shape}  Test: {envs_test.shape}")

# Estimate mean and std of the Envelope data set
mean_env_train =  envs_train.mean()
std_env_train = envs_train.std()
print(f"Mean Env: {mean_env_train}  Std Env: {std_env_train}")

# Normalize the data
eegs_train = (eegs_train - mean_eeg_train) / std_eeg_train
eegs_val = (eegs_val - mean_eeg_train) / std_eeg_train
eegs_test = (eegs_test - mean_eeg_train) / std_eeg_train

envs_train = (envs_train - mean_env_train) / std_env_train
envs_val = (envs_val - mean_env_train) / std_env_train
envs_test = (envs_test - mean_env_train) / std_env_train



Opening raw data file ../outputs/rochester_data/natural_speech/subj_1/after_ica_raw.fif...
    Range : 0 ... 464571 =      0.000 ...  3629.461 secs
Ready.
Reading 0 ... 464571  =      0.000 ...  3629.461 secs...
Initial num of annots: 42  Num of removed annots: 19  Num of retained annots:  23
 New annots: [OrderedDict([('onset', 0.0), ('duration', 0.0), ('description', 'bad'), ('orig_time', None)]), OrderedDict([('onset', 176.784332), ('duration', 2.0346832275390625), ('description', 'bad'), ('orig_time', None)]), OrderedDict([('onset', 357.548248), ('duration', 1.66888427734375), ('description', 'bad'), ('orig_time', None)]), OrderedDict([('onset', 537.407166), ('duration', 1.8746337890625), ('description', 'bad'), ('orig_time', None)]), OrderedDict([('onset', 719.272339), ('duration', 0.75445556640625), ('description', 'bad'), ('orig_time', None)]), OrderedDict([('onset', 899.217163), ('duration', 1.8060302734375), ('description', 'bad'), ('orig_time', None)]), OrderedDict([('onset',

### Pytorch dataloader

In [11]:
class MyDataset(Dataset):
    def __init__(self, eeg, env):
        self.eeg = eeg
        self.env = env
    
    def __getitem__(self, index):
        return self.eeg[index], self.env[index]
    
    def __len__(self):
        return len(self.eeg)
    
dataset_train = MyDataset(eegs_train, envs_train)
dataloader = DataLoader(dataset_train, batch_size=batch_size, shuffle=True, drop_last=True)

dl_val = DataLoader(MyDataset(eegs_val, envs_val), batch_size=batch_size, shuffle=True, drop_last=True)

## Model

In [12]:
class Conv2d(nn.Conv2d):
    def __init__(self, in_channels, out_channels, kernel_size, **kargs):
        super().__init__(in_channels, out_channels, kernel_size, **kargs)

    def __call__(self, inp):
        self.out = super().__call__(inp)

        if self.out.requires_grad:
            self.out.retain_grad()

        return self.out
    
    # -----------------------------------------------------------------------------------------------
class Flatten:
    
  def __call__(self, x):
    self.out = x.view(x.shape[0], -1)
    return self.out
  
  def parameters(self):
    return []
  
  # -----------------------------------------------------------------------------------------------
class Linear(nn.Linear):
    def __init__(self, x, y, **kargs):
        super().__init__(x, y, **kargs)

    def __call__(self, inp):
        self.out = super().__call__(inp)
        return self.out
  # -----------------------------------------------------------------------------------------------
   
class ELU(nn.ELU):
    def __init__(self, alpha=1.0, inplace=False):
        super().__init__(alpha=1.0, inplace=False)

    def __call__(self, inp):
        self.out = super().__call__(inp)
        if self.out.requires_grad:
            self.out.retain_grad()
        return self.out

  # -----------------------------------------------------------------------------------------------
class Sequential:
  
    def __init__(self, layers):
        self.layers = layers

    def __call__(self, x):
        for layer in self.layers:
            x = layer(x)
        self.out = x
        return self.out

    def parameters(self):
        # get parameters of all layers and stretch them out into one list
        return [p for layer in self.layers for p in layer.parameters()]

    def named_parameters(self):
        # get parameters of all layers and stretch them out into one list
        return ((n, p) for layer in self.layers for n, p in layer.named_parameters())

In [13]:
# My implementation of the shallow convnet

fs = 64 # sampling rate
T = 5 * fs # number of time points in each trial
C = 64 # number of EEG channels
F1 = 8 # number of channels (depth) in the first conv layer
D = 2 # number of spatial filters in the second conv layer
F2 = D * F1 # number of channels (depth) in the pont-wise conv layer
num_classes = 4 # number of classes

shallow_covnet = Sequential([
    Conv2d(1, 40, (1, int(fs//2)), padding='same', bias=True),
    Conv2d(40, 40, (C, 1), padding=(0, 0), bias=False), nn.BatchNorm2d(40, affine=True), 
    nn.AvgPool2d((1, 75), (1, 15)), nn.Dropout(0.5),
    Conv2d(40, 4, kernel_size=(1, 30), padding='same', stride=(1, 1), bias=True),
    nn.Flatten(1, -1), # Flatten start_dim=1, end_dim=-1
    Linear(62*4, 4, bias=True),
])



In [19]:
## EEG Encoder with LINEAR

class EEGEncoderWithLinear(nn.Module):
    def __init__(self,             
            fs = 128, # sampling rate
            T = 5, # lenght of each trial in seconds
            C = 128, # number of EEG channels
            F1 = 8, # 8 or 4 number of channels (depth) in the first conv layer
            D = 2, # number of spatial filters in the second conv layer
            F2 = None # number of channels (depth) in the pont-wise conv layer
        ):
        super(EEGEncoderWithLinear, self).__init__()

        if F2 is None:
            F2 = D * F1

        self.eeg_encoder = nn.Sequential(
            Conv2d(1, F1, (1, int(fs/2)), padding='same', bias=True, groups=1),
            nn.BatchNorm2d(F1, affine=True),
            Conv2d(F1, out_channels=D*F1, kernel_size=(C, 1), padding=(0, 0), bias=False, groups=F1),
            nn.BatchNorm2d(D*F1, affine=True), ELU(), nn.AvgPool2d(1, 4), nn.Dropout(0.25),
                    
            Conv2d(F2, F2, (1, int(fs/(2*4))), padding='same', bias=False, groups=D*F1),
            Conv2d(D*F1, F2, kernel_size=(1, 1), padding=(0, 0), groups=1, bias=False),
            nn.BatchNorm2d(F2, affine=True), ELU(), nn.AvgPool2d(1, 8), nn.Dropout(0.25),

            nn.Flatten(),
            nn.Linear(F2*int((T*fs)//(8*4)), int(fs/4))
        ) 

    def forward(self, x):
        x = self.eeg_encoder(x)
        return x


def normalize_weights_eegnet(eeg_encoder):

    for ix, (name, param) in enumerate(eeg_encoder.named_parameters()):
        if  name == 'weight' and param.ndim==4 and ix==1: # normalize conv weights to max norm 1
            param.data = torch.renorm(param.data, 2, 0, maxnorm=1)
        elif name == 'weight' and param.ndim==2: # normalize fc weights to max norm 0.25
            param.data = torch.renorm(param.data, 2, 0, maxnorm=0.25)


eeg_encoder_with_linear = EEGEncoderWithLinear()

# Test the model, add no grad
with torch.no_grad():
    print(eeg_encoder_with_linear(eegs_train[:32, :, :, :]).shape)

#summary(eeg_encoder_with_linear, (1, 128, 640))

torch.Size([32, 32])


In [20]:
## EEG Encoder NO LINEAR

class EEGEncoderNoLinear(nn.Module):
    def __init__(self,             
            fs = 128, # sampling rate
            T = 5, # lenght of each trial in seconds
            C = 128, # number of EEG channels
            F1 = 8, # 8 or 4 number of channels (depth) in the first conv layer
            D = 2, # number of spatial filters in the second conv layer
            F2 = None # number of channels (depth) in the pont-wise conv layer
        ):
        super(EEGEncoderNoLinear, self).__init__()

        if F2 is None:
            F2 = D * F1

        self.eeg_encoder = nn.Sequential(
            Conv2d(1, F1, (1, int(fs/2)), padding='same', bias=True, groups=1),
            nn.BatchNorm2d(F1, affine=True),
            Conv2d(F1, out_channels=D*F1, kernel_size=(C, 1), padding=(0, 0), bias=False, groups=F1),
            nn.BatchNorm2d(D*F1, affine=True), ELU(), nn.AvgPool2d(1, 4), nn.Dropout(0.25),
                    
            Conv2d(F2, F2, (1, int(fs/(2*4))), padding='same', bias=False, groups=D*F1),
            Conv2d(D*F1, F2, kernel_size=(1, 1), padding=(0, 0), groups=1, bias=False),
            nn.BatchNorm2d(F2, affine=True), ELU(), nn.AvgPool2d(1, 8), nn.Dropout(0.25),

            nn.Flatten(),
            #nn.Linear(F2*int((T*fs)//(8*4)), int(fs/4))
        ) 

    def forward(self, x):
        x = self.eeg_encoder(x)
        return x


def normalize_weights_eegnet(eeg_encoder):

    for ix, (name, param) in enumerate(eeg_encoder.named_parameters()):
        if  name == 'weight' and param.ndim==4 and ix==1: # normalize conv weights to max norm 1
            param.data = torch.renorm(param.data, 2, 0, maxnorm=1)
        elif name == 'weight' and param.ndim==2: # normalize fc weights to max norm 0.25
            param.data = torch.renorm(param.data, 2, 0, maxnorm=0.25)


eeg_encoder_no_linear = EEGEncoderNoLinear()

# Test the model, add no grad
with torch.no_grad():
    print(eeg_encoder_no_linear(eegs_train[:32, :, :, :]).shape)

#summary(eeg_encoder_no_linear, (1, 128, 640))

torch.Size([32, 320])


In [21]:
class EnvEncoder3ConvNoLinear(nn.Module):

    def __init__(self,             
            fs = 128, # sampling rate
            T = 5, # lenght of each trial in seconds
            F1 = 4
        ):
        super(EnvEncoder3ConvNoLinear, self).__init__()

        self.env_encoder = nn.Sequential(
            Conv2d(1, F1, (1, int(fs//2)), padding='same', bias=True),
            nn.BatchNorm2d(F1, affine=True), ELU(), nn.AvgPool2d(1, 2), nn.Dropout(0.5),
            Conv2d(F1, F1, (1, int(fs//4)), padding='same', bias=False, groups=1),
            nn.BatchNorm2d(F1, affine=True), ELU(), nn.AvgPool2d(1, 2), nn.Dropout(0.5),
            Conv2d(F1, F1*4, (1, int(fs//8)), padding='same', bias=False, groups=1),
            nn.BatchNorm2d(F1*4, affine=True), ELU(), nn.AvgPool2d(1, 8), nn.Dropout(0.5),
            nn.Flatten(),
            #nn.Linear(F1*int((T*fs)//(2*8)), int(fs/4))
        ) 

    def forward(self, x):
        x = self.env_encoder(x)
        return x

env_encoder3conv_no_linear = EnvEncoder3ConvNoLinear()


# Test the model, add no grad
with torch.no_grad():
    print(env_encoder3conv_no_linear(envs_train[:32, :, :, :]).shape)
#summary(env_encoder3conv_no_linear, (1, 1, 640))

torch.Size([32, 320])


In [22]:
class EnvEncoder2ConvNoLinear(nn.Module):

    def __init__(self,             
            fs = 128, # sampling rate
            T = 5, # lenght of each trial in seconds
            F1 = 4
        ):
        super(EnvEncoder2ConvNoLinear, self).__init__()

        self.env_encoder = nn.Sequential(
            Conv2d(1, F1, (1, int(fs//2)), padding='same', bias=True),
            nn.BatchNorm2d(F1, affine=True), ELU(), nn.AvgPool2d(1, 2), nn.Dropout(0.5),
            Conv2d(F1, F1*4, (1, int(fs//4)), padding='same', bias=False, groups=1),
            nn.BatchNorm2d(F1*4, affine=True), ELU(), nn.AvgPool2d(1, 16), nn.Dropout(0.5),
            nn.Flatten(),
            #nn.Linear(F1*int((T*fs)//(2*8)), int(fs/4))
        ) 

    def forward(self, x):
        x = self.env_encoder(x)
        return x

env_encoder2conv_no_linear = EnvEncoder2ConvNoLinear()


# Test the model, add no grad
with torch.no_grad():
    print(env_encoder2conv_no_linear(envs_train[:32, :, :, :]).shape)
#summary(env_encoder2conv_no_linear, (1, 1, 640))

torch.Size([32, 320])


In [23]:
class EnvEncoder2ConvWithLinear(nn.Module):

    def __init__(self,             
            fs = 128, # sampling rate
            T = 5, # lenght of each trial in seconds
            F1 = 4
        ):
        super(EnvEncoder2ConvWithLinear, self).__init__()

        self.env_encoder = nn.Sequential(
            Conv2d(1, F1, (1, int(fs//2)), padding='same', bias=True),
            nn.BatchNorm2d(F1, affine=True), ELU(), nn.AvgPool2d(1, 2), nn.Dropout(0.5),
            Conv2d(F1, F1*4, (1, int(fs//4)), padding='same', bias=False, groups=1),
            nn.BatchNorm2d(F1*4, affine=True), ELU(), nn.AvgPool2d(1, 16), nn.Dropout(0.5),
            nn.Flatten(),
            nn.Linear(F1*4*int((T*fs)//(8*4)), int(fs/4))
        ) 

    def forward(self, x):
        x = self.env_encoder(x)
        return x

env_encoder2conv_with_linear = EnvEncoder2ConvWithLinear()


# Test the model, add no grad
with torch.no_grad():
    print(env_encoder2conv_with_linear(envs_train[:32, :, :, :]).shape)
#summary(env_encoder2conv_with_linear, (1, 1, 640))

torch.Size([32, 32])


In [27]:
class EnvEncoder3ConvWithLinear(nn.Module):

    def __init__(self,             
            fs = 128, # sampling rate
            T = 5, # lenght of each trial in seconds
            F1 = 4
        ):
        super(EnvEncoder3ConvWithLinear, self).__init__()

        self.env_encoder = nn.Sequential(
            Conv2d(1, F1, (1, int(fs//2)), padding='same', bias=True),
            nn.BatchNorm2d(F1, affine=True), ELU(), nn.AvgPool2d(1, 2), nn.Dropout(0.5),
            Conv2d(F1, F1, (1, int(fs//4)), padding='same', bias=False, groups=1),
            nn.BatchNorm2d(F1, affine=True), ELU(), nn.AvgPool2d(1, 2), nn.Dropout(0.5),
            Conv2d(F1, F1*4, (1, int(fs//8)), padding='same', bias=False, groups=1),
            nn.BatchNorm2d(F1*4, affine=True), ELU(), nn.AvgPool2d(1, 8), nn.Dropout(0.5),
            nn.Flatten(),
            nn.Linear(F1*4*int((T*fs)//(4*8)), int(fs/4))
        ) 

    def forward(self, x):
        x = self.env_encoder(x)
        return x

env_encoder3conv_with_linear = EnvEncoder3ConvWithLinear()


# Test the model, add no grad
with torch.no_grad():
    print(env_encoder3conv_with_linear(envs_train[:32, :, :, :]).shape)
#summary(env_encoder3conv_with_linear, (1, 1, 640))

torch.Size([32, 32])


In [28]:
class CES(nn.Module):
    def __init__(self, 
                 eeg_encoder= None,
                 env_encoder = None): 
        super().__init__()

        self.eeg_encoder = eeg_encoder
        self.env_encoder = env_encoder
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def encode_eeg(self, x):
        return self.eeg_encoder(x)
    
    def encode_env(self, x):
        return self.env_encoder(x)
    
    def forward(self, eeg, env):
        eeg_features = self.encode_eeg(eeg)
        env_features = self.encode_env(env)
        return eeg_features, env_features, self.logit_scale.exp()
  

model = CES();
model.to(device)
#for n,p in model.named_parameters():
    #print(n, p.shape)


CES()

In [29]:
print(" Models with no Linear Layers")
eeg_encoder_no_linear = EEGEncoderNoLinear()
env_encoder3conv_no_linear = EnvEncoder2ConvNoLinear()
ces_eeg_0lin_env_3conv_0lin = CES(eeg_encoder=eeg_encoder_no_linear.eeg_encoder, env_encoder=env_encoder2conv_no_linear.env_encoder)
#summary(ces_eeg_0lin_env_3conv_0lin, [(1, 128, 640), (1, 1, 640)])

eeg_encoder_no_linear = EEGEncoderNoLinear()
env_encoder2conv_no_linear = EnvEncoder2ConvNoLinear()
ces_eeg_0lin_env_2conv_0lin = CES(eeg_encoder=eeg_encoder_no_linear.eeg_encoder, env_encoder=env_encoder2conv_no_linear.env_encoder)
#summary(ces_eeg_0lin_env_2conv_0lin, [(1, 128, 640), (1, 1, 640)])


print(" Models with Linear Layers")
eeg_encoder_with_linear = EEGEncoderWithLinear()
env_encoder3conv_with_linear = EnvEncoder2ConvWithLinear()
ces_eeg_1lin_env_3conv_1lin = CES(eeg_encoder=eeg_encoder_with_linear.eeg_encoder, env_encoder=env_encoder3conv_with_linear.env_encoder)
#summary(ces_eeg_1lin_env_3conv_1lin, [(1, 128, 640), (1, 1, 640)])

eeg_encoder_with_linear = EEGEncoderWithLinear()
env_encoder2conv_with_linear = EnvEncoder2ConvWithLinear()
ces_eeg_1lin_env_2conv_1lin = CES(eeg_encoder=eeg_encoder_with_linear.eeg_encoder, env_encoder=env_encoder2conv_with_linear.env_encoder)
#summary(ces_eeg_1lin_env_2conv_1lin, [(1, 128, 640), (1, 1, 640)])

models_name = ["eeg0lin_env3conv0lin", "eeg0lin_env2conv0lin", "eeg1lin_env3conv1lin", "eeg1lin_env2conv1lin"]
models_dict = {"eeg0lin_env3conv0lin": ces_eeg_0lin_env_3conv_0lin, "eeg0lin_env2conv0lin": ces_eeg_0lin_env_2conv_0lin, 
               "eeg1lin_env3conv1lin": ces_eeg_1lin_env_3conv_1lin, "eeg1lin_env2conv1lin": ces_eeg_1lin_env_2conv_1lin}


 Models with no Linear Layers
 Models with Linear Layers


In [30]:

lossi = []
udri = [] # update / data ratio 
ud = []

lr = 0.001

for name, model in models_dict.items():

    # Reset for the new model in the loop
    print(f"+--------------New model: {name}----------------------+")
    writer = SummaryWriter(log_dir=f"runs/{name}_{time.strftime('%Y%m%d_%H%M%S')}")
    model.to(device)
    optimizer = optim.NAdam(model.parameters(), lr=lr)
    cnt = 0
    loss_batches = []


    for epoch in range(1, 100):

        print(f"====== Epoch: {epoch}")

        model.train()
        for ix_batch, (Xb_eeg, Xb_env) in enumerate(dataloader):

            # send to device
            Xb_eeg = Xb_eeg.to(device)
            Xb_env = Xb_env.to(device)

            # Zero out gradients
            optimizer.zero_grad()

            # forward pass
            eeg_features, env_features, logit_scale = model(Xb_eeg, Xb_env) 


            # normalize features
            eeg_features_n = eeg_features / eeg_features.norm(dim=1, keepdim=True)
            env_features_n = env_features / env_features.norm(dim=1, keepdim=True)

            # logits
            logits_per_eeg = logit_scale * eeg_features_n @ env_features_n.t()
            logits_per_env = logits_per_eeg.t()

            #loss function
            labels = torch.arange(batch_size).to(device)
            loss_eeg = F.cross_entropy(logits_per_eeg, labels)
            loss_env = F.cross_entropy(logits_per_env, labels)
            loss   = (loss_eeg + loss_env)/2

            # backward pass
            loss.backward()
            optimizer.step()

            loss_batches.append(loss.item())
            cnt += 1

            with torch.no_grad():
                #ud = {f"p{ix}":(lr*p.grad.std() / p.data.std()).log10().item() for ix, p in enumerate(model.parameters()) if p.ndim==4 }
                #writer.add_scalars('UpdateOData/ud', ud, cnt)
                writer.add_scalar('Loss/train_batch', loss.item(), cnt)
            
            #break   

        loss_epoch = loss_batches[-(ix_batch + 1):]  # mean loss across batches
        loss_epoch = sum(loss_epoch) / len(loss_epoch)
        writer.add_scalar('Loss/train_epoch', loss_epoch, epoch)
        #for pname, p in model.named_parameters():
        #writer.add_histogram(f'Params/{pname}', p, epoch)
        #writer.add_histogram(f'Grads/{pname}', p.grad, epoch)

        loss_val, *_ = eval_model_cl(dl_val, model, device=device)
        writer.add_scalar('Loss/val_epoch', loss_val, epoch)

        
        # normalize weights
        with torch.no_grad():
            normalize_weights_eegnet(model.eeg_encoder)

        model.train()
            
    #break   


+--------------New model: eeg0lin_env3conv0lin----------------------+


====> Validation loss: 3.2588,  X1 loss: 3.2384   X2 loss: 3.2792
====> Validation loss: 3.2573,  X1 loss: 3.2559   X2 loss: 3.2586
====> Validation loss: 3.2296,  X1 loss: 3.2194   X2 loss: 3.2397
====> Validation loss: 3.1470,  X1 loss: 3.1281   X2 loss: 3.1660
====> Validation loss: 3.0635,  X1 loss: 3.0351   X2 loss: 3.0919
====> Validation loss: 3.0362,  X1 loss: 3.0204   X2 loss: 3.0519
====> Validation loss: 3.0325,  X1 loss: 3.0180   X2 loss: 3.0470
====> Validation loss: 3.0200,  X1 loss: 3.0087   X2 loss: 3.0313
====> Validation loss: 2.9823,  X1 loss: 2.9603   X2 loss: 3.0044
====> Validation loss: 2.9753,  X1 loss: 2.9604   X2 loss: 2.9902
====> Validation loss: 3.0359,  X1 loss: 3.0277   X2 loss: 3.0441
====> Validation loss: 3.0435,  X1 loss: 3.0265   X2 loss: 3.0605
====> Validation loss: 2.9389,  X1 loss: 2.9221   X2 loss: 2.9557
====> Validation loss: 2.9226,  X1 loss: 2.9036   X2 loss: 2.9417
====> Validation loss: 3.0037,  X1 loss: 2.9955   X2 loss: 3.0120
====> Vali