In [1]:
import numpy as np
import math
import glob
import random
import itertools
import datetime
import time
import datetime
import sys
import scipy.io
from torch.utils.data import DataLoader
from torch.autograd import Variable
# from torchsummary import summary
import torch.autograd as autograd

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import mne
from torch.utils.data import DataLoader

from braindecode.datasets import TUH
from braindecode.preprocessing import create_fixed_length_windows

mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted

In [3]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor

In [4]:
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce



In [5]:
import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True

In [6]:
from braindecode.datasets.tuh import _TUHMock as TUH  # noqa F811

In [7]:
TUH_PATH = 'edf/train/'
tuh = TUH(
    path=TUH_PATH,
    recording_ids=None,
    target_name=('gender'),  # use both age and gender as decoding target
    preload=False,
    add_physician_reports=False,
)
tuh.description

Unnamed: 0,path,version,year,month,day,subject,session,segment,age,gender
0,tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001...,v1.1.0,2003,2,5,58,1,0,0,M
1,tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004...,v1.1.0,2014,9,30,9932,4,13,53,F
2,tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s0...,v1.1.0,2014,12,14,12331,3,2,39,M
3,tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001...,v1.1.0,2015,12,30,0,1,0,37,M
4,tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s0...,v1.2.0,2016,1,15,14928,4,7,83,F


In [8]:
x, y = tuh[-1]
print('x:', x)
print('y:', y)

x: [[-2.29440914]
 [-1.57676045]
 [-0.32175798]
 [-1.60447638]
 [ 1.34726146]
 [ 1.35778785]
 [ 0.7181814 ]
 [-1.25348204]
 [-0.39194036]
 [ 0.80359228]
 [ 0.65341874]
 [-1.09040041]
 [ 2.17115861]
 [-0.90001031]
 [ 0.59748181]
 [ 0.58172382]
 [ 1.81903619]
 [ 2.14273693]
 [ 0.94823322]
 [ 1.32471369]
 [-1.03531365]]
y: F


In [9]:
tuh_windows = create_fixed_length_windows(
    tuh,
    start_offset_samples=0,
    stop_offset_samples=None,
    window_size_samples=1000,
    window_stride_samples=1000,
    drop_last_window=False,
    mapping={'M': 0, 'F': 1},  # map non-digit targets
)
# store the number of windows required for loading later on
tuh_windows.set_description({
    "n_windows": [len(d) for d in tuh_windows.datasets]})

In [10]:
x, y, ind = tuh_windows[-1]
print('x:', x)
print('y:', y)
print('ind:', ind)

x: [[ 0.95744467  0.36770442  0.20845759 ... -1.0786871  -0.61109
  -2.294409  ]
 [-1.1028883   1.0068827  -1.1743277  ...  0.771221   -1.3385423
  -1.5767604 ]
 [ 0.27898267  0.3863388  -0.2679523  ...  2.9042215  -0.33213955
  -0.32175797]
 ...
 [ 1.338765   -0.22695391  0.34812006 ...  0.10860196  0.2009789
   0.94823325]
 [ 1.5344774  -0.991343   -0.02419238 ... -0.64405024 -1.1567156
   1.3247137 ]
 [-0.4681851   1.0513295  -0.43032187 ...  0.20396124  0.15305676
  -1.0353136 ]]
y: 1
ind: [3, 2600, 3600]


In [11]:
dl = DataLoader(
    dataset=tuh_windows,
    batch_size=4,
)
for batch_X, batch_y, batch_ind in dl:
    pass
print('batch_X:', batch_X)
print('batch_y:', batch_y)
print('batch_ind:', batch_ind)

batch_X: tensor([[[-1.2357,  0.6497, -0.8300,  ..., -0.7639,  0.0224,  0.3111],
         [-0.8634,  0.5799, -0.3252,  ..., -0.9947, -1.6502, -1.1434],
         [ 0.0595, -0.1587, -1.6838,  ...,  0.5377, -0.4645,  0.4578],
         ...,
         [ 0.2055, -0.9867, -0.8866,  ..., -0.1409,  2.0982,  0.8535],
         [ 0.4756,  0.2467, -0.3098,  ...,  0.0427,  0.8393, -0.7588],
         [-0.3236,  0.6192,  1.4197,  ...,  0.5149, -0.2055,  0.5038]],

        [[-0.7843, -0.4270, -0.0223,  ..., -0.3390,  0.5544, -0.2855],
         [ 0.4028, -0.1024,  1.1524,  ...,  1.5556,  0.4898,  1.3321],
         [-1.1180, -1.0933, -0.2780,  ...,  0.0785,  0.9793,  1.2615],
         ...,
         [ 0.0842,  0.1249,  0.5489,  ..., -1.2234, -0.4355,  0.3331],
         [ 0.2752,  0.3448,  0.1283,  ...,  2.1729,  1.3250,  1.0280],
         [ 0.9441,  1.0467,  0.1602,  ..., -0.7809, -0.9615,  0.8170]],

        [[-0.3747, -1.1219,  1.6295,  ...,  0.8105,  0.4693,  0.3093],
         [-0.5658, -0.8672,  0.3879,

In [12]:
TUH_PATH = 'edf/dev/'
tuh1 = TUH(
    path=TUH_PATH,
    recording_ids=None,
    target_name=('gender'),  # use both age and gender as decoding target
    preload=False,
    add_physician_reports=False,
)
tuh1.description

Unnamed: 0,path,version,year,month,day,subject,session,segment,age,gender
0,tuh_eeg/v1.1.0/edf/02_tcp_le/000/00000058/s001...,v1.1.0,2003,2,5,58,1,0,0,M
1,tuh_eeg/v1.1.0/edf/01_tcp_ar/099/00009932/s004...,v1.1.0,2014,9,30,9932,4,13,53,F
2,tuh_eeg/v1.1.0/edf/03_tcp_ar_a/123/00012331/s0...,v1.1.0,2014,12,14,12331,3,2,39,M
3,tuh_eeg/v1.1.0/edf/01_tcp_ar/000/00000000/s001...,v1.1.0,2015,12,30,0,1,0,37,M
4,tuh_eeg/v1.2.0/edf/03_tcp_ar_a/149/00014928/s0...,v1.2.0,2016,1,15,14928,4,7,83,F


In [13]:
tuh_windows1 = create_fixed_length_windows(
    tuh1,
    start_offset_samples=0,
    stop_offset_samples=None,
    window_size_samples=1000,
    window_stride_samples=1000,
    drop_last_window=False,
    mapping={'M': 0, 'F': 1},  # map non-digit targets
)
# store the number of windows required for loading later on
tuh_windows1.set_description({
    "n_windows": [len(d) for d in tuh_windows1.datasets]})

In [14]:
tuh_windows1

<braindecode.datasets.base.BaseConcatDataset at 0x7f295f6544d0>

In [19]:
dll = DataLoader(
    dataset=tuh_windows1,
    batch_size=4,
)
for batch_X1, batch_y1, batch_ind1 in dll:
    batch_X1
print('batch_X:', batch_X.shape)
print('batch_y:', batch_y)
print('batch_ind:', batch_ind)

batch_X: torch.Size([4, 21, 1000])
batch_y: tensor([1, 1, 1, 1])
batch_ind: [tensor([0, 1, 2, 3]), tensor([   0, 1000, 2000, 2600]), tensor([1000, 2000, 3000, 3600])]


In [43]:
class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        # self.patch_size = patch_size
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
            nn.Conv2d(40, 40, (20, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d((1, 75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
            Rearrange('b e (h) (w) -> b (h w) e'),
        )


    def forward(self, x: Tensor) -> Tensor:
        b, _, _ = x.shape
        # print('x',x.shape)
        x = x[ :,None, :, :]
        # print('x',x.shape)
        x = self.shallownet(x)
        x = self.projection(x)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class GELU(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes):
        super().__init__()
        
        # global average pooling
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )
        self.fc = nn.Sequential(
            nn.Linear(4880, 256),#4x4880
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(256, 32),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Linear(32, 4)
        )

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        out = self.fc(x)
        return x, out


class Conformer(nn.Sequential):
    def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs):
        super().__init__(

            PatchEmbedding(emb_size),
            TransformerEncoder(depth, emb_size),
            ClassificationHead(emb_size, n_classes)
        )





In [47]:
class ExP():
    def __init__(self, nsub):
        super(ExP, self).__init__()
        self.batch_size = 72
        self.n_epochs = 500
        self.c_dim = 4
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.dimension = (190, 50)
        self.nSub = nsub

        self.start_epoch = 0
        self.root = '/Data/strict_TE/'

        # self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w")


        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor

        self.criterion_l1 = torch.nn.L1Loss().cuda()
        self.criterion_l2 = torch.nn.MSELoss().cuda()
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
        gpus = [0]
        self.model = Conformer().cuda()
        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
        self.model = self.model.cuda()
        # summary(self.model, (1, 22, 1000))





    def train(self):


        self.dataloader = DataLoader(dataset=tuh_windows,batch_size=4,)#torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)

        self.test_dataloader = DataLoader(dataset=tuh_windows1,batch_size=4,)#torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)

        # Optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        bestAcc = 0
        averAcc = 0
        num = 0
        Y_true = 0
        Y_pred = 0

        # Train the cnn model
        total_step = len(self.dataloader)
        curr_lr = self.lr

        for e in range(self.n_epochs):
            pred = []
            true = []
            
            # in_epoch = time.time()
            self.model.train()
            for img, label, i in self.dataloader:

                img = Variable(img.cuda().type(self.Tensor))
                label = Variable(label.cuda().type(self.LongTensor))

                # data augmentation
                # aug_data, aug_label = self.interaug(self.allData, self.allLabel)
                # img = torch.cat((img, aug_data))
                # label = torch.cat((label, aug_label))


                tok, outputs = self.model(img)

                loss = self.criterion_cls(outputs, label) 

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            print('epoch',e)
            print('loss', loss)
            # out_epoch = time.time()

#             for test_data, test_label, i in self.test_dataloader:
#             # test process
#                 if (e + 1) % 1 == 0:
#                     self.model.eval()
#                     Tok, Cls = self.model(test_data)


#                     loss_test = self.criterion_cls(Cls, test_label)
#                     y_pred = torch.max(Cls, 1)[1]
#                     pred.append(y_pred)
#                     true.append(test_label)
                    
#             acc = float((pred == true).cpu().numpy().astype(int).sum()) / float(len(true))
#             # train_pred = torch.max(outputs, 1)[1]
#             # train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))

#             print('Epoch:', e,
#                   '  Train loss: %.6f' % loss.detach().cpu().numpy(),
#                   '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
#                   # '  Train accuracy %.6f' % train_acc,
#                   '  Test accuracy is %.6f' % acc)

#             # self.log_write.write(str(e) + "    " + str(acc) + "\n")
#             num = num + 1
#             averAcc = averAcc + acc
#             if acc > bestAcc:
#                 bestAcc = acc
#                 Y_true = true
#                 Y_pred = pred


        torch.save(self.model.module.state_dict(), 'model.pth')
#         averAcc = averAcc / num
#         print('The average accuracy is:', averAcc)
#         print('The best accuracy is:', bestAcc)
#         # self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
#         # self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")

#         return bestAcc, averAcc, Y_true, Y_pred
#         # writer.close()




In [48]:
def main():
    best = 0
    aver = 0
    # result_write = open("preprocess/sub_result.txt", "w")

    for i in range(1):
        starttime = datetime.datetime.now()


        seed_n = np.random.randint(2021)
        print('seed is ' + str(seed_n))
        random.seed(seed_n)
        np.random.seed(seed_n)
        torch.manual_seed(seed_n)
        torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)


        print('Subject %d' % (i+1))
        exp = ExP(i + 1)

        # bestAcc, averAcc, Y_true, Y_pred = 
        exp.train()
#         print('THE BEST ACCURACY IS ' + str(bestAcc))
#         # result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
#         # result_write.write('Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
#         # result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")

#         endtime = datetime.datetime.now()
#         print('subject %d duration: '%(i+1) + str(endtime - starttime))
#         best = best + bestAcc
#         aver = aver + averAcc
#         if i == 0:
#             yt = Y_true
#             yp = Y_pred
#         else:
#             yt = torch.cat((yt, Y_true))
#             yp = torch.cat((yp, Y_pred))


#     best = best / 9
#     aver = aver / 9
#     print('average best accuracy', best)
#     print('average accuracy', aver)
    # result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
    # result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
    # result_write.close()


if __name__ == "__main__":
    print(time.asctime(time.localtime(time.time())))
    main()
    print(time.asctime(time.localtime(time.time())))

Wed Mar  1 09:54:02 2023
seed is 1585
Subject 1
epoch 0
loss tensor(2.2751, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 1
loss tensor(1.6702, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 2
loss tensor(1.3254, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 3
loss tensor(1.8329, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 4
loss tensor(1.4305, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 5
loss tensor(1.5214, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 6
loss tensor(1.8577, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 7
loss tensor(1.2755, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 8
loss tensor(1.8299, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 9
loss tensor(1.0020, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 10
loss tensor(0.9433, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 11
loss tensor(1.6243, device='cuda:0', grad_fn=<NllLossBackward0>)
epoch 12
loss tensor(0.9573, device='cuda:0', grad_fn=<NllLossBackward0>)


In [49]:
model = Conformer()
#self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
model = model#.cuda()
model.load_state_dict(torch.load('model.pth'))
model.eval()

Conformer(
  (0): PatchEmbedding(
    (shallownet): Sequential(
      (0): Conv2d(1, 40, kernel_size=(1, 25), stride=(1, 1))
      (1): Conv2d(40, 40, kernel_size=(20, 1), stride=(1, 1))
      (2): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): ELU(alpha=1.0)
      (4): AvgPool2d(kernel_size=(1, 75), stride=(1, 15), padding=0)
      (5): Dropout(p=0.5, inplace=False)
    )
    (projection): Sequential(
      (0): Conv2d(40, 40, kernel_size=(1, 1), stride=(1, 1))
      (1): Rearrange('b e (h) (w) -> b (h w) e')
    )
  )
  (1): TransformerEncoder(
    (0): TransformerEncoderBlock(
      (0): ResidualAdd(
        (fn): Sequential(
          (0): LayerNorm((40,), eps=1e-05, elementwise_affine=True)
          (1): MultiHeadAttention(
            (keys): Linear(in_features=40, out_features=40, bias=True)
            (queries): Linear(in_features=40, out_features=40, bias=True)
            (values): Linear(in_features=40, out_features=40, bias=True

In [54]:
pred = []
true =[]
for test_data, test_label, batch_ind1 in dll:
    Tok, Cls = model(test_data)
    # loss_test = self.criterion_cls(Cls, test_label)
    y_pred = torch.max(Cls, 1)[1]
    pred.append(y_pred.numpy())
    true.append(test_label.numpy())

In [59]:
acc = float((np.concatenate(pred) == np.concatenate(true)).sum()) / float(len(true))
acc

2.2

In [61]:
len(np.concatenate(pred))

20