In [15]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
'/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/BMEN4420-EEG-Classifier-Repo'

In [88]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, SGD
import random
import scipy.io as scio
from torch.utils.data import ConcatDataset, Dataset, DataLoader, random_split, RandomSampler
import numpy as np


#from Models.Transformer import TransformerModel
#from Models.PositionalEncoding import LearnedPositionalEncoding


In [83]:
# CHECK GPU RESOURCES
cuda = torch.cuda.is_available()
print("GPU available:", cuda)

torch.manual_seed(4460)# you don't have to set random seed beyond this block
np.random.seed(4460)

GPU available: False


In [17]:
sub01 = scio.loadmat('/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/Signal_Processing_FC/Signal_Processing_FC/Subject_1.mat')
sub02 = scio.loadmat('/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/Signal_Processing_FC/Signal_Processing_FC/Subject_2.mat')
sub03 = scio.loadmat('/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/Signal_Processing_FC/Signal_Processing_FC/Subject_3.mat')
sub04 = scio.loadmat('/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/Signal_Processing_FC/Signal_Processing_FC/Subject_4.mat')
sub05 = scio.loadmat('/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/Signal_Processing_FC/Signal_Processing_FC/Subject_5.mat')
sub06 = scio.loadmat('/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/Signal_Processing_FC/Signal_Processing_FC/Subject_6.mat')
sub07 = scio.loadmat('/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/Signal_Processing_FC/Signal_Processing_FC/Subject_7.mat')
sub08 = scio.loadmat('/content/drive/MyDrive/Columbia Spring 2023/Signal Modeling/Project-EEG-Classifier/Signal_Processing_FC/Signal_Processing_FC/Subject_8.mat')
data = {'sub01':sub01,'sub02':sub02,'sub03':sub03,'sub04':sub04,'sub05':sub05,'sub06':sub06,'sub07':sub07,'sub08':sub08}


In [103]:
class EEGData():
  def __init__(self, sample, label):
    self.x = sample
    self.y = label
    self.indices = list(range(np.size(self.y,0)))
  def __getitem__(self, index):
    return self.x[self.indices[index]], self.y[self.indices[index]]
  def shuffle(self):
    random.shuffle(self.indices)
  def __len__(self):
    return (np.size(self.y,0))

In [93]:
class EEGPT(nn.Module):
  def __init__(
      self,
      eeg_channels = 60,
      time_len = 1200
               ):
    super(EEGPT,self).__init__()
    # BUILD SPATIAL PATH
    ## CNN MODULE
    self.Conv1_s = nn.Conv1d(in_channels=eeg_channels, out_channels=eeg_channels, kernel_size=17, stride=1, padding="same")
    self.AvgPool1_s = nn.AvgPool1d(kernel_size=32,stride=32)
    self.Conv2_s = nn.Conv1d(in_channels=eeg_channels,out_channels=eeg_channels,kernel_size=15,stride=1,padding="valid") # output should be 
    ## TRANSFORMER MODULE
    self.PosEnc1_s = PositionalEncoder(embedding_dim=eeg_channels,max_length=1000)
    self.Transf1_s = EncoderTransformer(inSize=eeg_channels,outSize=4,numLayers=3,hiddenSize=1,numHeads=6,dropout=0.01)

    # BUILD TEMPORAL PATH
    # CNN MODULE
    self.dwconv1_t = nn.Conv1d(in_channels=eeg_channels,out_channels=eeg_channels, kernel_size=eeg_channels, stride=1, groups = eeg_channels, bias=False, padding="same")
    self.AvgPool1_t = nn.AvgPool2d(kernel_size=8)    
    # TRANSFORMER MODULE
    self.PosEnc1_t = PositionalEncoder(embedding_dim=eeg_channels,max_length=1000)
    self.Transf1_t = EncoderTransformer(inSize=time_len,outSize=4,numLayers=3,hiddenSize=1,numHeads=6,dropout=0.01)
    # Build Fully Connected Path
    self.fc1 = nn.Linear(8,2)

  def forward(self, x):
    # Spatial Pass
    x_s = self.Conv1_s(x_s)
    x_s = self.AvgPool1_s(x_s)
    x_s = self.Conv2_s(x_s)
    x_s = self.PosEnc1_s(x_s)
    x_s = self.Transf1_s(x_s)
    # Temporal Pass
    x_t = self.dwconv1_t(x)
    x_t = self.AvgPool1_t(x_t)
    x_t = x_t.permute(0,2,1) # transpose to present time wise vectors to transformer encoder    
    x_t = self.PosEnc1_t(x_t)
    x_t = self.Transf1_t(x_t)
    # Concatenation
    x_cat = torch.cat(x_s, x_t)
    # Output Pass: Fully Connected into Softmax
    x = self.fc1(x_cat)
    x = F.log_softmax(x)
    return x

class EncoderTransformer():
  def __init__(self, inSize, outSize, numLayers=3, hiddenSize=1, numHeads=8, dropout=0.01):
    self.encoderLayer = nn.TransformerEncoderLayer(d_model=inSize, nhead=numHeads, dim_feedforward=hiddenSize, dropout=dropout)
    self.encoder = nn.TransformerEncoder(self.encoderLayer,num_layers=numLayers)
    self.fc1 = nn.Linear(outSize, outSize)
  def forward(self, x):
    x = self.encoder(x)
    x = self.fc1(x)
    return x

class PositionalEncoder(nn.Module):
  def __init__(self, embedding_dim, max_length=1000):
    super(PositionalEncoder,self).__init__()
    pe = torch.zeros(max_length, embedding_dim)
    position = torch.arange(0, max_length,dtype=float).unsqueeze(1)
    div_term = torch.exp(
        torch.arange(0, embedding_dim, 2).float()
        * (-torch.log(torch.tensor(10000.0))/embedding_dim)
    )
    pe[:,0::2] = torch.sin(position * div_term)
    pe[:,1::2] = torch.cos(position * div_term)
    pe.unsqueeze(0).transpose(0,1)
    self.register_buffer('pe',pe)

  def forward(self, x):
    return x + self.pe[:x.size(0),:]



In [104]:
# COMPOSE MEGA DATASET FROM ALL SUBJECT TENSORS
numSets = 8
megaSet = []
for i in range(numSets):
  subSetX = data[('sub0'+str(i+1))]['X_EEG_TRAIN']
  subSetY = data[('sub0'+str(i+1))]['Y_EEG_TRAIN']
  #print(np.size(subSetY,0))
  for j in range(np.size(subSetY,0)):    
    subx = subSetX[:,:,j]    
    suby = subSetY[j,:]
    miniSet = EEGData(subx,suby)
    megaSet.append(miniSet)
    
    # DEBUGGING PRINTS
    #print(np.size(subSetY,0))
    #print(np.shape(subSetX))
    #print(np.shape(subSetY))
    #print(miniSet.__len__())

MegaSet = ConcatDataset(megaSet)
#MegaSet = RandomSampler(MegaSet)


# Load Dataset using EEGData and Dataloader
trainset, validset = random_split(MegaSet,[458, 115])
trainloader = DataLoader(trainset,batch_size=3,shuffle=True)
validloader = DataLoader(validset,batch_size=3,shuffle=True)

In [105]:
# Build Model
eegpt = EEGPT(eeg_channels=60, time_len=1200)

# Call Optimizer
adam = Adam(eegpt.parameters(),lr=0.001)

In [106]:
# COUNT MODEL PARAMETERS
param_count = 0;
for param in eegpt.parameters():
    param_count += param.numel()

print('number of model params: ', param_count)

number of model params:  118938


In [107]:
# MODEL TRAINING
EPOCHS = 25
train_epoch_loss = list()
validation_epoch_loss = list()
for epoch in range(EPOCHS):
  train_loss = list()
  valid_loss = list()
  eegpt.train() # put model in train mode
  for batch_index, (sample, label) in enumerate(trainloader):
    if cuda:
      train_pred = eegpt(sample.cuda())
      # calculate loss
      loss_fun = nn.CrossEntropyLoss()
      loss = loss_fun(train_pred, label.cuda())
      train_loss.append(loss.cpu().data.item())
      # reset gradient
      adam.zero_grad()
      # back propagation
      loss.backward()
      # Update parameters
      adam.step()
train_epoch_loss.append(np.mean(train_loss))

  return _methods._mean(a, axis=axis, dtype=dtype,
  ret = ret.dtype.type(ret / rcount)


In [None]:
#@title TABS REFERENCE

class up_conv_3D(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(up_conv_3D, self).__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor = 2),
            nn.Conv3d(ch_in, ch_out, kernel_size = 3, stride = 1, padding = 1, bias = True),
            nn.GroupNorm(8, ch_out),
            # nn.BatchNorm3d(ch_out),
            nn.ReLU(inplace = True)
        )

    def forward(self,x):
        x = self.up(x)
        return x


class conv_block_3D(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(conv_block_3D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(ch_in, ch_out, kernel_size = 3, stride = 1, padding = 1, bias = True),
            nn.GroupNorm(8, ch_out),
            nn.ReLU(inplace = True),
            nn.Conv3d(ch_out, ch_out, kernel_size = 3, stride = 1, padding = 1, bias = True),
            nn.GroupNorm(8, ch_out),
            nn.ReLU(inplace = True)
        )

    def forward(self,x):
        x = self.conv(x)
        return x

class resconv_block_3D(nn.Module):
    def __init__(self, ch_in, ch_out):
        super(resconv_block_3D, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(ch_in, ch_out, kernel_size = 3, stride = 1, padding = 1, bias = True),
            nn.GroupNorm(8, ch_out),
            nn.ReLU(inplace = True),
            nn.Conv3d(ch_out, ch_out, kernel_size = 3, stride = 1, padding = 1, bias = True),
            nn.GroupNorm(8, ch_out),
            nn.ReLU(inplace = True)
        )
        self.Conv_1x1 = nn.Conv3d(ch_in, ch_out, kernel_size = 1, stride = 1, padding = 0)

    def forward(self,x):

        residual = self.Conv_1x1(x)
        x = self.conv(x)
        return residual + x

# Can add squeeze excitation layers if you want to try that as well.
class ChannelSELayer3D(nn.Module):
    """
    3D extension of Squeeze-and-Excitation (SE) block described in:
        *Hu et al., Squeeze-and-Excitation Networks, arXiv:1709.01507*
        *Zhu et al., AnatomyNet, arXiv:arXiv:1808.05238*
    """

    def __init__(self, num_channels, reduction_ratio=8):
        """
        :param num_channels: No of input channels
        :param reduction_ratio: By how much should the num_channels should be reduced
        """
        super(ChannelSELayer3D, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool3d(1)
        num_channels_reduced = num_channels // reduction_ratio
        self.reduction_ratio = reduction_ratio
        self.fc1 = nn.Linear(num_channels, num_channels_reduced, bias=True)
        self.fc2 = nn.Linear(num_channels_reduced, num_channels, bias=True)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, input_tensor):
        """
        :param input_tensor: X, shape = (batch_size, num_channels, D, H, W)
        :return: output tensor
        """
        batch_size, num_channels, D, H, W = input_tensor.size()
        # Average along each channel
        squeeze_tensor = self.avg_pool(input_tensor)

        # channel excitation
        fc_out_1 = self.relu(self.fc1(squeeze_tensor.view(batch_size, num_channels)))
        fc_out_2 = self.sigmoid(self.fc2(fc_out_1))

        output_tensor = torch.mul(input_tensor, fc_out_2.view(batch_size, num_channels, 1, 1, 1))

        return output_tensor

class TABS(nn.Module):
    def __init__(
        self,
        img_dim = 192,
        patch_dim = 8,
        img_ch = 1,
        output_ch = 3,
        embedding_dim = 512,
        num_heads = 8,
        num_layers = 4,
        hidden_dim = 1728,
        dropout_rate = 0.1,
        attn_dropout_rate = 0.1,
        ):
        super(TABS,self).__init__()

        self.Maxpool = nn.MaxPool3d(kernel_size=2,stride=2)

        self.Conv1 = resconv_block_3D(ch_in=img_ch,ch_out=8)

        self.Conv2 = resconv_block_3D(ch_in=8,ch_out=16)

        self.Conv3 = resconv_block_3D(ch_in=16,ch_out=32)

        self.Conv4 = resconv_block_3D(ch_in=32,ch_out=64)

        self.Conv5 = resconv_block_3D(ch_in=64,ch_out=128)

        self.Up5 = up_conv_3D(ch_in=128,ch_out=64)
        self.Up_conv5 = resconv_block_3D(ch_in=128, ch_out=64)

        self.Up4 = up_conv_3D(ch_in=64,ch_out=32)
        self.Up_conv4 = resconv_block_3D(ch_in=64, ch_out=32)

        self.Up3 = up_conv_3D(ch_in=32,ch_out=16)
        self.Up_conv3 = resconv_block_3D(ch_in=32, ch_out=16)

        self.Up2 = up_conv_3D(ch_in=16,ch_out=8)
        self.Up_conv2 = resconv_block_3D(ch_in=16, ch_out=8)

        self.Conv_1x1 = nn.Conv3d(8,output_ch,kernel_size=1,stride=1,padding=0)
        self.gn = nn.GroupNorm(8, 128)
        self.relu = nn.ReLU(inplace=True)

        self.num_patches = int((img_dim // patch_dim) ** 3)
        self.seq_length = self.num_patches
        self.flatten_dim = 128 * img_ch

        self.position_encoding = LearnedPositionalEncoding(
            self.seq_length, embedding_dim, self.seq_length
        )

        self.act = nn.Softmax(dim=1)

        self.reshaped_conv = conv_block_3D(512, 128)

        self.transformer = TransformerModel(
            embedding_dim,
            num_layers,
            num_heads,
            hidden_dim,

            dropout_rate,
            attn_dropout_rate,
        )

        self.conv_x = nn.Conv3d(
            128,
            embedding_dim,
            kernel_size=3,
            stride=1,
            padding=1
            )

        self.pre_head_ln = nn.LayerNorm(embedding_dim)

        self.img_dim = 192
        self.patch_dim = 8
        self.img_ch = 1
        self.output_ch = 3
        self.embedding_dim = 512

    def forward(self,x):
        # encoding path
        x1 = self.Conv1(x)

        x2 = self.Maxpool(x1)
        x2 = self.Conv2(x2)

        x3 = self.Maxpool(x2)
        x3 = self.Conv3(x3)

        x4 = self.Maxpool(x3)
        x4 = self.Conv4(x4)

        x5 = self.Maxpool(x4)
        x = self.Conv5(x5)

        x = self.gn(x)
        x = self.relu(x)
        x = self.conv_x(x)

        x = x.permute(0, 2, 3, 4, 1).contiguous()
        x = x.view(x.size(0), -1, self.embedding_dim)

        x = self.position_encoding(x)

        x, intmd_x = self.transformer(x)
        x = self.pre_head_ln(x)

        encoder_outputs = {}
        all_keys = []
        for i in [1, 2, 3, 4]:
            val = str(2 * i - 1)
            _key = 'Z' + str(i)
            all_keys.append(_key)
            encoder_outputs[_key] = intmd_x[val]
        all_keys.reverse()

        x = encoder_outputs[all_keys[0]]
        x = self._reshape_output(x)
        x = self.reshaped_conv(x)

        d5 = self.Up5(x)
        d5 = torch.cat((x4,d5),dim=1)
        d5 = self.Up_conv5(d5)

        d4 = self.Up4(d5)
        d4 = torch.cat((x3,d4),dim=1)
        d4 = self.Up_conv4(d4)

        d3 = self.Up3(d4)
        d3 = torch.cat((x2,d3),dim=1)
        d3 = self.Up_conv3(d3)

        d2 = self.Up2(d3)
        d2 = torch.cat((x1,d2),dim=1)
        d2 = self.Up_conv2(d2)

        d1 = self.Conv_1x1(d2)

        d1 = self.act(d1)

        return d1

    def _reshape_output(self, x):
        x = x.view(
            x.size(0),
            int(self.img_dim//2 / self.patch_dim),
            int(self.img_dim//2 / self.patch_dim),
            int(self.img_dim//2 / self.patch_dim),
            self.embedding_dim,
        )
        x = x.permute(0, 4, 1, 2, 3).contiguous()

        return x
