In [1]:
from __future__ import print_function

import glob
from itertools import chain
import os
import random
import zipfile
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from linformer import Linformer
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
from tqdm.notebook import tqdm
from util import data_creator2,EEGDataset,Resize
from vit_pytorch.efficient import ViT
from sklearn.preprocessing import normalize

In [2]:
batch_size = 64
epochs = 5
lr = 0.001
gamma = 0.7
seed = 42
interval=1250
aug=2
modelname='model_5s'
inpath='data_in2'
out_path='data_out4'
out_path_F='data_out_F'
num_file=np.int16(np.array([11,7,8,8]))
num_file2=np.int16(np.array([19,16]))

In [4]:


data_creator2(inpath,out_path,num_file=num_file2,interval=interval,aug=aug,do_fft=False,stack_3_ch=False)
#data_creator(inpath,out_path,num_file=num_file,interval=interval,aug=aug,do_fft=True)


(27745493, 8)
(27082764, 8)


1

In [5]:
#Train Model in time domain: Set up
def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu') # don't have GPU 
    return device
device = get_device()
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)
EEGtransforms = transforms.Compose([
    transforms.ToTensor(),
    Resize(224)
])
data_train = EEGDataset(os.path.join("data_out4", "train"), transform=EEGtransforms)
data_test = EEGDataset(os.path.join("data_out4", "test"), transform=EEGtransforms)
train_loader = DataLoader(dataset=data_train, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=data_test, batch_size=batch_size, shuffle=True)

efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  
    depth=12,
    heads=8,
    k=64
)
model = ViT(
    dim=128,
    image_size=256,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=1,
).to(device)
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)


In [6]:
# Time Domain Train Model
best_val_acc = None
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    model.train()
    count=0
    prob_0=None
    #pred = None
    for data, label in train_loader:
        data = data.float().to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = output.argmax(dim = 1)
        curr_prob_0=output[:,0]

        #curr_pred = output.argmax(dim = 1)
        #if pred:
        #    pred = torch.cat(curr_pred.cpu().numpy())
      #  else:
       #     pred = curr_pred
      #  pred = pred.numpy()
        
        acc = (pred == label).float().mean()
        #curr_tp_count = (pred == 1) and (label == 1)
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
    with torch.no_grad():
        model.eval()
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        tp_count=0
        fp_count=0
        tn_count=0
        fn_count=0
        prob_1=None
        whole_label=None
        #best_val_acc=0
        for data, label in test_loader:
            #print(data.shape)
            data = data.float().to(device)
            label = label.to(device)
                        
            val_output = model(data)
            pred=val_output.argmax(dim=1)
            val_loss = criterion(val_output, label)
            tp_count+=torch.logical_and(pred == 1, label == 1).sum()
            fp_count+=torch.logical_and(pred == 1, label == 0).sum()
            tn_count+=torch.logical_and(pred == 0, label == 0).sum()
            fn_count+=torch.logical_and(pred == 0, label == 1).sum()
            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(test_loader)
            epoch_val_loss += val_loss / len(test_loader)
            #curr_pred = val_output.argmax(dim = 1)
            curr_prob_1=val_output[:,1]
            if prob_1 == None:
                
                prob_1=curr_prob_1
            else:
                prob_1=torch.cat((prob_1,curr_prob_1))
            if whole_label==None:
                
                whole_label=label
            else:
                whole_label=torch.cat((whole_label,label))
        tpr=tp_count/(tp_count+fn_count)
        fpr=fp_count/(tn_count+fp_count)
        #tn=tn_count/(tn_count+fp_count)
        #fn=fn_count/(tn_count+fn_count)            
        if best_val_acc:
            if best_val_acc < epoch_val_accuracy:
                best_val_acc = epoch_val_accuracy
                torch.save(model,"model_Time.pt")
                optm_tpr=tp_count/(tp_count+fn_count)
                optm_fpr=fp_count/(tn_count+fp_count)
                optm_true_label=whole_label.cpu().numpy()
                optm_prob_1=prob_1.cpu().numpy()
                #optm_tn=tn_count/(fp_count+tn_count)
                #optm_fn=fn_count/(tn_count+fn_count)
        else:
            best_val_acc = epoch_val_accuracy
            torch.save(model,"model_Time.pt")
            optm_tpr=tp_count/(tp_count+fn_count)
            optm_fpr=fp_count/(fp_count+tn_count)
            #optm_tn=tn_count/(tn_count+fn_count)
            optm_true_label=whole_label.cpu().numpy()
            optm_prob_1=prob_1.cpu().numpy()            
        print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}"
    )
    print(f"tp : {tp_count} - fp : {fp_count} - tn: {tn_count} - fn : {fn_count} - tpr:{tpr:.4f}- fpr: {fpr:.4f}\n")

Epoch : 1 - loss : 0.4460 - acc: 0.7856 - val_loss : 0.3859 - val_acc: 0.8271
tp : 18505 - fp : 4429 - tn: 17763 - fn : 3157 - tpr:0.8543- fpr: 0.1996

Epoch : 2 - loss : 0.2374 - acc: 0.9003 - val_loss : 0.2595 - val_acc: 0.8912
tp : 19560 - fp : 2670 - tn: 19522 - fn : 2102 - tpr:0.9030- fpr: 0.1203

Epoch : 3 - loss : 0.1688 - acc: 0.9332 - val_loss : 0.2704 - val_acc: 0.8964
tp : 20489 - fp : 3375 - tn: 18817 - fn : 1173 - tpr:0.9458- fpr: 0.1521

Epoch : 4 - loss : 0.1318 - acc: 0.9496 - val_loss : 0.2402 - val_acc: 0.9057
tp : 20296 - fp : 2768 - tn: 19424 - fn : 1366 - tpr:0.9369- fpr: 0.1247

Epoch : 5 - loss : 0.1111 - acc: 0.9584 - val_loss : 0.3103 - val_acc: 0.8905
tp : 20059 - fp : 3193 - tn: 18999 - fn : 1603 - tpr:0.9260- fpr: 0.1439



In [4]:
#Frequency Domain
def get_device():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu') # don't have GPU 
    return device

device = get_device()
def seed_everything(seed):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

seed_everything(seed)

train_transforms = transforms.Compose([
    transforms.ToTensor(),
    Resize(224)
])
data_train2 = EEGDataset(os.path.join("data_out3", "train"), transform=train_transforms)
data_test2 = EEGDataset(os.path.join("data_out3", "test"), transform=train_transforms)
train_loader2 = DataLoader(dataset=data_train2, batch_size=batch_size, shuffle=True)
test_loader2 = DataLoader(dataset=data_test2, batch_size=batch_size, shuffle=True)

efficient_transformer = Linformer(
    dim=128,
    seq_len=49+1,  
    depth=12,
    heads=8,
    k=64
)
model = ViT(
    dim=128,
    image_size=256,
    patch_size=32,
    num_classes=2,
    transformer=efficient_transformer,
    channels=1,
).to(device)
# loss function
criterion = nn.CrossEntropyLoss()
# optimizer
optimizer = optim.Adam(model.parameters(), lr=lr)
# scheduler
scheduler = StepLR(optimizer, step_size=1, gamma=gamma)

In [7]:
#Frequency Domain
best_val_acc = None
for epoch in range(epochs):
    epoch_loss = 0
    epoch_accuracy = 0
    model.train()
    count=0
    prob_0=None
    #pred = None
    for data, label in train_loader:
        data = data.float().to(device)
        label = label.to(device)
        output = model(data)
        loss = criterion(output, label).to(device)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        pred = output.argmax(dim = 1)
        curr_prob_0=output[:,0]

        #curr_pred = output.argmax(dim = 1)
        #if pred:
        #    pred = torch.cat(curr_pred.cpu().numpy())
      #  else:
       #     pred = curr_pred
      #  pred = pred.numpy()
        
        acc = (pred == label).float().mean()
        #curr_tp_count = (pred == 1) and (label == 1)
        epoch_accuracy += acc / len(train_loader)
        epoch_loss += loss / len(train_loader)
    with torch.no_grad():
        model.eval()
        epoch_val_accuracy = 0
        epoch_val_loss = 0
        tp_count=0
        fp_count=0
        tn_count=0
        fn_count=0
        prob_1=None
        whole_label=None
        #best_val_acc=0
        for data, label in test_loader:
            #print(data.shape)
            data = data.float().to(device)
            label = label.to(device)
                        
            val_output = model(data)
            pred=val_output.argmax(dim=1)
            val_loss = criterion(val_output, label)
            tp_count+=torch.logical_and(pred == 1, label == 1).sum()
            fp_count+=torch.logical_and(pred == 1, label == 0).sum()
            tn_count+=torch.logical_and(pred == 0, label == 0).sum()
            fn_count+=torch.logical_and(pred == 0, label == 1).sum()
            acc = (val_output.argmax(dim=1) == label).float().mean()
            epoch_val_accuracy += acc / len(test_loader)
            epoch_val_loss += val_loss / len(test_loader)
            #curr_pred = val_output.argmax(dim = 1)
            curr_prob_1=val_output[:,1]
            if prob_1 == None:
                
                prob_1=curr_prob_1
            else:
                prob_1=torch.cat((prob_1,curr_prob_1))
            if whole_label==None:
                
                whole_label=label
            else:
                whole_label=torch.cat((whole_label,label))
        tpr=tp_count/(tp_count+fn_count)
        fpr=fp_count/(tn_count+fp_count)
        #tn=tn_count/(tn_count+fp_count)
        #fn=fn_count/(tn_count+fn_count)            
        if best_val_acc:
            if best_val_acc < epoch_val_accuracy:
                best_val_acc = epoch_val_accuracy
                torch.save(model,"model_Frequency.pt")
                optm_tpr=tp_count/(tp_count+fn_count)
                optm_fpr=fp_count/(tn_count+fp_count)
                optm_true_label=whole_label.cpu().numpy()
                optm_prob_1=prob_1.cpu().numpy()
                #optm_tn=tn_count/(fp_count+tn_count)
                #optm_fn=fn_count/(tn_count+fn_count)
        else:
            best_val_acc = epoch_val_accuracy
            torch.save(model,"model_Frequency.pt")
            optm_tpr=tp_count/(tp_count+fn_count)
            optm_fpr=fp_count/(fp_count+tn_count)
            #optm_tn=tn_count/(tn_count+fn_count)
            optm_true_label=whole_label.cpu().numpy()
            optm_prob_1=prob_1.cpu().numpy()            
        print(
        f"Epoch : {epoch+1} - loss : {epoch_loss:.4f} - acc: {epoch_accuracy:.4f} - val_loss : {epoch_val_loss:.4f} - val_acc: {epoch_val_accuracy:.4f}"
    )
    print(f"tp : {tp_count} - fp : {fp_count} - tn: {tn_count} - fn : {fn_count} - tpr:{tpr:.4f}- fpr: {fpr:.4f}\n")

Epoch : 1 - loss : 0.6084 - acc: 0.6597 - val_loss : 0.5510 - val_acc: 0.7186
tp : 15126 - fp : 5808 - tn: 16384 - fn : 6536 - tpr:0.6983- fpr: 0.2617

Epoch : 2 - loss : 0.5328 - acc: 0.7265 - val_loss : 0.5240 - val_acc: 0.7334
tp : 16237 - fp : 6273 - tn: 15919 - fn : 5425 - tpr:0.7496- fpr: 0.2827

Epoch : 3 - loss : 0.4977 - acc: 0.7526 - val_loss : 0.5077 - val_acc: 0.7502
tp : 16462 - fp : 5764 - tn: 16428 - fn : 5200 - tpr:0.7599- fpr: 0.2597

Epoch : 4 - loss : 0.4631 - acc: 0.7740 - val_loss : 0.4708 - val_acc: 0.7721
tp : 16043 - fp : 4373 - tn: 17819 - fn : 5619 - tpr:0.7406- fpr: 0.1971



KeyboardInterrupt: 