In [1]:
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch.optim as optim
from torchsummary import summary


from torch.utils.data import DataLoader

from src.dataset.MI_dataset_all_subjects import MI_Dataset as MI_Dataset_all_subjects
from src.dataset.MI_dataset_single_subject import MI_Dataset as MI_Dataset_single_subject

from config.default import cfg


from models.eegnet import EEGNet

from utils.eval import accuracy

%load_ext autoreload
%autoreload 2


In [2]:
device =  torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [4]:
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [5]:
subject = 1
train_runs = [0,1,2,3,4]
test_runs = [5]


train_dataset = MI_Dataset_single_subject(subject, train_runs, device = device)
test_dataset = MI_Dataset_single_subject(subject, test_runs, device = device)

train_dataloader = DataLoader(train_dataset,  batch_size=cfg['train']['batch_size'], shuffle=True, drop_last=True)
test_dataloader = DataLoader(test_dataset,  batch_size=cfg['train']['batch_size'], shuffle=False, drop_last=True)

In [6]:
print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

for features, label in train_dataloader:
    print(features.shape)
    print(label)
    break
    


Train dataset: 240 samples
Test dataset: 48 samples
torch.Size([48, 22, 1001])
tensor([3, 0, 2, 0, 3, 1, 1, 3, 2, 1, 3, 3, 0, 0, 0, 3, 3, 0, 3, 1, 3, 1, 3, 1,
        2, 3, 2, 0, 3, 1, 1, 3, 0, 2, 2, 2, 0, 0, 3, 2, 3, 1, 1, 2, 1, 1, 2, 2],
       device='cuda:0')


In [7]:
class conv_block(nn.Module):
    def __init__(self,):
        super(conv_block,self).__init__()
        self.conv_block_1 = nn.Sequential(
                nn.Conv2d(1, 16, kernel_size=(1,64), bias=False,padding='same'),
                nn.BatchNorm2d(16),
                # problem,
            )
        self.depthwise = nn.Conv2d(16, 16, (22,1), stride=1, padding=0, dilation=1, groups=16, bias=False)
        self.pointwise = nn.Conv2d(16, 16*2, 1, 1, 0, 1, 1, bias=False)
        self.conv_block_2 = nn.Sequential(
                nn.BatchNorm2d(32),
                nn.ELU(),
                nn.Dropout(0.5),
                nn.AvgPool2d(kernel_size=(1,8)),
                nn.Conv2d(32, 32, kernel_size=(1,16), bias=False,padding='same'),
                nn.BatchNorm2d(32),
                nn.ELU(),
                nn.AvgPool2d(kernel_size=(1, 7)),
                nn.Dropout(0.5),
            )
    def forward(self, x):
        x = self.conv_block_1(x)
        x = self.depthwise(x)
        x = self.pointwise(x)
        out = self.conv_block_2(x)
        # nn.utils.clip_grad_norm_(self.depthwise.parameters(), max_norm=1.0)  
        return out


In [12]:
class EEGNet(nn.Module):
    def __init__(self, num_classes: int = 4, channels: int = 22, samples: int = 1001,
        dropout_rate: float = 0.5, kernel_length: int = 64, num_filters1: int = 16,
        depth_multiplier: int = 2, num_filters2: int = 32, norm_rate: float = 0.25) -> None:
        super(EEGNet, self).__init__()

        self.channels = channels
        self.samples = samples

        # First convolutional block
        # Temporal convolutional to learn frequency filters
        self.conv1 = nn.Conv2d(1, num_filters1, (1, kernel_length), padding=(0, kernel_length // 2), bias=False)
        self.bn1 = nn.BatchNorm2d(num_filters1)
        
        # Depthwise convolutional block
        # Connected to each feature map individually, to learn frequency-specific spatial filters
        self.dw_conv1 = nn.Conv2d(num_filters1, num_filters1 * depth_multiplier, (channels, 1), groups=num_filters1, bias=False)
        self.bn2 = nn.BatchNorm2d(num_filters1 * depth_multiplier)
        self.activation = nn.ELU()
        self.avg_pool1 = nn.AvgPool2d((1, 4))
        self.dropout1 = nn.Dropout(dropout_rate)

        # Separable convolutional block
        # Learns a temporal summary for each feature map individually, 
        # followed by a pointwise convolution, which learns how to optimally mix the feature maps together
        self.sep_conv1 = nn.Conv2d(num_filters1 * depth_multiplier, num_filters1 * depth_multiplier, (1, 16), groups=num_filters1 * depth_multiplier, padding=(0, 8), bias=False)
        self.conv2 = nn.Conv2d(num_filters1 * depth_multiplier, num_filters2, 1, bias=False)
        self.bn3 = nn.BatchNorm2d(num_filters2)
        self.avg_pool2 = nn.AvgPool2d((1, 8))
        self.dropout2 = nn.Dropout(dropout_rate)

       

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.view(-1,  1, self.channels, self.samples)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.dw_conv1(x)
        x = self.bn2(x)
        x = self.activation(x)
        x = self.avg_pool1(x)
        x = self.dropout1(x)

        x = self.sep_conv1(x)
        x = self.conv2(x)
        x = self.bn3(x)
        x = self.activation(x)
        x = self.avg_pool2(x)
        x = self.dropout2(x)


        return x

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, input_size = 32, num_heads=2):
        super().__init__()
        self.input_size = input_size
        
        self.embed_dim = self.input_size    
        self.num_heads = num_heads
        
        # Instantiate the PyTorch's built-in MultiheadAttention module with batch_first=True
        self.multihead_attn = nn.MultiheadAttention(embed_dim=self.embed_dim, num_heads=self.num_heads, dropout=0.5,  batch_first=True)
        
        # Output linear layer remains the same
        self.W_O = nn.Linear(input_size, input_size)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x):
        # With batch_first=True, x is expected to be (batch_size, sequence_length, embedding_dim)
        # We can pass x directly without transposing
        x = x.permute(0,2,1)
        #print('mha forward', x.shape)
        attn_output, attn_output_weights = self.multihead_attn(x, x, x)
        
        # Since we are using batch_first=True, attn_output is already in the shape (batch_size, sequence_length, embedding_dim)
        # So, no need to transpose it before passing through the final linear layer
        output = self.W_O(attn_output)
        output = self.dropout(output)
        output = output.permute(0,2,1)
        return output


In [13]:
class CausalConv1d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, dilation = 1):
        super(CausalConv1d, self).__init__()
        self.padding = (kernel_size - 1) * dilation
        self.conv1d = nn.Conv1d(in_channels, out_channels, kernel_size, dilation=dilation)
        nn.init.kaiming_uniform_(self.conv1d.weight, nonlinearity='linear')

    def forward(self, x):
        x = F.pad(x, (self.padding, 0))
        return self.conv1d(x)
    
class TCN_block(nn.Module):
    def __init__(self, depth=2):
        super(TCN_block, self).__init__()
        self.depth = depth


        self.Activation_1 = nn.ELU()
        self.TCN_Residual_1 = nn.Sequential(
            #可能问题的所在
            CausalConv1d(32, 32, 4, dilation=1),
            nn.BatchNorm1d(32),
            nn.ELU(),
            nn.Dropout(0.3),
            CausalConv1d(32, 32, 4, dilation=1),
            nn.BatchNorm1d(32),
            nn.ELU(),
            nn.Dropout(0.3),
        )
        
        self.TCN_Residual = nn.ModuleList()
        self.Activation = nn.ModuleList()
        for i in range(depth-1):
            TCN_Residual_n = nn.Sequential(
            CausalConv1d(32, 32, 4, dilation=2**(i+1)),
            nn.BatchNorm1d(32),
            nn.ELU(),
            nn.Dropout(0.3),
            CausalConv1d(32, 32, 4, dilation=2**(i+1)),
            nn.BatchNorm1d(32),
            nn.ELU(),
            nn.Dropout(0.3),
        )
            self.TCN_Residual.append(TCN_Residual_n)
            self.Activation.append(nn.ELU())   
        
    def forward(self, x):
        block = self.TCN_Residual_1(x)
        # print(block.shape)
        block += x
        block = self.Activation_1(block)
        
        for i in range(self.depth-1):
            block_o = block
            block = self.TCN_Residual[i](block)
            block += block_o
            # block = torch.add(block_o,block)
            block = self.Activation[i](block)
        return block[:, :, -1]

In [14]:

class ATCNet(nn.Module):
    def __init__(self, num_classes = 4, fe = 'conv', fuze = 'average') -> None:
        super(ATCNet, self).__init__()
        self.num_classes = num_classes
        if fe == 'conv':
            self.conv_block = conv_block()
        if fe == 'eegnet':
            self.conv_block = EEGNet()  

        self.fuze = fuze# 'average' or 'concat'

        self.attention_list = nn.ModuleList()
        self.TCN_list = nn.ModuleList()
        for i in range(5):
            self.attention_list.append(MultiHeadAttention())
            self.TCN_list.append(TCN_block())
        
        if self.fuze == 'average':
            self.fuze_layers = nn.ModuleList()
            for i in range(5):
                self.fuze_layers.append(nn.Linear(32, num_classes))

        # Fully connected layer
        self.flatten = nn.Flatten()
        asd = 160 #2080 # 544
        self.dense = nn.Linear(asd, num_classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x.unsqueeze(1)
        block1 = self.conv_block(x) 
        #print(block1.shape) # [B,32,1,17]
        n_windows = 5
        block1 = block1.squeeze(2)


        sliding_window_output = []
        for i in range(n_windows):
            start = i
            end = block1.shape[2]-n_windows+i+1
            block2 = block1[:,:, start:end]
            #print(block2.shape) # [B,32,13]
            block2 = self.attention_list[i](block2)
            #sliding_window_output.append(block2)
            #print(block2.shape)  # [B, 32, 13]
            block2 = self.TCN_list[i](block2)
            #print(block2.shape)   # [B, 32]
            
            sliding_window_output.append(block2)

        if self.fuze == 'concat':
            output = torch.cat(sliding_window_output, dim=1)
            #print(output.shape) # [B, 32, 65]
            #output = block1
            output = self.flatten(output)
            #print(output.shape)
            output = self.dense(output)
        if self.fuze == 'average':
            for i in range(5):
                sliding_window_output[i] = self.fuze_layers[i](sliding_window_output[i])
            output = torch.stack(sliding_window_output, dim=2) # check dim may permute required
            output = torch.mean(output, dim=2)
            

        return output


In [15]:
dummy = torch.randn(1, 22, 1001).to(device)
model = ATCNet().to(device)
model(dummy)

tensor([[-0.2433, -0.2354,  0.3645, -0.0508]], device='cuda:0',
       grad_fn=<MeanBackward1>)

In [11]:
#model = ATCNet().to(device)
#summary(model, ( 22, 1000))

In [12]:

out = model(dummy)
out

tensor([[ 0.1788, -0.2459, -0.4350,  0.0487]], device='cuda:0',
       grad_fn=<MeanBackward1>)

In [13]:
# Test forward pass
model(next(iter(train_dataloader))[0]);

In [16]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=cfg['train']['learning_rate'], weight_decay=cfg['train']['weight_decay'])

# Training loop
for epoch in range(cfg['train']['n_epochs']):
    epoch_loss = 0.0

    for batch_features, batch_labels in train_dataloader:
        optimizer.zero_grad()
        outputs = model(batch_features)
        loss = criterion(outputs, batch_labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()

    if epoch % 10 == 9:
        train_accuracy = accuracy(model, train_dataloader)
        test_accuracy = accuracy(model, test_dataloader)
        print(f"Epoch {epoch + 1}/{cfg['train']['n_epochs']}, Loss: {epoch_loss}, Train accuracy: {train_accuracy:.2f}%, Test accuracy: {test_accuracy:.2f}%")

print("#"*50)
print(f'Final_loss: {epoch_loss}')
print(f'Final train accuracy: {accuracy(model, train_dataloader):.2f}%')
print(f'Final test accuracy: {accuracy(model, test_dataloader):.2f}%')

Epoch 10/200, Loss: 5.9273258447647095, Train accuracy: 44.17%, Test accuracy: 60.42%
Epoch 20/200, Loss: 5.295874655246735, Train accuracy: 49.58%, Test accuracy: 45.83%
Epoch 30/200, Loss: 4.981917023658752, Train accuracy: 51.67%, Test accuracy: 45.83%
Epoch 40/200, Loss: 4.745257079601288, Train accuracy: 50.00%, Test accuracy: 54.17%
Epoch 50/200, Loss: 4.471916615962982, Train accuracy: 59.58%, Test accuracy: 50.00%
Epoch 60/200, Loss: 4.001338601112366, Train accuracy: 65.83%, Test accuracy: 54.17%
Epoch 70/200, Loss: 3.6973262429237366, Train accuracy: 67.92%, Test accuracy: 41.67%
Epoch 80/200, Loss: 3.550718069076538, Train accuracy: 70.83%, Test accuracy: 50.00%
Epoch 90/200, Loss: 3.490455389022827, Train accuracy: 74.17%, Test accuracy: 47.92%
Epoch 100/200, Loss: 2.7460367679595947, Train accuracy: 75.83%, Test accuracy: 45.83%
Epoch 110/200, Loss: 2.9169896245002747, Train accuracy: 78.75%, Test accuracy: 50.00%
Epoch 120/200, Loss: 3.3092034459114075, Train accuracy: 81

### Results:

conv block 56%  

conv + mha 56%

conv + mha + tc concat (no dropout) 72% (max)

conv + mha + tc concat = 54%   //  68% (max)

conv + mha + tc concat = 58%   //  64% (max)