In [1]:
import torch
import numpy as np
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from multiprocessing import Pool
from DTran import DualTran
from DTRN2 import DTRN
from EEGTran import EEGTran
from scipy.signal import butter, filtfilt, iirnotch, lfilter
import os
from scipy.signal import resample
import pywt
#from EEGFormer import EEGFormer
from pytorch_memlab import profile, set_target_gpu, profile_every
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512"

print("PyTorch Version: ",torch.__version__)

PyTorch Version:  2.0.1+cu118


In [2]:
import pandas as pd
import torch
from torch.utils.data import Dataset
# stop using pandas for this it doesn't support multiprocessing

class EEGDataset(Dataset):
    def __init__(self, dir, chunksize, label_column='label', samples=60000, num_chunks=58):
        self.dir = dir
        self.chunksize = chunksize
        self.label_column = label_column
        self.samples = samples
        self.n_channels = 128
        self.n_timestamps = 500
        self.num_chunks = num_chunks
        
        # Initialize variables for chunk loading
        self.current_chunk = []
        self.current_labels = []
        self.current_row = []
        self.chunk_idx = torch.randint(0, 20, (1,)).item()
        self.reader = None

        # Load the first chunk
        self.load_next_chunk()

    def load_next_chunk(self):
        try:
            self.current_chunk = pd.read_parquet(self.dir + f'data_chunk_{self.chunk_idx}.parquet', engine='pyarrow')
            self.current_labels = pd.read_parquet(self.dir + f'label_chunk_{self.chunk_idx}.parquet', engine='pyarrow')
            print("Loaded chunk", self.chunk_idx)
            self.chunk_idx = torch.randint(0, 20, (1,)).item()
            
        except:
            self.chunk_idx = torch.randint(0, 20, (1,)).item()

    def __len__(self):
        return self.samples

    def bandpass_filter(self, data, lowcut, highcut, fs, order=5):
        nyq = 0.5 * fs  # Nyquist frequency, which is half of fs
        low = lowcut / nyq
        high = highcut / nyq
        b, a = butter(order, [low, high], btype='band')
        y = lfilter(b, a, data)
        return y
    
    def highpass_filter(self, data, cutoff, fs, order=5):
        """
        Apply high-pass filter to data.
        
        Parameters:
        - data: The signal data (numpy array)
        - cutoff: Cutoff frequency for high-pass filter
        - fs: Sampling frequency of the data
        - order: Order of the filter (default is 4)
        
        Returns:
        - Filtered data
        """
        nyq = 0.5 * fs  # Nyquist frequency
        normal_cutoff = cutoff / nyq
        b, a = butter(order, normal_cutoff, btype='high', analog=False)
        y = filtfilt(b, a, data)  # filtfilt is used to apply the filter forwards and backwards to avoid phase shifts
        return y
    
    def __getitem__(self, idx):
        if len(self.current_row) == 0:
       
            self.load_next_chunk()
            self.current_row = list(range(len(self.current_chunk)))
            np.random.shuffle(self.current_row)
            self.curr_row = self.current_row.pop()
            

        row = self.current_chunk.iloc[self.curr_row].values
        label = self.current_labels.iloc[self.curr_row].values
        self.curr_row = self.current_row.pop()
        
        assert row.shape[0] == self.n_channels * self.n_timestamps, "Unexpected number of columns."
        row = row.reshape(self.n_channels, self.n_timestamps)

        f_sample = 250.0 # Change this to your actual sample frequency
        f_notch = 50.0
        quality_factor = 30.0  # This defines the width of the notch

        # Design the notch filter
        b, a = iirnotch(f_notch, quality_factor, f_sample)
        filtered_data = np.empty_like(row)
        for i in range(row.shape[0]):
            filtered_data[i] = filtfilt(b, a, row[i])
        
        highpass_data = self.highpass_filter(filtered_data, .1, 250.0)
        coeffs = pywt.wavedec(highpass_data, 'db4', level=3)
        cA3, cD3, cD2, cD1 = coeffs

        sigma = np.median(np.abs(cD1)) / 0.6745
        n = len(filtered_data)
        threshold = sigma * np.sqrt(2 * np.log(n))

        #cA3 = pywt.threshold(cA3, threshold*2, mode='soft')
        cD3 = pywt.threshold(cD3, threshold*2, mode='soft')
        cD2 = pywt.threshold(cD2, threshold*2, mode='soft')
        cD1 = pywt.threshold(cD1, threshold*2, mode='soft')

        denoised_signal = pywt.waverec([cA3, cD3, cD2, cD1], 'db4')

        row = denoised_signal

        row = torch.from_numpy(row).float()
        label = torch.tensor(label).long();
    
        return row, label


In [3]:
train = 'data\\train_data_chunks\\'
dataset = EEGDataset(train, chunksize=256, label_column='label', samples=120000, num_chunks=59)
trainloader = DataLoader(dataset, batch_size=8, shuffle=False, drop_last=True, pin_memory=True)

test = 'data\\test_data_chunks\\'
dataset = EEGDataset(test, chunksize=256, label_column='label', samples=20000, num_chunks=10)
testloader = DataLoader(dataset, batch_size=8, shuffle=False, drop_last=True, pin_memory=True)

Loaded chunk 4
Loaded chunk 9


In [4]:
if torch.cuda.is_available():
    print("CUDA is available.")
    device = torch.device("cuda:0")
else:
    print("CUDA is not available.")

CUDA is available.


In [5]:
import torch
import math
import torch.nn as nn

class Depthwise1DCNN(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, num_kernels):
        super(Depthwise1DCNN, self).__init__()

        self.conv1 = nn.Conv1d(in_channels, in_channels * num_kernels, kernel_size, groups=in_channels, padding='valid')
        self.layer_norm1 = nn.LayerNorm(491)
        self.conv2 = nn.Conv1d(in_channels * num_kernels, in_channels * num_kernels, kernel_size, groups=in_channels, padding='valid')
        self.layer_norm2 = nn.LayerNorm(482)
        self.conv3 = nn.Conv1d(in_channels * num_kernels, in_channels * num_kernels, kernel_size//2, groups=in_channels, padding='valid')
        self.layer_norm3 = nn.LayerNorm(478)
        self.relu = nn.ReLU()

    @profile_every(1)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.layer_norm1(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.layer_norm2(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = self.layer_norm3(x)

        return x
    
class RegionalTran(nn.Module):
    def __init__(self, d_model, num_heads, num_features, num_blocks):
        super(RegionalTran, self).__init__()
        self.linear_map = nn.Linear(d_model, d_model)
        self.pos_enc = PositionalEncoding(d_model=d_model, max_len=num_features)
        self.transformer_block = nn.ModuleList([TransformerBlock(d_model, num_heads) for _ in range(num_blocks)])

    @profile_every(1)
    def forward(self, x):
        print(x.size())
        x = self.linear_map(x)
        x = self.pos_enc(x)
        for transformer in self.transformer_block:
            x = transformer(x)
        return x
    
class SynchronousTran(nn.Module):
    def __init__(self, d_model, num_heads, num_features, num_blocks):
        super(SynchronousTran, self).__init__()
        print(d_model)
        self.linear_map = nn.Linear(d_model, d_model)
        self.pos_enc = PositionalEncoding(d_model=d_model, max_len=num_features)
        self.transformer_block = nn.ModuleList([TransformerBlock(d_model, num_heads) for _ in range(num_blocks)])

    def forward(self, x):
        x = x.permute(0, 2, 1, 3)
        x = self.linear_map(x)
        x = self.pos_enc(x)
        for transformer in self.transformer_block:
            x = transformer(x)
        return x

class TemporalTran(nn.Module):
    def __init__(self, d_model, num_heads, num_features, num_blocks, compress=250):
        super(TemporalTran, self).__init__()
        self.compress = compress
        self.linear_map = nn.Linear(d_model, d_model)
        self.pos_enc = PositionalEncoding(d_model=d_model, max_len=num_features)
        self.transformer_block = nn.ModuleList([TransformerBlock(d_model, num_heads) for _ in range(num_blocks)])
    
    def compress_temporal(self, x):
        B, C, S, D = x.size()

        assert D % self.compress == 0

        segment_size = D // self.compress

        reshaped_x = x.view(B, C, S, D, segment_size)

        compressed_x = torch.mean(reshaped_x, dim=-1)

        return compressed_x

    def forward(self, x):
        x = self.compress_temporal(x)
        x = x.permute(0, 3, 2, 1)
        x = self.linear_map(x)
        x = self.pos_enc(x)
        for transformer in self.transformer_block:
            x = transformer(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self, channels, features, temporal, N):
        super(Decoder, self).__init__()
        self.normConv = nn.Conv2d(features, 1, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels=1, out_channels=N, kernel_size=(5, 5), padding=(2, 2))
        self.conv3 = nn.Conv2d(in_channels=N, out_channels=N, kernel_size=(5, 5), padding=(2, 2), stride=(2, 2))
        self.linear = nn.Linear(N * (features // 2) * (temporal // 2), 11)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = x.permute(0, 2, 3, 1)
        x = self.normConv(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.relu(x)
        x = x.view(x.size(0), -1)
        x = self.linear(x)
        return x

class EEGFormer(nn.Module):
    def __init__(self, d_model=478, channels=128, features=128, N=128):
        super(EEGFormer, self).__init__()
        self.channels = channels
        self.embedding = Depthwise1DCNN(channels, features, 10, features)
        self.regional = RegionalTran(d_model, 8, features, 1)
        self.synchronous = SynchronousTran(d_model, 8, features, 1)
        self.temporal = TemporalTran(features, 10, features, 1)
        self.decoder = Decoder(channels, features, 250, N)
        
    @profile_every(1)
    def forward(self, x):
        x = self.embedding(x)
        x = x.view(x.size(0), self.channels, -1, x.size(-1))
        x = self.regional(x)
        x = self.synchronous(x)
        x = self.temporal(x)
        x = self.decoder(x)
        return x

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin((position * div_term)*(d_model/max_len))
        pe[:, 1::2] = torch.cos((position * div_term)*(d_model/max_len))
        pe = pe.unsqueeze(0)

        self.register_buffer('pe', pe)

    @profile_every(1)
    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :]
        return x
    
class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads):
        super(TransformerBlock, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_size = d_model // num_heads
        self.multi_head_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForward(d_model)
        self.layer_norm = nn.LayerNorm(d_model)

    @profile_every(1)
    def forward(self, x):
        x = x + self.multi_head_attn(x)
        x = self.layer_norm(x)
        x = x + self.feed_forward(x)
        x = self.layer_norm(x)
        return x
    
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_size = d_model // num_heads
        self.query = nn.Linear(d_model, d_model)
        self.key = nn.Linear(d_model, d_model)
        self.value = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(0.1)
        self.projection = nn.Linear(d_model, d_model)

    @profile_every(1)
    def forward(self, x):

        B, S, C, L = x.size()

        x = x.view(B*S, C, L)

        query = self.query(x)
        key = self.key(x)
        value = self.value(x)

        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_size)
        scores = self.softmax(scores)
        scores = self.dropout(scores)

        context = torch.matmul(scores, value)

        context = context.view(B, S, C, L)
        
        context = self.projection(context)
        return context


class FeedForward(nn.Module):
    def __init__(self, d_model):
        super(FeedForward, self).__init__()
        self.MLP = nn.Sequential(
            nn.Linear(d_model, d_model * 2),
            nn.ReLU(),
            nn.Linear(d_model * 2, d_model)
        )
        self.layer_norm = nn.LayerNorm(d_model)

    @profile_every(1)
    def forward(self, x):
        x = x + self.MLP(x)
        x = self.layer_norm(x)
        return x



In [6]:
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
from focal_loss.focal_loss import FocalLoss
import wandb
#writer = SummaryWriter()
torch.cuda.empty_cache()

'''wandb.init(
    # set the wandb project where this run will be logged
    project="my-awesome-project",
    
    # track hyperparameters and run metadata
    config={
    "learning_rate": 0.00001,
    "architecture": "EEGTran",
    "dataset": "MNSIT-8B"
    }
)'''

#odel = DualTran().to(device)
model = EEGFormer().to(device)
#wandb.watch(model, log="all")
set_target_gpu(0)

#model.load_state_dict(torch.load('model_overfit.pt'))
weights = torch.FloatTensor([2, 3.2, 0.7])
#criterion = torch.nn.CrossEntropyLoss()
criterion = FocalLoss(gamma=0.8)
s = torch.nn.Softmax(dim=1)
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)


478


In [7]:
import sys
def minMaxNorm(tensor):
    # Calculate the minimum and maximum values across the seq_len dimension (i.e., dim=-1)
    min_vals = torch.min(tensor, dim=2, keepdim=True)[0]
    max_vals = torch.max(tensor, dim=2, keepdim=True)[0]

    # Perform the min-max normalization
    normalized_tensor = (tensor - min_vals) / (max_vals - min_vals + 1e-10)  # Added epsilon to prevent division by zero

    return normalized_tensor

In [8]:
def z_norm(data):
    """
    Apply z-score normalization to EEG data.
    
    Parameters:
    - data: The EEG data (numpy array)
    
    Returns:
    - Normalized data
    """
    mean = torch.mean(data, dim=2, keepdim=True)
    std = torch.std(data, dim=2, keepdim=True)
    normalized_data = (data - mean) / (std + 1e-10)  # Added epsilon to prevent division by zero
    return normalized_data

In [9]:
for i in range(10):
    total = 0
    for j, (data, label) in enumerate(trainloader):
        data = data.to(device)
        data = z_norm(data)

        output = model(data)
        label= torch.where(label.flatten() == -1, torch.tensor(10), label.flatten())
        
        label = label.to(device)
        loss = criterion(s(output), label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total += loss.item()
        '''if np.random.randint(0, 100) == 1:
            print(loss.item())
            print(one_hot[0])
            print(output[0])'''
        
        #writer.add_scalar('Loss/train', loss.item(), j + i * len(trainloader))
        #wandb.log({"loss/train": loss.item()})

    print(f'epoch{i}', total / len(trainloader))

    train_loss = 0
    for k, (data, label) in enumerate(testloader):
        data = data.to(device)
        data = z_norm(data)
        output = model(data)
    
        label = torch.where(label.flatten() == -1, torch.tensor(10), label.flatten())
        label = label.to(device)
        loss = criterion(s(output), label)

        #writer.add_scalar('Loss/test', loss.item(), k + i * len(trainloader))
        #wandb.log({"loss/test": loss.item()})

        train_loss += loss.item()
    print(f'test_loss{i}', train_loss / len(testloader))
#wandb.finish()
    
        

Loaded chunk 6


  merged[byte_cols] = merged[byte_cols].applymap(readable_size)


## Depthwise1DCNN.forward

active_bytes reserved_bytes line code                            
         all            all                                      
        peak           peak                                      
     184.16M        196.00M   17     @profile_every(1)           
                              18     def forward(self, x):       
     430.16M        442.00M   19         x = self.conv1(x)       
     676.16M        688.00M   20         x = self.relu(x)        
     677.16M        688.00M   21         x = self.layer_norm1(x) 
       1.45G          1.46G   22         x = self.conv2(x)       
       1.13G          1.46G   23         x = self.relu(x)        
       1.13G          1.46G   24         x = self.layer_norm2(x) 
       1.88G          1.97G   25         x = self.conv3(x)       
       1.60G          1.97G   26         x = self.relu(x)        
       1.60G          1.97G   27         x = self.layer_norm3(x) 
                              28                 

  merged[byte_cols] = merged[byte_cols].applymap(readable_size)


## PositionalEncoding.forward

active_bytes reserved_bytes line code                                      
         all            all                                                
        peak           peak                                                
       1.84G          1.97G  146     @profile_every(1)                     
                             147     def forward(self, x):                 
       2.08G          2.21G  148         x = x + self.pe[:, :x.size(1), :] 
       2.08G          2.21G  149         return x                          


  merged[byte_cols] = merged[byte_cols].applymap(readable_size)


## MultiHeadAttention.forward

active_bytes reserved_bytes line code                                                                                    
         all            all                                                                                              
        peak           peak                                                                                              
       1.84G          2.21G  182     @profile_every(1)                                                                   
                             183     def forward(self, x):                                                               
                             184                                                                                         
       1.84G          2.21G  185         B, S, C, L = x.size()                                                           
                             186                                                                                   

  merged[byte_cols] = merged[byte_cols].applymap(readable_size)


## FeedForward.forward

active_bytes reserved_bytes line code                           
         all            all                                     
        peak           peak                                     
       3.39G          3.45G  215     @profile_every(1)          
                             216     def forward(self, x):      
       4.32G          4.38G  217         x = x + self.MLP(x)    
       4.33G          4.38G  218         x = self.layer_norm(x) 
       4.33G          4.38G  219         return x               


  merged[byte_cols] = merged[byte_cols].applymap(readable_size)
  merged[byte_cols] = merged[byte_cols].applymap(readable_size)


## TransformerBlock.forward

active_bytes reserved_bytes line code                                    
         all            all                                              
        peak           peak                                              
       1.84G          2.21G  161     @profile_every(1)                   
                             162     def forward(self, x):               
       3.39G          3.44G  163         x = x + self.multi_head_attn(x) 
       3.39G          3.45G  164         x = self.layer_norm(x)          
       4.56G          4.61G  165         x = x + self.feed_forward(x)    
       4.56G          4.62G  166         x = self.layer_norm(x)          
       4.56G          4.62G  167         return x                        
## RegionalTran.forward

active_bytes reserved_bytes line code                                               
         all            all                                                         
        peak           peak         

  merged[byte_cols] = merged[byte_cols].applymap(readable_size)


## PositionalEncoding.forward

active_bytes reserved_bytes line code                                      
         all            all                                                
        peak           peak                                                
       5.03G          5.08G  146     @profile_every(1)                     
                             147     def forward(self, x):                 
       5.26G          5.32G  148         x = x + self.pe[:, :x.size(1), :] 
       5.26G          5.32G  149         return x                          


  merged[byte_cols] = merged[byte_cols].applymap(readable_size)


## MultiHeadAttention.forward

active_bytes reserved_bytes line code                                                                                    
         all            all                                                                                              
        peak           peak                                                                                              
       5.03G          5.32G  182     @profile_every(1)                                                                   
                             183     def forward(self, x):                                                               
                             184                                                                                         
       5.03G          5.32G  185         B, S, C, L = x.size()                                                           
                             186                                                                                   

OutOfMemoryError: CUDA out of memory. Tried to allocate 478.00 MiB (GPU 0; 8.00 GiB total capacity; 7.04 GiB already allocated; 0 bytes free; 7.08 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
torch.save(model.state_dict(), 'model_overfit.pt')