In [None]:
import numpy as np
import glob
import os
import pandas as pd
import scipy.signal as signal
import mne
import random
import torchvision
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.datasets import DatasetFolder

from mne import preprocessing, Epochs
import utils
import matplotlib.pyplot as plt

## Data pre-processing

### Band pass filtering and events

In [None]:
eeg_file_path = 'data/eeg_data_A/'
eeg_training_files = glob.glob('data/eeg_data_A/A0*T.gdf')

#eeg_training_files = random.sample(eeg_training_files,2)

eeg_eval_files = glob.glob(os.path.join(eeg_file_path, 'A0*E.gdf'))

eeg_train_obj, epoch_train_obj = utils.band_pass_filter(eeg_training_files)
eeg_eval_obj, epoch_eval_obj = utils.band_pass_filter(eeg_eval_files)


In [None]:
eeg_data = utils.raw_to_tensor(eeg_train_obj)

eeg_test_data = utils.raw_to_tensor(eeg_eval_obj)

#eeg_data = np.transpose(eeg_data, (2,1,0))
#eeg_data = np.expand_dims(eeg_data, axis=1)

eeg_test_set = np.transpose(eeg_test_data, (2,1,0))
#eeg_test_set = np.expand_dims(eeg_test_set, axis=1)

split_size = 1000  # Define your desired split size
smaller_tensors = []

for tensor in eeg_data:  # Assuming eeg_data is your original dataset
    splits = utils.split_tensor(tensor, split_size)
    smaller_tensors.extend(splits)

print(len(smaller_tensors))

In [None]:

class EEGDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
    
eeg_data= EEGDataset(smaller_tensors)

for tensor in eeg_data:
    tensor = utils.time_shift(tensor, shift=10)
    tensor = utils.add_noise(tensor, noise_level=0.1)
    tensor = utils.time_warp(tensor, factor=0.8)

In [None]:
max_length = max(tensor.shape[0] for tensor in eeg_data)
"""
padded_tensors = []
for tensor in eeg_data:
    padding_size = max_length - tensor.shape[0]
    if padding_size > 0:
        padded_tensor = torch.nn.functional.pad(tensor, (0, 0, 0, 0, 0, padding_size))
    else:
        padded_tensor = tensor
    padded_tensors.append(padded_tensor)

    split_ratio = 0.66  
"""
# Use the split_data function from utils.py
##eeg_train_set, eeg_val_set = utils.split_data(eeg_data, eeg_training_files, .8)


In [None]:


eeg_train_set, eeg_val_set = random_split(eeg_data, [2500, 224]) 


print("Number of items in training set:", len(eeg_train_set))
print("Shape of first item in training set:", eeg_train_set[0].shape)

print("Number of items in validation set:", len(eeg_val_set))
print("Shape of first item in validation set:", eeg_val_set[0].shape)

In [None]:


BATCH_SIZE = 70

train_loader = DataLoader(eeg_train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(eeg_val_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(eeg_test_data, batch_size=BATCH_SIZE, shuffle=False)

for i, batch in enumerate(train_loader):
    # Print the batch contents
    print(f"Batch {i}:")
    print(batch)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embed_size = 25
nhead = 5  
num_layers = 1  

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 40, (1, 4), (1, 1))
        self.conv2 = nn.Conv2d(40, 40, (25, 1), (1, 1))
        self.elu1 = nn.ELU()

        self.flatten = nn.Flatten()
        # Assuming the width after conv1 is 1000 - 4 + 1 = 997
        # After conv2, the height is 1
        output_size = 40 * 1 * 997  # 40 channels, height 1, width 997
        self.fc1 = nn.Linear(output_size, embed_size)  # Adjusted for the correct input size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.unsqueeze(1) #becomes [N, 1, 25, 1000]
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.elu1(x)
        x = self.flatten(x)
        x = self.fc1(x)
        return x


In [None]:


#transformer = nn.Transformer(d_model=d_model, nhead=nhead).to(device)

conv_net = ConvNet().to(device)
encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=nhead).to(device)
decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=nhead).to(device)


transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers).to(device)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers).to(device)
model_params = list(transformer_encoder.parameters()) + list(transformer_decoder.parameters())


In [None]:

criterion = nn.MSELoss()    
optimizer = optim.SGD(model_params, 
                        lr=0.001, 
                        momentum=0.9)

losses = []
n = 1
for epoch in range(100):
    train_loss = 0.0
    #transformer.train()

    for batch in train_loader:
        #print(f'Starting batch {n}')   
        src_data = batch
        src_data = src_data.cuda()
        # print(src_data.shape)
        #src_data = src_data.permute(2, 0, 1)
        #src_data = src_data.flatten(start_dim=1)
        optimizer.zero_grad()
        
        #print(src_data.shape) 
        #print(src_data)
        #print("ConvNet...")
        src_data = conv_net(src_data)
        #print(src_data.shape)
        #print("Encoding...")
        memory = transformer_encoder(src_data)
        #print("Decoding...")
        out_batch = transformer_decoder(src_data, memory)
        loss = criterion(out_batch, src_data)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        n += 1
        #losses.append(loss.item())
    
    val_loss = 0.0
    for batch in val_loader:
        src_data = batch
        src_data = src_data.cuda()
        src_data = conv_net(src_data)
    

        memory = transformer_encoder(src_data)
        #print("Validation Decoding...")
        out_batch = transformer_decoder(src_data, memory)
        loss = criterion(out_batch, src_data)
       
        val_loss += loss.item()

   

    print("Epoch: {} Train Loss: {} Val Loss: {}".format(
                  epoch, 
                  train_loss/len(train_loader), 
                  val_loss/len(val_loader)))

plt.plot(losses)
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.title('Loss per Batch')
plt.show()


Attetnion Block

Transformer Block

Encoder