In [11]:
import datetime
import os
import time
from tqdm import tqdm
import torch
import h5py
import numpy as np
from torchinfo import summary
import math
from torch.autograd import Variable
import pickle

In [12]:
# custom module
import model, dataloader, train
import importlib
importlib.reload(model)
importlib.reload(dataloader)
importlib.reload(train)
from model import *
from dataloader import *
from train import train
from test import test

# Parameters

In [13]:
config = {
    "root" : "./data/pickles/",
    "train_feature" : "Epilepsy_feature.pickle",
    "train_target" : "Epilepsy_target.pickle",
    "test_feature" : "Epilepsy_feature_test.pickle",
    "test_target" : "Epilepsy_target_test.pickle",
    "labelencodedict" : { 'EPILEPSY' : 0,
                    'WALKING' : 1,
                    'RUNNING' :2,
                    'SAWING' :3},
    "epochs" : 3000,
    "learning_rate" : 1e-3,
    # model 
    "first_out_channels" : 32,
    "first_kernel_size" : 5,
    "second_in_channels" : 32,
    "second_out_channels" : 16,
    "second_kernel_size" : 3,
    "dim_model" : 64,
    "dim_inner" : 128,
    "num_heads" : 8,
    "dropout_rate" : 0.15,
    "squeeze_factor" : 2,
    "sec_kernel_size" : 3,
    "sec_stride" : 1,
    "N" : 3,
    "gamma" : 2,
    "u" : 2,     
    "weight_decay" : 0.1,
    "device" : torch.device('cuda' if torch.cuda.is_available() else 'cpu')
}

# Main

In [15]:
if __name__ == "__main__":
    
    epochs = config["epochs"]
    learning_rate = config["learning_rate"]
    weight_decay = config["weight_decay"]
    
    X_train, y_train, train_y, X_test, y_test, test_y = load_dataset(config)

    config["num_samples"] = X_train.shape[0]
    config["first_in_channels"] = X_train.shape[1]
    config["L"] = X_train.shape[2]
    config["length"] = X_train.shape[2]
    
    md = Net(
                dim_model = config["dim_model"],
                gamma = config["gamma"],
                u = config["u"],
                ys = y_train,
                first_in_channels = config["first_in_channels"],
                first_out_channels = config["first_out_channels"],
                first_kernel_size = config["first_kernel_size"],
                second_in_channels = config["second_in_channels"],
                second_out_channels = config["second_out_channels"],
                second_kernel_size = config["second_kernel_size"],
                N = config["N"],  #slice
                L = config["L"], # slice
                dim_inner = config["dim_inner"],
                num_samples = config["num_samples"], 
                length = config["length"], 
                num_heads = config["num_heads"],
                dropout_rate = config["dropout_rate"],
                squeeze_factor = config["squeeze_factor"],
                sec_kernel_size = config["sec_kernel_size"],
                sec_stride = config["sec_stride"],
                num_classes = len(config["labelencodedict"].keys())).to(config['device'])
    optimizer = torch.optim.Adam(md.parameters(),weight_decay=weight_decay, lr=learning_rate)
    criterion = torch.nn.CrossEntropyLoss()
    
    train_loss, train_acc = train(X_train, train_y, optimizer, model = md, epochs = epochs)
    md.ys = y_test
    test_loss, test_acc = test(X_test, test_y, model = md)

  attention_prob = torch.nn.functional.softmax(scaled_att_score)
  final_score = torch.matmul(attention_prob, v)


Epoch: 0, Loss: 1.2092, Train acc: 83.2117%
Epoch: 1, Loss: 1.0545, Train acc: 86.8613%
Epoch: 2, Loss: 0.8499, Train acc: 91.2409%
Epoch: 3, Loss: 0.6755, Train acc: 90.5109%
Epoch: 4, Loss: 0.5625, Train acc: 92.3358%
Epoch: 5, Loss: 0.4849, Train acc: 91.6058%
Epoch: 6, Loss: 0.4020, Train acc: 91.6058%
Epoch: 7, Loss: 0.3560, Train acc: 93.7956%
Epoch: 8, Loss: 0.3351, Train acc: 94.5255%
Epoch: 9, Loss: 0.2932, Train acc: 95.2555%
Epoch: 10, Loss: 0.2505, Train acc: 96.7153%
Epoch: 11, Loss: 0.2125, Train acc: 95.9854%
Epoch: 12, Loss: 0.1563, Train acc: 97.4453%
Epoch: 13, Loss: 0.1283, Train acc: 99.6350%
Epoch: 14, Loss: 0.1468, Train acc: 98.9051%
Epoch: 15, Loss: 0.1024, Train acc: 98.9051%
Epoch: 16, Loss: 0.0881, Train acc: 98.9051%
Epoch: 17, Loss: 0.0696, Train acc: 98.9051%
Epoch: 18, Loss: 0.0601, Train acc: 99.6350%
Epoch: 19, Loss: 0.0573, Train acc: 99.6350%
Epoch: 20, Loss: 0.0585, Train acc: 99.2701%
Epoch: 21, Loss: 0.0355, Train acc: 99.6350%
Epoch: 22, Loss: 0.0

KeyboardInterrupt: 