# **Setup**

**Imports**

In [None]:
# @title

import torch
import scipy
import numpy as np
import os
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader as DL
from torch.utils.data import TensorDataset as TData
import torch.optim as optim
from tqdm.auto import tqdm
from google.colab import files


In [None]:
# @title
import zipfile
import os

To run model, first make a new folder in the files tab to the left, naming it EEGData. Running the following code will process the data and set up the train and test dataloaders for you

In [None]:
# @title
zip_file = "/content/data_nathan_LR.zip"

# Extract contents to the current directory or a specific folder
with zipfile.ZipFile(zip_file, 'r') as zip_ref:
    zip_ref.extractall("EEGData")
    # Replace with your desired folder name

# List extracted files
print(os.listdir("EEGData"))

['data_nathan_LR']


In [None]:
# @title

from google.colab import drive
drive.mount('/content/drive')

new_base = 'EEGData/data_nathan_LR'
files = os.listdir(new_base)
print(files)
paths = [os.path.join(new_base, file) for file in files]
"""print('first data:')
#print(files[1].rstrip('.npy').split('_'))
file_splits = [file.rstrip('.npy').split('_') for file in files]
print(file_splits)"""
file_splits = [file.rstrip('.npy').split('_')[3] for file in files]
sessions = map(int, set(file_splits))
sorted_sessions = sorted(list(sessions))[8:]

print("Sorted sessions:", sorted_sessions)

sorted_sessions.remove(31)
sorted_sessions.remove(34)
sorted_sessions.remove(41)
print(sorted_sessions) # 17 sessions in total  45-57 new
# may not be needed if using k-fold testing

import random
chosen_sessions = random.sample(sorted_sessions, k=12)
val_sessions, test_sessions = chosen_sessions[:6], chosen_sessions[6:]
print(val_sessions, test_sessions)
# done for electrode grouping to perform 2D convolution
def channel_rearrangment(sig, channel_order):
    channel_order = [channel - 1 for channel in channel_order]
    reindexed = np.zeros_like(sig)
    for i, ind in enumerate(channel_order):
        reindexed[i] = sig[ind]
    return reindexed

ordered_channels = [1, 9, 11, 3, 2, 12, 10, 4, 13, 5, 15, 7, 14, 16, 6, 8]
# applying a bandpass filter
def bandpass_filter(signal, crit_freq = [1, 40], sampling_freq = 125, plot = False, channel = 0):
  order = 4

  b, a = scipy.signal.butter(order, crit_freq, btype = 'bandpass', fs = sampling_freq)
  processed_signal = scipy.signal.filtfilt(b, a, signal, 1)

  if plot == True:
    plt.figure()
    plt.xlabel('Time')
    plt.ylabel(f'Normalized amplitude of channel {channel}')
    plt.title(f'{crit_freq[0]}-{crit_freq[1]}Hz bandpass filter')
    signal_min = np.full((signal.shape[1], signal.shape[0]), np.min(signal, 1)).transpose()
    signal_max = np.full((signal.shape[1], signal.shape[0]), np.max(signal, 1)).transpose()
    normed_signal = (signal - signal_min) / (signal_max - signal_min)
    filtered_min = np.full((processed_signal.shape[1], processed_signal.shape[0]), np.min(processed_signal, 1)).transpose()
    filtered_max = np.full((processed_signal.shape[1], processed_signal.shape[0]), np.max(processed_signal, 1)).transpose()
    normed_filt = (processed_signal - filtered_min) / (filtered_max - filtered_min)
    plt.plot(np.arange(normed_signal[channel].size), normed_signal[channel], label = 'Input')
    plt.plot(np.arange(normed_filt[channel].size), normed_filt[channel], label = 'Transformed')
    plt.legend()

  return processed_signal


# function to segment eeg data based on sampling freq(Hz), window_size(s), and window_shift(s)
def segmentation(signal, sampling_freq=125, window_size=1, window_shift=0.016):
  w_size = int(sampling_freq * window_size)
  w_shift = int(sampling_freq * window_shift)
  segments = []
  i = 0
  while i + w_size <= signal.shape[1]:
    segments.append(signal[:, i: i + w_size])
    i += w_shift
  return segments

# applying all preprocessing steps to create train and test data
train_eeg = []
train_labels = []
valid_eeg = []
valid_labels = []
test_eeg = []
test_labels = []
for i in range(len(files)):
  name = files[i]
  details = name.rstrip('.npy').split('_')[2:] # getting session details from file name
  sig = np.load(paths[i]) # loading signal
  sig = sig[:, 1:] # removing first time step because it is inaccurate
  if sig.shape[1] == 0 or int(details[1]) not in sorted_sessions: # excluding empty sample elements
    #print(name)
    continue
  reindexed_signal = channel_rearrangment(sig, ordered_channels)
  filtered_sig = bandpass_filter(reindexed_signal, [5, 40], 125) # bandpass filter
  normed_sig = (filtered_sig - np.mean(filtered_sig, 1, keepdims=True)) / np.std(filtered_sig, 1, keepdims=True) # standard scaling
  if np.isnan(normed_sig).any(): # excluding sample elements with nans
    print(name)
    continue
  signals = segmentation(normed_sig, 128, window_size = 1.5, window_shift = 0.0175) # segmentation
  labels = [int(details[0])] * len(signals)
  if int(details[1]) in test_sessions:
    test_eeg.extend(signals)
    test_labels.extend(labels)
  elif int(details[1]) in val_sessions:
    valid_eeg.extend(signals)
    valid_labels.extend(labels)
  else:
    train_eeg.extend(signals)
    train_labels.extend(labels)

train_eeg_tensor = torch.zeros((len(train_eeg), train_eeg[0].shape[0], train_eeg[0].shape[1])) # untransposed dimensions 1 and 2
valid_eeg_tensor = torch.zeros((len(valid_eeg), valid_eeg[0].shape[0], valid_eeg[0].shape[1]))
test_eeg_tensor = torch.zeros((len(test_eeg), test_eeg[0].shape[0], test_eeg[0].shape[1]))
for i in range(len(train_eeg)):
  tens = torch.from_numpy(train_eeg[i].copy()) # no longer transposing before conversion to tensor
  train_eeg_tensor[i] = tens
for i in range(len(valid_eeg)):
  tens = torch.from_numpy(valid_eeg[i].copy())
  valid_eeg_tensor[i] = tens
for i in range(len(test_eeg)):
  tens = torch.from_numpy(test_eeg[i].copy())
  test_eeg_tensor[i] = tens
train_label_tensor = torch.zeros(len(train_labels), 2)
valid_label_tensor = torch.zeros(len(valid_labels), 2)
test_label_tensor = torch.zeros(len(test_labels), 2)
class_to_idx = {1:0, 3:1}
for i in range(len(train_labels)):
  label = class_to_idx[train_labels[i]]
  train_label_tensor[i][label] = 1
for i in range(len(valid_labels)):
  label = class_to_idx[valid_labels[i]]
  valid_label_tensor[i][label] = 1
for i in range(len(test_labels)):
  label = class_to_idx[test_labels[i]]
  test_label_tensor[i][label] = 1

train_ds = TData(train_eeg_tensor, train_label_tensor)
valid_ds = TData(valid_eeg_tensor, valid_label_tensor)
test_ds = TData(test_eeg_tensor, test_label_tensor)
train_dl = DL(train_ds, batch_size=64, shuffle= True, drop_last = True)
valid_dl = DL(valid_ds, batch_size=64, shuffle= True, drop_last = True)
test_dl = DL(test_ds, batch_size=64, shuffle = True, drop_last = True)

Mounted at /content/drive
['EEGMouse_Nathan_3_54_1329.npy', 'EEGMouse_Nathan_1_29_867.npy', 'EEGMouse_Nathan_1_26_799.npy', 'EEGMouse_Nathan_3_19_684.npy', 'EEGMouse_Nathan_1_13_560.npy', 'EEGMouse_Nathan_1_16_624.npy', 'EEGMouse_Nathan_3_19_674.npy', 'EEGMouse_Nathan_3_4_370.npy', 'EEGMouse_Nathan_3_56_1377.npy', 'EEGMouse_Nathan_1_50_1246.npy', 'EEGMouse_Nathan_3_47_1174.npy', 'EEGMouse_Nathan_1_52_1290.npy', 'EEGMouse_Nathan_1_27_817.npy', 'EEGMouse_Nathan_1_57_1398.npy', 'EEGMouse_Nathan_3_37_1028.npy', 'EEGMouse_Nathan_3_19_680.npy', 'EEGMouse_Nathan_3_22_730.npy', 'EEGMouse_Nathan_3_40_1063.npy', 'EEGMouse_Nathan_3_36_994.npy', 'EEGMouse_Nathan_1_37_1017.npy', 'EEGMouse_Nathan_1_5_385.npy', 'EEGMouse_Nathan_3_35_976.npy', 'EEGMouse_Nathan_3_10_495.npy', 'EEGMouse_Nathan_3_7_451.npy', 'EEGMouse_Nathan_3_32_924.npy', 'EEGMouse_Nathan_1_30_885.npy', 'EEGMouse_Nathan_3_32_922.npy', 'EEGMouse_Nathan_1_21_707.npy', 'EEGMouse_Nathan_3_1_338.npy', 'EEGMouse_Nathan_1_44_1125.npy', 'EEGMou

  normed_sig = (filtered_sig - np.mean(filtered_sig, 1, keepdims=True)) / np.std(filtered_sig, 1, keepdims=True) # standard scaling


EEGMouse_Nathan_1_51_1262.npy
EEGMouse_Nathan_1_50_1240.npy
EEGMouse_Nathan_1_47_1173.npy


In [None]:
b, l = next(iter(train_dl))

Below is the model code, which you have already gone through

In [None]:
# @title
class Squash(nn.Module):
    def __init__(self, eps=1e-20):
        super(Squash, self).__init__()
        self.eps = eps

    def forward(self, x):
        norm = torch.linalg.norm(x, ord=2, dim=-1, keepdim=True)
        coef = 1 - 1 / (torch.exp(norm) + self.eps)
        unit = x / (norm + self.eps)
        return coef * unit

class Routing(nn.Module):
    def __init__(self, groups, in_dims, out_dims):
        super(Routing, self).__init__()
        N0, D0 = in_dims
        N1, self.D1 = out_dims
        self.W = nn.Parameter(torch.Tensor(groups, N1, N0, D0, self.D1))
        nn.init.kaiming_normal_(self.W)
        self.b = nn.Parameter(torch.zeros(groups, N1, N0, 1))
        self.squash = Squash()

    def forward(self, x):

        u = torch.einsum('...gni,gknid->...gknd', x, self.W) # shape: (B, G, N1, N0, D1)

        c = torch.einsum("...ij,...kj->...i", u, u) # shape: (B, N1, N0)

        c = c[..., None]  # (B, N1, N0, 1) for bias broadcasting
        c = c / torch.sqrt(torch.tensor(self.D1).float())  # stabilize
        c = torch.softmax(c, axis=1) + self.b

        ## new capsules
        s = torch.sum(u * c, dim=-2)

        return self.squash(s)


class ReconstructionNet(nn.Module):
    def __init__(self, input_size=(1, 28, 28), num_classes=2, num_capsules=64):
        super(ReconstructionNet, self).__init__()
        self.input_size = input_size
        self.fc1 = nn.Linear(in_features=num_capsules * num_classes, out_features=512)
        self.fc2 = nn.Linear(512, 1024)
        self.fc3 = nn.Linear(1024, np.prod(input_size) * 2)
        self.relu = nn.ReLU()
        self.reset_parameters()

    def reset_parameters(self):
        gain = nn.init.calculate_gain('relu')
        nn.init.xavier_normal_(self.fc1.weight, gain=gain)
        nn.init.xavier_normal_(self.fc2.weight, gain=gain)
        nn.init.xavier_normal_(self.fc3.weight, gain=gain)

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        x = self.fc3(x)
        x = x.view(x.size(0), 2, *self.input_size).squeeze(1)
        complex_x = torch.complex(x[:, 0], x[:, 1]) # create complex tensor to reflext fourier transform
        return complex_x


class CapsMask(nn.Module):
    def __init__(self):
        super(CapsMask, self).__init__()

    def forward(self, x, y_true=None):
        if y_true is not None:  # training mode
            mask = y_true
        else:  # testing mode
            # convert list of maximum value's indices to one-hot tensor
            temp = torch.sqrt(torch.sum(x**2, dim=-1))
            mask = F.one_hot(torch.argmax(temp, dim=1), num_classes=temp.shape[1])

        masked = x * mask.unsqueeze(-1)

        return masked.view(x.shape[0], -1)  # reshape


class CapsLen(nn.Module):
    def __init__(self, eps=1e-7):
        super(CapsLen, self).__init__()
        self.eps = eps

    def forward(self, x):
        return torch.sqrt(
            torch.sum(x**2, dim=-1) + self.eps
        )  # (batch_size, num_capsules)


In [None]:
# @title
class EEGCapsNet(nn.Module):
    def __init__(self, input_size=(1, 16, 192), num_classes=2):
        super(EEGCapsNet, self).__init__()
        self.channelCapsTemporal_1 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=128, kernel_size=(1,64), groups=16, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1,input_size[2]-63), groups=128) # collapse to 1 point
        )
        self.channelCapsTemporal_2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=128, kernel_size=(1, 24), groups=16, padding=0),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(1,input_size[2]-23), groups=128) # collapse to 1 point
        )
        # channel caps will eventually be transformed from 8x4 capsules to 4x8 higher level capsules

        self.channelRouting = Routing(16, (8, 4), (4, 8))
        self.channelShrink = Routing(8, (8, 8), (4, 8))
        self.channelDeepSuper = Routing(1, (64, 8), (2, 32))
        self.channel_align = Routing(16, (4, 8), (1, 8))

        self.localCapsSpatial_1 = nn.Sequential(
            nn.Conv2d(in_channels=8, out_channels=64, kernel_size=(2, 36), groups=8, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, input_size[2]-35), groups=64) # collapse to 1 point
        )

        self.localCapsSpatial_2 = nn.Sequential(
            nn.Conv2d(in_channels=8, out_channels=64, kernel_size=(2, 16), groups=8),
            nn.BatchNorm2d(64),
            nn.ELU(),
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=(1, input_size[2]-15), groups=64), # collapse to 1 point
            nn.Dropout(p=0.3)
        )

        self.localRouting = Routing(8, (8, 4), (4, 8))
        self.localShrink = Routing(8, (8, 8), (4, 16))
        self.localRegion = Routing(4, (8, 16), (4, 16))
        self.localDeepSuper = Routing(1, (32, 8), (2, 64))
        self.local_align = Routing(8, (4, 8), (2, 8))

        # local spatial caps will be transformed from 8x4 capsules to 8x8 higher level capsules
        # local caps will be 16x8 (with addition from channel caps) this will be reduced to 8x8


        self.regionCapsSpatial_1 = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=(4, 24), groups=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1, input_size[2]-23), groups=32) # collapse to 1 point
        )

        self.regionCapsSpatial_2 = nn.Sequential(
            nn.Conv2d(in_channels=4, out_channels=32, kernel_size=(4, 32), groups=4),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(1, input_size[2]-31), groups=32) # collapse to 1 point
        )

        # region caps will be transformed from 8x4 capsules to 16x8 higher level capsules
        # region caps will be 32x8 (with addition from local caps) this will be reduced to 16x8

        self.regionRouting = Routing(4, (8, 4), (4, 16))
        self.regionShrink = Routing(4, (8, 16), (4, 16))
        self.regionHemi = Routing(2, (8, 16), (4, 16))
        self.regionDeepSuper = Routing(1, (16, 16), (2, 64))
        self.region_align = Routing(4, (4, 16), (4, 8))

        self.hemiCapsSpatial_1 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=16, kernel_size=(8, 30), groups=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1, input_size[2]-29), groups=16) # collapse to 1 point
        )

        self.hemiCapsSpatial_2 = nn.Sequential(
            nn.Conv2d(in_channels=2, out_channels=16, kernel_size=(8, 60), groups=2),
            nn.BatchNorm2d(16),
            nn.ReLU(),
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(1, input_size[2]-59), groups=16) # collapse to 1 point
        )

        # hemi caps will be transformed from 8x4 capsules to 32x8 higher level capsules
        # hemi caps will be 64x8 (with addition from region caps) this will be reduced to 32x8

        self.hemiRouting = Routing(2, (8, 4), (4, 16))
        self.hemiShrink = Routing(2, (8, 16), (4, 16))
        self.hemi_align = Routing(2, (4, 16), (8, 8))

        self.out = Routing(1, (64, 8), (num_classes, 64)) #changed from (1, (8, 16), (num_classes, 64))
        self.generator = ReconstructionNet(input_size, num_classes)
        self.channel_generator = ReconstructionNet(input_size, num_classes, num_capsules=32)
        self.local_generator = ReconstructionNet(input_size, num_classes, num_capsules=64)
        self.region_generator = ReconstructionNet(input_size, num_classes, num_capsules=64)
        self.mask = CapsMask()
        self.capsLen = CapsLen()


    def forward(self, x, y_true=None, mode='train'):
        x = x.permute(0, 2, 1, 3)
        data = x
        x1 = self.channelCapsTemporal_1(data).view(data.size(0), 16, 16)
        x2 = self.channelCapsTemporal_2(data).view(data.size(0), 16, 16)
        channels = torch.cat((x1, x2), dim=2).view(data.size(0), 16, 8, 4)
        channels = self.channelRouting(channels) #shape (16, 4, 8)
        channelsAlign = self.channel_align(channels).view(data.size(0), 1, 16, 8)
        deep_channels = self.channelDeepSuper(channels.view(data.size(0), 1, 64, 8))
        #new_locals = self.channelShrink(channels.view(data.size(0), 8, 8, 8))

        x1 = self.localCapsSpatial_1(data.view(data.size(0), 8, 2, 192)).view(data.size(0), 8, 16)
        x2 = self.localCapsSpatial_2(data.view(data.size(0), 8, 2, 192)).view(data.size(0), 8, 16)
        local = torch.cat((x1, x2), dim=2).view(data.size(0), 8, 8, 4)
        local = self.localRouting(local) #shape (8, 4, 8)
        localsAlign = self.local_align(local).view(data.size(0), 1, 16, 8)
        deep_locals = self.localDeepSuper(local.view(data.size(0), 1, 32, 8))

        #local = torch.cat((local, new_locals), dim=2)
        #local = self.localShrink(local)
        #new_regions = self.localRegion(local.view(data.size(0), 4, 8, 16))

        x1 = self.regionCapsSpatial_1(data.view(data.size(0), 4, 4, 192)).view(data.size(0), 4, 16)
        x2 = self.regionCapsSpatial_2(data.view(data.size(0), 4, 4, 192)).view(data.size(0), 4, 16)
        regions = torch.cat((x1, x2), dim=2).view(data.size(0), 4, 8, 4)
        regions = self.regionRouting(regions) #shape (4, 4, 16)
        regionsAlign = self.region_align(regions).view(data.size(0), 1, 16, 8)
        deep_regions = self.regionDeepSuper(regions.view(data.size(0), 1, 16, 16))
        #regions = torch.cat((regions, new_regions), dim=2)
        #regions = self.regionShrink(regions)
        #new_hemis = self.regionHemi(regions.view(data.size(0), 2, 8, 16))

        x1 = self.hemiCapsSpatial_1(data.view(data.size(0), 2, 8, 192)).view(data.size(0), 2, 16)
        x2 = self.hemiCapsSpatial_2(data.view(data.size(0), 2, 8, 192)).view(data.size(0), 2, 16)
        hemis = torch.cat((x1, x2), dim=2).view(data.size(0), 2, 8, 4)
        hemis = self.hemiRouting(hemis) #shape (2, 4, 16)
        hemisAlign = self.hemi_align(hemis).view(data.size(0), 1, 16, 8)
        #hemisDeep = self.hemiDeepSuper(hemis.view(data.size(0), 1, 64, 8))
        #hemis = torch.cat((hemis, new_hemis), dim=2)
        #hemis = self.hemiShrink(hemis).view(data.size(0), 1, 8, 16)

        #First trying with only one layer between convolution and output
        #Essentially no layer combining and of the separate CNN information

        concatenated = torch.cat((channelsAlign, localsAlign, regionsAlign, hemisAlign), dim=2)
        out = self.out(concatenated)


        #out = self.out(hemis)
        out = out.squeeze(1)

        pred = self.capsLen(out)

        if mode == "train":
            masked = self.mask(out, y_true)
            deep_channels = deep_channels.squeeze(1)
            deep_locals = deep_locals.squeeze(1)         #removing these to also remove residuals
            deep_regions = deep_regions.squeeze(1)
            masked_channels = self.mask(deep_channels, y_true)
            masked_locals = self.mask(deep_locals, y_true)
            masked_regions = self.mask(deep_regions, y_true)
        elif mode == "eval":
            masked = self.mask(out)
            x = self.generator(masked)
            return pred, x
        elif mode == "test":
            return pred
        x = self.generator(masked)
        x_channels = self.channel_generator(masked_channels)
        x_locals = self.local_generator(masked_locals)
        x_regions = self.region_generator(masked_regions)
        pred_channels = self.capsLen(deep_channels)
        pred_locals = self.capsLen(deep_locals)
        pred_regions = self.capsLen(deep_regions)


        return pred, x, pred_regions, x_regions, pred_locals, x_locals, pred_channels, x_channels

# **Training**

Below is the custom loss function for training the model

In [None]:
# @title
class MarginLoss(nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lambda_=0.5):
        super(MarginLoss, self).__init__()
        self.m_pos = m_pos
        self.m_neg = m_neg
        self.lambda_ = lambda_

    def forward(self, targets, digit_probs):
        assert targets.shape is not digit_probs.shape
        present_losses = (
            targets * torch.clamp_min(self.m_pos - digit_probs, min=0.0) ** 2
        )
        absent_losses = (1 - targets) * torch.clamp_min(
            digit_probs - self.m_neg, min=0.0
        ) ** 2
        losses = present_losses + self.lambda_ * absent_losses
        return torch.mean(torch.sum(losses, dim=1))


class ReconstructionLoss(nn.Module):
    def forward(self, reconstructions, input_images):
        if reconstructions[0].dtype not in [torch.float32, torch.float64]:
            magnitude_recon = torch.abs(reconstructions)
            magnitude_input = torch.abs(input_images)
            phase_recon = torch.angle(reconstructions)
            phase_input = torch.angle(input_images)
            return torch.nn.MSELoss(reduction="mean")(magnitude_recon, magnitude_input) + torch.nn.MSELoss(reduction="mean")(phase_recon, phase_input)
        return torch.nn.MSELoss(reduction="mean")(reconstructions, input_images)


class TotalLoss(nn.Module):
    def __init__(self, m_pos=0.9, m_neg=0.1, lambda_=0.5, recon_factor=0.0005):
        super(TotalLoss, self).__init__()
        self.margin_loss = MarginLoss(m_pos, m_neg, lambda_)
        self.recon_loss = ReconstructionLoss()
        self.recon_factor = recon_factor

    def forward(self, input_images, targets, reconstructions, digit_probs):
        margin = self.margin_loss(targets.squeeze(), digit_probs)
        recon = self.recon_loss(reconstructions, input_images)
        return margin + self.recon_factor * recon


Run this training loop. It may take a while even with the T4 Runtime!

In [None]:
!pip install --upgrade sympy

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
deep_super_weights = [1/(1.5**i) for i in range(4)]
deep_super_weights = deep_super_weights / np.sum(deep_super_weights)
model = EEGCapsNet().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0005)
criterion = TotalLoss()
num_epochs = 10

train_losses = []
val_losses = []
accs = []
max_acc = 0
for epoch in range(num_epochs):
    model.train()
    pbar_batch = tqdm(range(len(train_dl)))
    total_loss = 0
    for batch, (data, labels) in enumerate(train_dl):
        correct_train = 0
        total_train = 0
        x_fft = torch.fft.fft(data, dim=-1).to(device)
        optimizer.zero_grad()
        data = data.unsqueeze(1).to(device)
        labels = labels.to(device)
        outs = model(data)
        full_loss = 0
        for k in range(0, 8, 2):
            loss = criterion(x_fft, labels, outs[k+1], outs[k])
            loss *= deep_super_weights[k // 2]
            full_loss += loss
        full_loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), max_norm = 5)
        optimizer.step()
        total_loss += full_loss.item()
        predicted = torch.argmax(outs[0], -1)
        labels = torch.argmax(labels, -1)
        correct_train += (predicted == labels).float().sum().item()
        total_train += labels.shape[0]
        pbar_batch.set_description(f"Batch {batch + 1}    loss={total_loss / (batch + 1):0.4f}      accuracy={correct_train/total_train:04f}")
        pbar_batch.update(1)
    pbar_batch.close()
    train_losses.append(total_loss / len(train_dl))
    total_val_loss = 0.0
    total_accuracy = 0.0
    # set model to evaluation mode, which changes the behavior
    # of some layers like dropout and batch normalization
    model.eval()
    with torch.no_grad():
        pbar = tqdm(total=len(valid_dl))
        for j, (x, y) in enumerate(valid_dl):
            x = x.float().cuda()
            y = y.cuda()
            x = x.unsqueeze(1)
            outs = model(x, mode='eval')
            pred = outs[0]
            img = outs[1]
            y = y.squeeze(1)
            x_fft = torch.fft.fft(x, dim=-1)
            loss = criterion(x_fft, y, img, pred)
            total_val_loss += loss.item()
            predicted = torch.argmax(pred, -1)
            labels = torch.argmax(y, -1)
            accuracy = (predicted == labels).float().mean().item()
            total_accuracy += accuracy
            pbar.set_description(f"val loss={total_val_loss / (j + 1):0.4f}    val acc={total_accuracy / (j + 1):0.4f}")
            pbar.update(1)
        pbar.close()
        val_losses.append(total_val_loss/len(valid_dl))
        accs.append(total_accuracy/len(valid_dl))
    # save model if accuracy is best seen
    if accs[-1] > max_acc:
      # dictionary with model state dict, optimizer state dict, and best accuracy
      checkpoint = {'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'best_acc': accs[-1]}
      # save dictionary to specified file path if it exists or create new one otherwise
      torch.save(checkpoint, 'SepConv2.pth.tar')
      print('Model Saved')
      max_acc = accs[-1]
#pbar_epoch.close()

  0%|          | 0/3557 [00:00<?, ?it/s]

  return F.mse_loss(input, target, reduction=self.reduction)


  0%|          | 0/553 [00:00<?, ?it/s]

Model Saved


  0%|          | 0/3557 [00:00<?, ?it/s]

  0%|          | 0/553 [00:00<?, ?it/s]

Model Saved


  0%|          | 0/3557 [00:00<?, ?it/s]

  0%|          | 0/553 [00:00<?, ?it/s]

Model Saved


  0%|          | 0/3557 [00:00<?, ?it/s]

  0%|          | 0/553 [00:00<?, ?it/s]

Model Saved


  0%|          | 0/3557 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:

plt.figure()
plt.plot(train_losses, label = 'train loss')
plt.plot(val_losses, label = 'val loss')
plt.legend()
plt.ylabel('Categorical Crossentropy')
plt.xlabel('Epoch')
plt.title('Loss over Epochs')


plt.figure()
plt.plot(accs)
plt.ylabel('Accuracy')
plt.xlabel('Epoch')

Interpretability


In [None]:
!pip install captum
from captum.attr import LayerGradCam, LayerAttribution

def visualize_importances(feature_names, importances, title="Average Feature Importances", plot=True, axis_title="Features"):
    print(title)
    for i in range(len(feature_names)):
        print(feature_names[i], ": ", '%.3f'%(importances[i]))
    x_pos = (np.arange(len(feature_names)))
    if plot:
        plt.figure(figsize=(12,6))
        plt.bar(x_pos, importances, align='center')
        plt.xticks(x_pos, feature_names, wrap=True)
        plt.xlabel(axis_title)
        plt.title(title)
conductance = LayerConductance(model, model.out)
#test_input_tensor = torch.from_numpy(test_features).type(torch.FloatTensor)
#conductance_vals = conductance.attribute(data.to(device), target=1).detach().numpy()
#visualize_importances(range(12),np.mean(cond_vals, axis=0),title="Average Neuron Importances", axis_title="Neurons")


# Assuming 'data' is your EEG input tensor
layer_gradcam = LayerGradCam(model, model.out)

attr = layer_gradcam.attribute(data.to(device), target=1, additional_forward_args=(labels.to(device), 'test'))
attr = attr.squeeze()

In [None]:
# Assuming 'attr' contains the attributions and has shape (channels, time)
attr_abs = torch.abs(attr) # taking the absolute values
attr_mean = torch.mean(attr_abs, dim=1) #Averaging over the channels to represent the attributions as a single scalar value per input data point.
# assuming attr_mean has shape (batch_size, height, width)
plt.plot(attr_mean.cpu().detach().numpy())
plt.xlabel('Data Point Index')
plt.ylabel('Attribution')
plt.title('LayerGradCAM Attributions')


#INPROGRESS: get the color key to work
# Get the colormap
cmap = plt.cm.get_cmap('viridis')

# Define the number of color patches in the key
num_patches = 5

# Generate color patches and value ranges
values = np.linspace(attr_mean.min().item(), attr_mean.max().item(), num_patches)
colors = cmap(np.linspace(0, 1, num_patches))

# Create a custom legend with color patches and labels
patches = [
    plt.Rectangle((0, 0), 1, 1, color=colors[i], label=f'{values[i]:.2f} - {values[i+1]:.2f}' if i < num_patches - 1 else f'{values[i]:.2f}+')
    for i in range(num_patches)
]

plt.legend(handles=patches, loc='best', title='Attribution Values')

plt.show()
# for i in range(attr_mean.shape[0]):
#     plt.figure()
#     plt.imshow(attr_mean[i].cpu().detach().numpy(), cmap='viridis', aspect='auto', interpolation='nearest')
#     plt.colorbar(label='Attribution')
#     plt.xlabel('Time')
#     plt.ylabel('Channels')
#     plt.title(f'LayerGradCAM Attributions for Sample {i}')
#     plt.show()
