In [None]:
import numpy as np
import math
import glob
import random
import itertools
import datetime
import time
import datetime
import sys
import scipy.io
from torch.utils.data import DataLoader
from torch.autograd import Variable
# from torchsummary import summary
import torch.autograd as autograd

In [None]:
import mne
from torch.utils.data import DataLoader

from braindecode.datasets import TUH
from braindecode.preprocessing import create_fixed_length_windows

mne.set_log_level('ERROR')  # avoid messages everytime a window is extracted

In [None]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

from torch import nn
from torch import Tensor

In [None]:
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce



In [None]:
import matplotlib.pyplot as plt
# from torch.utils.tensorboard import SummaryWriter
from torch.backends import cudnn
cudnn.benchmark = False
cudnn.deterministic = True

In [None]:
from braindecode.datasets.tuh import _TUHMock as TUH  # noqa F811

In [None]:
TUH_PATH = 'edf/train/'
tuh = TUH(
    path=TUH_PATH,
    recording_ids=None,
    target_name=('gender'),  # use both age and gender as decoding target
    preload=False,
    add_physician_reports=False,
)
tuh.description

In [None]:
x, y = tuh[-1]
print('x:', x)
print('y:', y)

In [None]:
tuh_windows = create_fixed_length_windows(
    tuh,
    start_offset_samples=0,
    stop_offset_samples=None,
    window_size_samples=1000,
    window_stride_samples=1000,
    drop_last_window=False,
    mapping={'M': 0, 'F': 1},  # map non-digit targets
)
# store the number of windows required for loading later on
tuh_windows.set_description({
    "n_windows": [len(d) for d in tuh_windows.datasets]})

In [None]:
x, y, ind = tuh_windows[-1]
print('x:', x)
print('y:', y)
print('ind:', ind)

In [None]:
dl = DataLoader(
    dataset=tuh_windows,
    batch_size=4,
)
for batch_X, batch_y, batch_ind in dl:
    pass
print('batch_X:', batch_X)
print('batch_y:', batch_y)
print('batch_ind:', batch_ind)

In [None]:
TUH_PATH = 'edf/dev/'
tuh1 = TUH(
    path=TUH_PATH,
    recording_ids=None,
    target_name=('gender'),  # use both age and gender as decoding target
    preload=False,
    add_physician_reports=False,
)
tuh1.description

In [None]:
tuh_windows1 = create_fixed_length_windows(
    tuh1,
    start_offset_samples=0,
    stop_offset_samples=None,
    window_size_samples=1000,
    window_stride_samples=1000,
    drop_last_window=False,
    mapping={'M': 0, 'F': 1},  # map non-digit targets
)
# store the number of windows required for loading later on
tuh_windows1.set_description({
    "n_windows": [len(d) for d in tuh_windows1.datasets]})

In [None]:
tuh_windows1

In [None]:
dll = DataLoader(
    dataset=tuh_windows1,
    batch_size=4,
)
for batch_X1, batch_y1, batch_ind1 in dll:
    batch_X1
print('batch_X:', batch_X.shape)
print('batch_y:', batch_y)
print('batch_ind:', batch_ind)

In [None]:
class PatchEmbedding(nn.Module):
    def __init__(self, emb_size=40):
        # self.patch_size = patch_size
        super().__init__()

        self.shallownet = nn.Sequential(
            nn.Conv2d(1, 40, (1, 25), (1, 1)),
            nn.Conv2d(40, 40, (20, 1), (1, 1)),
            nn.BatchNorm2d(40),
            nn.ELU(),
            nn.AvgPool2d((1, 75), (1, 15)),  # pooling acts as slicing to obtain 'patch' along the time dimension as in ViT
            nn.Dropout(0.5),
        )

        self.projection = nn.Sequential(
            nn.Conv2d(40, emb_size, (1, 1), stride=(1, 1)),  # transpose, conv could enhance fiting ability slightly
            Rearrange('b e (h) (w) -> b (h w) e'),
        )


    def forward(self, x: Tensor) -> Tensor:
        b, _, _ = x.shape
        # print('x',x.shape)
        x = x[ :,None, :, :]
        # print('x',x.shape)
        x = self.shallownet(x)
        x = self.projection(x)
        return x


class MultiHeadAttention(nn.Module):
    def __init__(self, emb_size, num_heads, dropout):
        super().__init__()
        self.emb_size = emb_size
        self.num_heads = num_heads
        self.keys = nn.Linear(emb_size, emb_size)
        self.queries = nn.Linear(emb_size, emb_size)
        self.values = nn.Linear(emb_size, emb_size)
        self.att_drop = nn.Dropout(dropout)
        self.projection = nn.Linear(emb_size, emb_size)

    def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
        queries = rearrange(self.queries(x), "b n (h d) -> b h n d", h=self.num_heads)
        keys = rearrange(self.keys(x), "b n (h d) -> b h n d", h=self.num_heads)
        values = rearrange(self.values(x), "b n (h d) -> b h n d", h=self.num_heads)
        energy = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)  
        if mask is not None:
            fill_value = torch.finfo(torch.float32).min
            energy.mask_fill(~mask, fill_value)

        scaling = self.emb_size ** (1 / 2)
        att = F.softmax(energy / scaling, dim=-1)
        att = self.att_drop(att)
        out = torch.einsum('bhal, bhlv -> bhav ', att, values)
        out = rearrange(out, "b h n d -> b n (h d)")
        out = self.projection(out)
        return out


class ResidualAdd(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn

    def forward(self, x, **kwargs):
        res = x
        x = self.fn(x, **kwargs)
        x += res
        return x


class FeedForwardBlock(nn.Sequential):
    def __init__(self, emb_size, expansion, drop_p):
        super().__init__(
            nn.Linear(emb_size, expansion * emb_size),
            nn.GELU(),
            nn.Dropout(drop_p),
            nn.Linear(expansion * emb_size, emb_size),
        )


class GELU(nn.Module):
    def forward(self, input: Tensor) -> Tensor:
        return input*0.5*(1.0+torch.erf(input/math.sqrt(2.0)))


class TransformerEncoderBlock(nn.Sequential):
    def __init__(self,
                 emb_size,
                 num_heads=10,
                 drop_p=0.5,
                 forward_expansion=4,
                 forward_drop_p=0.5):
        super().__init__(
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                MultiHeadAttention(emb_size, num_heads, drop_p),
                nn.Dropout(drop_p)
            )),
            ResidualAdd(nn.Sequential(
                nn.LayerNorm(emb_size),
                FeedForwardBlock(
                    emb_size, expansion=forward_expansion, drop_p=forward_drop_p),
                nn.Dropout(drop_p)
            )
            ))


class TransformerEncoder(nn.Sequential):
    def __init__(self, depth, emb_size):
        super().__init__(*[TransformerEncoderBlock(emb_size) for _ in range(depth)])


class ClassificationHead(nn.Sequential):
    def __init__(self, emb_size, n_classes):
        super().__init__()
        
        # global average pooling
        self.clshead = nn.Sequential(
            Reduce('b n e -> b e', reduction='mean'),
            nn.LayerNorm(emb_size),
            nn.Linear(emb_size, n_classes)
        )
        self.fc = nn.Sequential(
            nn.Linear(4880, 256),#4x4880
            nn.ELU(),
            nn.Dropout(0.5),
            nn.Linear(256, 32),
            nn.ELU(),
            nn.Dropout(0.3),
            nn.Linear(32, 4)
        )

    def forward(self, x):
        x = x.contiguous().view(x.size(0), -1)
        out = self.fc(x)
        return x, out


class Conformer(nn.Sequential):
    def __init__(self, emb_size=40, depth=6, n_classes=2, **kwargs):
        super().__init__(

            PatchEmbedding(emb_size),
            TransformerEncoder(depth, emb_size),
            ClassificationHead(emb_size, n_classes)
        )





In [None]:
class ExP():
    def __init__(self, nsub):
        super(ExP, self).__init__()
        self.batch_size = 72
        self.n_epochs = 500
        self.c_dim = 4
        self.lr = 0.0002
        self.b1 = 0.5
        self.b2 = 0.999
        self.dimension = (190, 50)
        self.nSub = nsub

        self.start_epoch = 0
        self.root = '/Data/strict_TE/'

        # self.log_write = open("./results/log_subject%d.txt" % self.nSub, "w")


        self.Tensor = torch.cuda.FloatTensor
        self.LongTensor = torch.cuda.LongTensor

        self.criterion_l1 = torch.nn.L1Loss().cuda()
        self.criterion_l2 = torch.nn.MSELoss().cuda()
        self.criterion_cls = torch.nn.CrossEntropyLoss().cuda()
        gpus = [0]
        self.model = Conformer().cuda()
        self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
        self.model = self.model.cuda()
        # summary(self.model, (1, 22, 1000))





    def train(self):


        self.dataloader = DataLoader(dataset=tuh_windows,batch_size=4,)#torch.utils.data.DataLoader(dataset=dataset, batch_size=self.batch_size, shuffle=True)

        self.test_dataloader = DataLoader(dataset=tuh_windows1,batch_size=4,)#torch.utils.data.DataLoader(dataset=test_dataset, batch_size=self.batch_size, shuffle=True)

        # Optimizers
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=self.lr, betas=(self.b1, self.b2))

        bestAcc = 0
        averAcc = 0
        num = 0
        Y_true = 0
        Y_pred = 0

        # Train the cnn model
        total_step = len(self.dataloader)
        curr_lr = self.lr

        for e in range(self.n_epochs):
            pred = []
            true = []
            
            # in_epoch = time.time()
            self.model.train()
            for img, label, i in self.dataloader:

                img = Variable(img.cuda().type(self.Tensor))
                label = Variable(label.cuda().type(self.LongTensor))

                # data augmentation
                # aug_data, aug_label = self.interaug(self.allData, self.allLabel)
                # img = torch.cat((img, aug_data))
                # label = torch.cat((label, aug_label))


                tok, outputs = self.model(img)

                loss = self.criterion_cls(outputs, label) 

                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

            print('epoch',e)
            print('loss', loss)
            # out_epoch = time.time()

#             for test_data, test_label, i in self.test_dataloader:
#             # test process
#                 if (e + 1) % 1 == 0:
#                     self.model.eval()
#                     Tok, Cls = self.model(test_data)


#                     loss_test = self.criterion_cls(Cls, test_label)
#                     y_pred = torch.max(Cls, 1)[1]
#                     pred.append(y_pred)
#                     true.append(test_label)
                    
#             acc = float((pred == true).cpu().numpy().astype(int).sum()) / float(len(true))
#             # train_pred = torch.max(outputs, 1)[1]
#             # train_acc = float((train_pred == label).cpu().numpy().astype(int).sum()) / float(label.size(0))

#             print('Epoch:', e,
#                   '  Train loss: %.6f' % loss.detach().cpu().numpy(),
#                   '  Test loss: %.6f' % loss_test.detach().cpu().numpy(),
#                   # '  Train accuracy %.6f' % train_acc,
#                   '  Test accuracy is %.6f' % acc)

#             # self.log_write.write(str(e) + "    " + str(acc) + "\n")
#             num = num + 1
#             averAcc = averAcc + acc
#             if acc > bestAcc:
#                 bestAcc = acc
#                 Y_true = true
#                 Y_pred = pred


        torch.save(self.model.module.state_dict(), 'model.pth')
#         averAcc = averAcc / num
#         print('The average accuracy is:', averAcc)
#         print('The best accuracy is:', bestAcc)
#         # self.log_write.write('The average accuracy is: ' + str(averAcc) + "\n")
#         # self.log_write.write('The best accuracy is: ' + str(bestAcc) + "\n")

#         return bestAcc, averAcc, Y_true, Y_pred
#         # writer.close()




In [None]:
def main():
    best = 0
    aver = 0
    # result_write = open("preprocess/sub_result.txt", "w")

    for i in range(1):
        starttime = datetime.datetime.now()


        seed_n = np.random.randint(2021)
        print('seed is ' + str(seed_n))
        random.seed(seed_n)
        np.random.seed(seed_n)
        torch.manual_seed(seed_n)
        torch.cuda.manual_seed(seed_n)
        torch.cuda.manual_seed_all(seed_n)


        print('Subject %d' % (i+1))
        exp = ExP(i + 1)

        # bestAcc, averAcc, Y_true, Y_pred = 
        exp.train()
#         print('THE BEST ACCURACY IS ' + str(bestAcc))
#         # result_write.write('Subject ' + str(i + 1) + ' : ' + 'Seed is: ' + str(seed_n) + "\n")
#         # result_write.write('Subject ' + str(i + 1) + ' : ' + 'The best accuracy is: ' + str(bestAcc) + "\n")
#         # result_write.write('Subject ' + str(i + 1) + ' : ' + 'The average accuracy is: ' + str(averAcc) + "\n")

#         endtime = datetime.datetime.now()
#         print('subject %d duration: '%(i+1) + str(endtime - starttime))
#         best = best + bestAcc
#         aver = aver + averAcc
#         if i == 0:
#             yt = Y_true
#             yp = Y_pred
#         else:
#             yt = torch.cat((yt, Y_true))
#             yp = torch.cat((yp, Y_pred))


#     best = best / 9
#     aver = aver / 9
#     print('average best accuracy', best)
#     print('average accuracy', aver)
    # result_write.write('**The average Best accuracy is: ' + str(best) + "\n")
    # result_write.write('The average Aver accuracy is: ' + str(aver) + "\n")
    # result_write.close()


if __name__ == "__main__":
    print(time.asctime(time.localtime(time.time())))
    main()
    print(time.asctime(time.localtime(time.time())))

In [None]:
model = Conformer()
#self.model = nn.DataParallel(self.model, device_ids=[i for i in range(len(gpus))])
model = model#.cuda()
model.load_state_dict(torch.load('model.pth'))
model.eval()

In [None]:
pred = []
true =[]
for test_data, test_label, batch_ind1 in dll:
    Tok, Cls = model(test_data)
    # loss_test = self.criterion_cls(Cls, test_label)
    y_pred = torch.max(Cls, 1)[1]
    pred.append(y_pred.numpy())
    true.append(test_label.numpy())

In [None]:
acc = float((np.concatenate(pred) == np.concatenate(true)).sum()) / float(len(true))
acc

In [None]:
len(np.concatenate(pred))