In [None]:
import numpy as np
import glob
import os
import pandas as pd
import scipy.signal as signal
import mne
import torchvision
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision.datasets import DatasetFolder

from mne import preprocessing, Epochs
import utils
import matplotlib.pyplot as plt

## Data pre-processing

### Band pass filtering and events

In [None]:
eeg_file_path = 'data/eeg_data_A/'
eeg_training_files = glob.glob('data/eeg_data_A/A0*T.gdf')
eeg_eval_files = glob.glob(os.path.join(eeg_file_path, 'A0*E.gdf'))

eeg_train_obj, epoch_train_obj = utils.band_pass_filter(eeg_training_files)
eeg_eval_obj, epoch_eval_obj = utils.band_pass_filter(eeg_eval_files)

eeg_data = utils.raw_to_tensor(eeg_train_obj)
eeg_test_data = utils.raw_to_tensor(eeg_eval_obj)

eeg_data = np.transpose(eeg_data, (2,1,0))
#eeg_data = np.expand_dims(eeg_data, axis=1)

eeg_test_set = np.transpose(eeg_test_data, (2,1,0))
#eeg_test_set = np.expand_dims(eeg_test_set, axis=1)

In [None]:
max_length = max(tensor.shape[0] for tensor in eeg_data)

padded_tensors = []
for tensor in eeg_data:
    padding_size = max_length - tensor.shape[0]
    if padding_size > 0:
        padded_tensor = torch.nn.functional.pad(tensor, (0, 0, 0, 0, 0, padding_size))
    else:
        padded_tensor = tensor
    padded_tensors.append(padded_tensor)

    split_ratio = 0.8  
print(padded_tensor)
# Use the split_data function from utils.py
eeg_train_set, eeg_val_set = utils.split_data(padded_tensor, eeg_training_files, split_ratio)


In [None]:

BATCH_SIZE = 1

train_loader = DataLoader(eeg_train_set, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(eeg_val_set, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(eeg_test_data, batch_size=BATCH_SIZE, shuffle=False)



In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embed_size = 4
nhead = 2     
num_layers = 6  

#transformer = nn.Transformer(d_model=d_model, nhead=nhead).to(device)

encoder_layer = nn.TransformerEncoderLayer(d_model=embed_size, nhead=nhead, batch_first=True).to(device)
decoder_layer = nn.TransformerDecoderLayer(d_model=embed_size, nhead=nhead, batch_first=True).to(device)


transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers).to(device)
transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers).to(device)
model_params = list(transformer_encoder.parameters()) + list(transformer_decoder.parameters())


In [None]:

criterion = nn.MSELoss()    
optimizer = optim.SGD(model_params, 
                        lr=0.001, 
                        momentum=0.9)

losses = []
#batch_size = 1
for epoch in range(20000):
    train_loss = 0.0
    #transformer.train()

    for batch in train_loader:
        
        src_data = batch
        src_data = src_data.cuda()

        optimizer.zero_grad()
       # src_data = src_data.flatten(start_dim=2) 
       # src_data = src_data.permute(2, 0, 1) 
        memory = transformer_encoder(src_data)
        out_batch = transformer_decoder(src_data, memory)
        loss = criterion(out_batch, src_data)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
        #losses.append(loss.item())
    
    val_loss = 0.0
    for batch in val_loader:
        src_batch = batch
        src_batch = src_batch.cuda()
        memory = transformer_encoder(src_batch)
        out_batch = transformer_decoder(src_data, memory)
        loss = criterion(out_batch, src_data)
       
        val_loss += loss.item()
        
    print("Epoch: {} Train Loss: {} Val Loss: {}".format(
                  epoch, 
                  train_loss/len(train_loader), 
                  val_loss/len(val_loader)))

plt.plot(losses)
plt.xlabel('Batch')
plt.ylabel('Loss')
plt.title('Loss per Batch')
plt.show()


Attetnion Block

Transformer Block

Encoder