In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
!pip install torch
!pip install torchsummary
!pip install torchvision
!pip install scipy
!pip install einops
!pip install transformers
!pip install transformers[torch]

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 (from torch)
  Using cached nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)
Collectin

In [None]:
import argparse
import os
gpus = [0]
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(map(str, gpus))
import numpy as np
import math
import scipy.io
import random

import torchvision.transforms as transforms
from torchvision.utils import save_image, make_grid

from torch.utils.data import DataLoader
from torch.autograd import Variable
from torchsummary import summary
import torch.autograd as autograd
from torchvision.models import vgg19

import torch.nn as nn
import torch.nn.functional as F
import torch
import torch.nn.init as init

from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from sklearn.decomposition import PCA

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor
from PIL import Image
from torchvision.transforms import Compose, Resize, ToTensor
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
# from common_spatial_pattern import csp

import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True

# writer = SummaryWriter('./TensorBoardX/')


from transformers import Trainer, TrainingArguments
from torch.utils.data import Dataset
import torch
import numpy as np
import os
import scipy.io

from torch.nn import MSELoss


**Model**

In [None]:
"""
EEG Conformer

Convolutional Transformer for EEG decoding

Couple CNN and Transformer in a concise manner with amazing results
"""
# remember to change paths



# Convolution module
# use conv to capture local features, instead of postion embedding.
class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        # self.patch_size = patch_size
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(4, 40, (2, 2), (1, 1)),
            nn.Conv2d(40, 40, (22, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            # nn.AvgPool2d((1,75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.AvgPool2d((1, 3), (1, 1)),
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
            Rearrange('b e (h) (w) -> b (h w) e'),
        )


    def forward(self, x: Tensor) -> Tensor:
        #print("input shape = ",x.shape)
        # b, _, _ = x.shape
        #print(x)
        x = self.shallownet(x)
        x = self.projection(x)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class GELU(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])

class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size,b):
        super().__init__()
        self.regression_head = nn.Sequential(
                Reduce('b n e -> b e', reduction='mean'),
                nn.LayerNorm(emb_size),
                nn.Linear(emb_size, b)  # Output a single value for regression
        )

    def forward(self, x):
        out = self.regression_head(x)
        return out  # No need to return x, out

class Conformer(nn.Module):
    def __init__(self, batch = 4,emb_size=40, depth=6, n_classes=4, **kwargs):
        super(Conformer, self).__init__()
        self.patch_embedding = PatchEmbedding(emb_size)
        self.transformer_encoder = TransformerEncoder(depth, emb_size)
        self.classification_head = ClassificationHead(emb_size,batch)

    def forward(self, input, label=None):
        # print("input.shape = ", input.shape)
        input = input[None, :, :, :]
        x = self.patch_embedding(input)
        x = self.transformer_encoder(x)
        # print("Shape before classification head:", x.shape)  # Debugging line
        x = self.classification_head(x)
        return x


from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

batch_size = 4

def pad_collate(batch):
    # Extract inputs and labels from the batch
    inputs = [item['input'] for item in batch]
    labels = [item['label'] for item in batch]

    # Check if the batch size is not a multiple of 4
    required_batch_size = batch_size
    shortfall = len(batch) % required_batch_size
    if shortfall > 0:
        # Calculate how many samples to add
        samples_to_add = required_batch_size - shortfall
        # Randomly select samples to add
        for _ in range(samples_to_add):
            random_sample = random.choice(batch)  # Assuming 'random' is already imported
            inputs.append(random_sample['input'])
            labels.append(random_sample['label'])

    # Pad the inputs to have the same length
    inputs_padded = pad_sequence(inputs, batch_first=True, padding_value=0)

    # Stack labels into a single tensor
    labels = torch.stack(labels)

    return {'input': inputs_padded, 'label': labels}



class EEGDataset(Dataset):
    def __init__(self, root_dir, max_timesteps=1000):
        self.max_timesteps = max_timesteps
        self.data_files = []
        self.portion_counts = []  # Store the number of portions per file
        self.labels = []
        for folder_name in os.listdir(root_dir):
            if '(' in folder_name and ')' in folder_name:
                label = int(folder_name.split('(')[-1].split(')')[0])
                folder_path = os.path.join(root_dir, folder_name)
                for file_name in os.listdir(folder_path):
                    if file_name.endswith('.mat') and file_name != 'FFT.mat':
                        file_path = os.path.join(folder_path, file_name)
                        self.data_files.append(file_path)
                        self.labels.append(label)
                        # Determine how many portions this file will be split into
                        data = scipy.io.loadmat(file_path)['data']
                        portions = math.ceil(data.shape[1] / max_timesteps)
                        self.portion_counts.append(portions)

    def __len__(self):
        return sum(self.portion_counts)

    def __getitem__(self, idx):
        # Find which file and which portion this index corresponds to
        file_idx = 0
        while idx >= self.portion_counts[file_idx]:
            idx -= self.portion_counts[file_idx]
            file_idx += 1
        portion_idx = idx
        label = self.labels[file_idx]
        data_path = self.data_files[file_idx]
        data = scipy.io.loadmat(data_path)['data']
        # Calculate the start and end indices for this portion
        start_idx = portion_idx * self.max_timesteps
        end_idx = min((portion_idx + 1) * self.max_timesteps, data.shape[1])
        # Slice the data for this portion
        data_portion = data[:, start_idx:end_idx]
        data_portion = torch.tensor(data_portion, dtype=torch.float32)
        return {'input': data_portion, 'label_ids': label, 'label': torch.tensor(label, dtype=torch.long)}

def model_init():
    return Conformer()  # Initialize your modified Conformer model here



class CustomTrainer(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_fn = MSELoss()

    def compute_loss(self, model, inputs, return_outputs=False):
        labels = inputs["label"].float()  # Ensure labels are float for MSE Loss
        outputs = model(**inputs)
        # Reshape labels to match output shape (batch_size, 1)
        labels = labels.view(-1, 1)
        labels = labels.transpose(0, 1)
        loss = self.loss_fn(outputs, labels)
        return (loss, outputs) if return_outputs else loss

training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=1000,
    per_device_train_batch_size=batch_size,
    logging_dir='./logs',
    logging_steps=100,
    learning_rate=1e-3,
    lr_scheduler_type='cosine',
    warmup_ratio=0.1,
    weight_decay=0.01,
    #save_total_limit=3,
)

data_collator = pad_collate


trainer = CustomTrainer(
    model_init=model_init,
    args=training_args,
    train_dataset=EEGDataset(root_dir='/content/drive/MyDrive/be_lab/data/G01_data_cut'),
    data_collator=data_collator,
)


trainer.train()

Step,Training Loss
100,52.7511
200,40.8862
300,35.3739
400,29.8642
500,23.0157
600,15.4687
700,8.7348
800,5.7819
900,4.7551
1000,4.6867


KeyboardInterrupt: 

In [None]:
# Save the trained model
model_path = './trained_conformer_model.pth'
torch.save(trainer.model.state_dict(), model_path)
print(f"Model saved to {model_path}")

Model saved to ./trained_conformer_model.pth


In [None]:
def perform_inference():
    # Initialize the model
    model = model_init()
    # Load the trained model weights
    model.load_state_dict(torch.load("/content/trained_conformer_model.pth"))
    model.eval()  # Set the model to evaluation mode

    # DataLoader for inference dataset
    inference_dataset = EEGDataset(root_dir='/content/drive/MyDrive/be_lab/data/G01_data_cut')  # Use the provided data directory
    inference_loader = DataLoader(inference_dataset, batch_size=1, collate_fn=pad_collate)

    # Check the size of the dataset
    print(f"Dataset size: {len(inference_dataset)}")

    # Try to get the first item from the dataset
    if len(inference_dataset) > 0:
        first_item = inference_dataset[0]
        print(f"First item keys: {first_item.keys()}")
    else:
        print("Dataset is empty. Check the dataset path and contents.")

    # correct_predictions = 0
    # total_predictions = 0
    total_error = 0
    batches = 0

    # Perform inference
    with torch.no_grad():  # No need to track gradients
        for batch in inference_loader:
            batches += 1
            inputs = batch['input']
            true_labels = batch['label']  # Assuming you want to compare against true labels
            outputs = model(inputs)
            # print(inputs)
            print(outputs)
            print(true_labels)
            #true_label = true_labels.transpose(0, 1)
            # print(true_label)
            # _, predicted_labels = torch.max(outputs, 1)
            predicted_labels = outputs
            total_error += abs(predicted_labels - true_labels).sum()
            # print(predicted_labels, true_labels
            # correct_predictions += (predicted_labels == true_labels).sum().item()
            # total_predictions += true_labels.size(0)
            # Process the outputs as needed
            #print(outputs)
    # accuracy = correct_predictions / total_predictions
    # print(f'Accuracy: {accuracy * 100:.2f}%')
    print("average_error = ", total_error / batches)
    print("batches = ", batches)
    print("total_error = ", total_error)
    print("len of dataset = ", len(inference_dataset))

perform_inference()



Dataset size: 103
First item keys: dict_keys(['input', 'label_ids', 'label'])
tensor([[9.6928, 7.2218, 7.3124, 7.4642]])
tensor([8, 8, 8, 8])
tensor([[9.4314, 7.6260, 7.6931, 7.8083]])
tensor([8, 8, 8, 8])
tensor([[8.3761, 8.0902, 8.0881, 8.1228]])
tensor([8, 8, 8, 8])
tensor([[8.6492, 8.0383, 8.0531, 8.1041]])
tensor([8, 8, 8, 8])
tensor([[4.4689, 7.2810, 7.1072, 7.0110]])
tensor([8, 8, 8, 8])
tensor([[9.5410, 7.4924, 7.5389, 7.7073]])
tensor([8, 8, 8, 8])
tensor([[8.4441, 8.0914, 8.0883, 8.1382]])
tensor([8, 8, 8, 8])
tensor([[8.2870, 8.0791, 8.0772, 8.1012]])
tensor([8, 8, 8, 8])
tensor([[6.2735, 7.8982, 7.7887, 7.7632]])
tensor([8, 8, 8, 8])
tensor([[8.4300, 8.0948, 8.0896, 8.1424]])
tensor([8, 8, 8, 8])
tensor([[8.8687, 7.9895, 7.9990, 8.0928]])
tensor([8, 8, 8, 8])
tensor([[6.7142, 7.9817, 7.9123, 7.8666]])
tensor([8, 8, 8, 8])
tensor([[7.3173, 8.0974, 8.0375, 8.0359]])
tensor([8, 8, 8, 8])
tensor([[5.7062, 7.7351, 7.6034, 7.5429]])
tensor([8, 8, 8, 8])
tensor([[7.9883, 8.1305, 8