In [81]:
import torch
from tqdm import tqdm
import torch.nn as nn

from load_data import read_data_arrays, data_file_names, standardize_data, data_loader
from models import ChronoNet
from utils import cal_accuracy, evaluate_model 


BATCH_SIZE = 128
#device = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
DEVICE = torch.device("cpu")
NUM_EPOCHS = 5

print("Reading Data....")
data_files = data_file_names()
(train_features, val_features, test_features,
 train_labels, val_labels, test_labels) = read_data_arrays(
    data_files)
    
print("Scaling Data....")
train_features, val_features, test_features = standardize_data(
    train_features, val_features, test_features)
    
print("Data Loader....")
train_iter = data_loader(train_features, train_labels, DEVICE, BATCH_SIZE)
val_iter = data_loader(val_features, val_labels, DEVICE, BATCH_SIZE)
test_iter = data_loader(test_features, test_labels, DEVICE, BATCH_SIZE)
    
print("Training Model....")
n_chans = 19
model=ChronoNet(n_chans)
model.to(DEVICE)
loss_func = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(1, NUM_EPOCHS + 1):
    print("Epoch", epoch) 
    loss_sum, n = 0.0, 0
    model.train()
    for t, (x, y) in enumerate(tqdm(train_iter)):
        y_pred = model(x)
        y_pred = y_pred.squeeze()
        loss = loss_func(y_pred, y)
        loss.backward()
        loss_sum += loss.item()
        optimizer.step()
        optimizer.zero_grad()
    
    val_loss = evaluate_model(model, loss_func, val_iter)
    print("Train loss:", loss_sum / (t+1), "Accuracy: ", 
        cal_accuracy(model, train_iter)[0])
    print("Val loss:", val_loss, ", Accuracy: ", 
        cal_accuracy(model, val_iter)[0])

Reading Data....
Scaling Data....
Data Loader....
Training Model....
Epoch 1


100%|█████████████████████████████████████████| 554/554 [04:33<00:00,  2.02it/s]


Train loss: 0.5995515951826254 Accuracy:  0.8264265536723164
Val loss: 0.6213273151500806 , Accuracy:  0.7582203389830509
Epoch 2


100%|█████████████████████████████████████████| 554/554 [04:06<00:00,  2.25it/s]


Train loss: 0.5771762947098013 Accuracy:  0.8492655367231638
Val loss: 0.6040160601203506 , Accuracy:  0.7727966101694915
Epoch 3


100%|█████████████████████████████████████████| 554/554 [18:30<00:00,  2.01s/it]


Train loss: 0.5705024327827275 Accuracy:  0.8535310734463277
Val loss: 0.6002634596180272 , Accuracy:  0.7733474576271187
Epoch 4


100%|█████████████████████████████████████████| 554/554 [04:19<00:00,  2.13it/s]


Train loss: 0.564884297576622 Accuracy:  0.8598728813559322
Val loss: 0.599231883964023 , Accuracy:  0.7755084745762711
Epoch 5


100%|█████████████████████████████████████████| 554/554 [04:14<00:00,  2.17it/s]


Train loss: 0.5626794706613149 Accuracy:  0.8557627118644068
Val loss: 0.6021302448736655 , Accuracy:  0.7630508474576271


In [82]:
print("Accuracy: ", cal_accuracy(model, test_iter)[0])
print("Confusion Matrix: ", cal_accuracy(model, test_iter)[1])

Accuracy:  0.7957811348563006
Confusion Matrix:  [[16351  1349]
 [ 5302  9566]]
