In [2]:
import numpy as np
import os
import random
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from vit import ViT
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from torch_loader import *
from bhutan_loader import *

def prepare_dataloaders(batch_size, dataset = remaining_dataset):
    # #IF SLEEP-EDF
    # batch_size = 128
    # train_size = int(0.8 * len(remaining_dataset))
    # test_size = len(remaining_dataset) - train_size

    # train_dataset, test_dataset = torch.utils.data.random_split(remaining_dataset,
    #                                                             [train_size, test_size])
    # train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)#, num_workers=0)
    # test_loader = DataLoader(test_dataset, batch_size = batch_size, shuffle = False)#, num_workers=4)
    
    #IF BHUTAN
    batch_size = 25
    train_valid_size = int(0.85 * len(Bhutan_dataset))
    test_size = len(Bhutan_dataset) - train_valid_size
    generator1 = torch.Generator().manual_seed(0)
    train_valid_dataset, test_dataset = torch.utils.data.random_split(
    Bhutan_dataset, 
    [train_valid_size, test_size], 
    generator=generator1
    )

    train_size = int(0.70 * len(Bhutan_dataset))
    valid_size = len(train_valid_dataset) - train_size
    train_dataset, valid_dataset = torch.utils.data.random_split(
    train_valid_dataset,
    [train_size, valid_size],
    )

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=3)
    valid_loader = DataLoader(valid_dataset, batch_size=len(valid_dataset), shuffle=False, num_workers=3)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=3)
    
    return train_loader, test_loader, valid_loader, train_dataset, test_dataset, valid_dataset

#IF BHUTAN, 19 channels and batch 25 with 'cls' and 'fixed', if SLEEP-EDF 3 channels and 128 batch with 'cls' and 'learnable'.
def main(image_size=(224,224), patch_size=(56,56), channels=19, 
         embed_dim=224, num_heads=4, num_layers=4, num_classes=5,
         pos_enc='fixed', pool='cls', dropout=0.3, fc_dim=None, 
         num_epochs=50, batch_size=25, lr=1e-3, warmup_steps=625,
         weight_decay=0.1, gradient_clipping=1
    ):

    out_dict = {'train_acc': [],
              'test_acc': [],
              'train_loss': [],
              'test_loss': [],
              'y_pred': [],
              'y_true': []}

    loss_function = nn.CrossEntropyLoss()

    train_iter, test_iter, valid_iter, train_set, test_set, valid_set = prepare_dataloaders(batch_size=batch_size, dataset = Bhutan_dataset)

    model = ViT(image_size=image_size, patch_size=patch_size, channels=channels, 
                embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers,
                pos_enc=pos_enc, pool=pool, dropout=dropout, fc_dim=fc_dim, 
                num_classes=num_classes
    )

    if torch.cuda.is_available():
        model = model.to('cuda')

    opt = torch.optim.AdamW(lr=lr, params=model.parameters(), weight_decay=weight_decay)
    sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(i / warmup_steps, 1.0))

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # training loop
    best_val_loss = 1e10
    for e in range(num_epochs):
        print(f'\n epoch {e}')
        model.train()
        train_loss = 0
        for image, label in tqdm.tqdm(train_iter):
            if torch.cuda.is_available():
                image, label = image.to('cuda'), label.to('cuda')
            opt.zero_grad()
            out, atten_matix_list = model(image)
            loss = loss_function(out, label)
            loss.backward()
            train_loss += loss.item()
            # if the total gradient vector has a length > 1, we clip it back down to 1.
            if gradient_clipping > 0.0:
                nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
            opt.step()
            sch.step()

        train_loss/=len(train_iter)

        val_loss = 0
        with torch.no_grad():
            model.eval()
            tot, cor= 0.0, 0.0
            y_true = []
            y_pred = []
            for image, label in test_iter:
                if torch.cuda.is_available():
                    image, label = image.to('cuda'), label.to('cuda')
                out, atten_matrix_list = model(image)
                loss = loss_function(out, label)
                val_loss += loss.item()
                out = out.argmax(dim=1)
                y_pred.extend(out.cpu().numpy())
                y_true.extend(label.cpu().numpy())
                tot += float(image.size(0))
                cor += float((label == out).sum().item())
            acc = cor / tot
            val_loss /= len(test_iter)
            std = 1 - accuracy_score(y_true, y_pred)
            sem = std / np.sqrt(len(y_true))
            sem_p = (sem / 1)
            print(f'-- train loss {train_loss:.3f} -- test accuracy {acc:.3f} -- test loss: {val_loss:.3f} -- SEM {sem_p * 100:.3f}')
            if val_loss <= best_val_loss:
                torch.save(model.state_dict(), 'model2_ViT_Bhutan.pth')
                best_val_loss = val_loss
            
        out_dict['y_pred'].append(y_pred)
        out_dict['y_true'].append(y_true)

        # save model
        torch.save(model, 'model2_ViT_Bhutan.pth')
    return out_dict

In [3]:
if __name__ == "__main__":
    #os.environ["CUDA_VISIBLE_DEVICES"]= str(0)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
    print(f"Model will run on {device}")
    # set_seed(seed=1) #only apply seed if you want to train each time from same position (seed in dimensions). 
    main()

Model will run on cpu

 epoch 0


100%|██████████| 31/31 [00:02<00:00, 10.60it/s]


-- train loss 1.728 -- validation accuracy 0.146 -- validation loss: 1.504

 epoch 1


100%|██████████| 31/31 [00:02<00:00, 10.61it/s]


-- train loss 1.681 -- validation accuracy 0.526 -- validation loss: 1.366

 epoch 2


100%|██████████| 31/31 [00:03<00:00,  8.90it/s]


-- train loss 1.587 -- validation accuracy 0.526 -- validation loss: 1.429

 epoch 3


100%|██████████| 31/31 [00:03<00:00,  9.68it/s]


-- train loss 1.529 -- validation accuracy 0.526 -- validation loss: 1.674

 epoch 4


100%|██████████| 31/31 [00:02<00:00, 11.17it/s]


-- train loss 1.510 -- validation accuracy 0.526 -- validation loss: 1.794

 epoch 5


100%|██████████| 31/31 [00:02<00:00, 11.28it/s]


-- train loss 1.495 -- validation accuracy 0.526 -- validation loss: 1.827

 epoch 6


100%|██████████| 31/31 [00:02<00:00, 11.24it/s]


-- train loss 1.436 -- validation accuracy 0.526 -- validation loss: 1.674

 epoch 7


100%|██████████| 31/31 [00:02<00:00, 11.06it/s]


-- train loss 1.448 -- validation accuracy 0.526 -- validation loss: 1.687

 epoch 8


100%|██████████| 31/31 [00:02<00:00, 10.90it/s]


-- train loss 1.419 -- validation accuracy 0.526 -- validation loss: 1.553

 epoch 9


100%|██████████| 31/31 [00:02<00:00, 10.51it/s]


-- train loss 1.408 -- validation accuracy 0.526 -- validation loss: 1.481

 epoch 10


100%|██████████| 31/31 [00:02<00:00, 11.21it/s]


-- train loss 1.376 -- validation accuracy 0.518 -- validation loss: 1.414

 epoch 11


100%|██████████| 31/31 [00:02<00:00, 11.34it/s]


-- train loss 1.358 -- validation accuracy 0.538 -- validation loss: 1.474

 epoch 12


100%|██████████| 31/31 [00:02<00:00, 10.98it/s]


-- train loss 1.309 -- validation accuracy 0.555 -- validation loss: 1.668

 epoch 13


100%|██████████| 31/31 [00:02<00:00, 11.25it/s]


-- train loss 1.295 -- validation accuracy 0.538 -- validation loss: 1.494

 epoch 14


100%|██████████| 31/31 [00:02<00:00, 11.08it/s]


-- train loss 1.276 -- validation accuracy 0.543 -- validation loss: 1.419

 epoch 15


100%|██████████| 31/31 [00:02<00:00, 11.16it/s]


-- train loss 1.256 -- validation accuracy 0.543 -- validation loss: 1.585

 epoch 16


100%|██████████| 31/31 [00:02<00:00, 11.18it/s]


-- train loss 1.230 -- validation accuracy 0.543 -- validation loss: 1.431

 epoch 17


100%|██████████| 31/31 [00:02<00:00, 11.04it/s]


-- train loss 1.193 -- validation accuracy 0.559 -- validation loss: 1.241

 epoch 18


100%|██████████| 31/31 [00:02<00:00, 11.18it/s]


-- train loss 1.114 -- validation accuracy 0.619 -- validation loss: 1.061

 epoch 19


100%|██████████| 31/31 [00:02<00:00, 11.14it/s]


-- train loss 0.988 -- validation accuracy 0.623 -- validation loss: 1.062

 epoch 20


100%|██████████| 31/31 [00:02<00:00, 11.17it/s]


-- train loss 0.968 -- validation accuracy 0.656 -- validation loss: 1.069

 epoch 21


100%|██████████| 31/31 [00:02<00:00, 11.23it/s]


-- train loss 0.965 -- validation accuracy 0.615 -- validation loss: 0.997

 epoch 22


100%|██████████| 31/31 [00:02<00:00, 10.92it/s]


-- train loss 0.918 -- validation accuracy 0.664 -- validation loss: 0.986

 epoch 23


100%|██████████| 31/31 [00:02<00:00, 11.27it/s]


-- train loss 0.908 -- validation accuracy 0.538 -- validation loss: 1.146

 epoch 24


100%|██████████| 31/31 [00:02<00:00, 11.15it/s]


-- train loss 0.915 -- validation accuracy 0.636 -- validation loss: 0.987

 epoch 25


100%|██████████| 31/31 [00:02<00:00, 10.61it/s]


-- train loss 0.946 -- validation accuracy 0.632 -- validation loss: 0.919

 epoch 26


100%|██████████| 31/31 [00:02<00:00, 10.66it/s]


-- train loss 0.933 -- validation accuracy 0.664 -- validation loss: 1.124

 epoch 27


100%|██████████| 31/31 [00:02<00:00, 10.76it/s]


-- train loss 0.913 -- validation accuracy 0.640 -- validation loss: 0.935

 epoch 28


100%|██████████| 31/31 [00:02<00:00, 10.86it/s]


-- train loss 0.884 -- validation accuracy 0.583 -- validation loss: 1.041

 epoch 29


100%|██████████| 31/31 [00:02<00:00, 10.86it/s]


-- train loss 0.919 -- validation accuracy 0.648 -- validation loss: 0.961

 epoch 30


100%|██████████| 31/31 [00:02<00:00, 10.58it/s]


-- train loss 0.849 -- validation accuracy 0.644 -- validation loss: 0.979

 epoch 31


100%|██████████| 31/31 [00:02<00:00, 11.01it/s]


-- train loss 0.870 -- validation accuracy 0.632 -- validation loss: 0.908

 epoch 32


100%|██████████| 31/31 [00:02<00:00, 10.75it/s]


-- train loss 0.858 -- validation accuracy 0.648 -- validation loss: 0.895

 epoch 33


100%|██████████| 31/31 [00:02<00:00, 10.58it/s]


-- train loss 0.859 -- validation accuracy 0.668 -- validation loss: 1.009

 epoch 34


100%|██████████| 31/31 [00:02<00:00, 10.77it/s]


-- train loss 0.849 -- validation accuracy 0.644 -- validation loss: 0.937

 epoch 35


100%|██████████| 31/31 [00:02<00:00, 10.56it/s]


-- train loss 0.839 -- validation accuracy 0.640 -- validation loss: 0.876

 epoch 36


100%|██████████| 31/31 [00:02<00:00, 10.81it/s]


-- train loss 0.846 -- validation accuracy 0.648 -- validation loss: 0.849

 epoch 37


100%|██████████| 31/31 [00:02<00:00, 10.34it/s]


-- train loss 0.823 -- validation accuracy 0.640 -- validation loss: 0.887

 epoch 38


100%|██████████| 31/31 [00:02<00:00, 10.70it/s]


-- train loss 0.824 -- validation accuracy 0.684 -- validation loss: 0.882

 epoch 39


100%|██████████| 31/31 [00:02<00:00, 10.75it/s]


-- train loss 0.801 -- validation accuracy 0.668 -- validation loss: 0.878

 epoch 40


100%|██████████| 31/31 [00:02<00:00, 10.65it/s]


-- train loss 0.823 -- validation accuracy 0.668 -- validation loss: 0.911

 epoch 41


100%|██████████| 31/31 [00:02<00:00, 10.74it/s]


-- train loss 0.827 -- validation accuracy 0.652 -- validation loss: 0.855

 epoch 42


100%|██████████| 31/31 [00:02<00:00, 10.49it/s]


-- train loss 0.798 -- validation accuracy 0.672 -- validation loss: 0.876

 epoch 43


100%|██████████| 31/31 [00:02<00:00, 10.65it/s]


-- train loss 0.797 -- validation accuracy 0.656 -- validation loss: 1.000

 epoch 44


100%|██████████| 31/31 [00:02<00:00, 10.60it/s]


-- train loss 0.793 -- validation accuracy 0.668 -- validation loss: 0.871

 epoch 45


100%|██████████| 31/31 [00:02<00:00, 10.61it/s]


-- train loss 0.811 -- validation accuracy 0.676 -- validation loss: 0.806

 epoch 46


100%|██████████| 31/31 [00:02<00:00, 10.65it/s]


-- train loss 0.760 -- validation accuracy 0.680 -- validation loss: 0.831

 epoch 47


100%|██████████| 31/31 [00:02<00:00, 10.63it/s]


-- train loss 0.761 -- validation accuracy 0.684 -- validation loss: 0.875

 epoch 48


100%|██████████| 31/31 [00:02<00:00, 10.60it/s]


-- train loss 0.783 -- validation accuracy 0.676 -- validation loss: 0.855

 epoch 49


100%|██████████| 31/31 [00:02<00:00, 10.67it/s]


-- train loss 0.780 -- validation accuracy 0.668 -- validation loss: 0.836


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#IF SLEEP-EDF
#model = torch.load('model1_ViT.pth')

#IF BHUTAN
model = torch.load('model2_ViT_Bhutan.pth')
_, _, valid_loader, _, _, _ = prepare_dataloaders(batch_size=32, dataset = Bhutan_dataset)

y_pred_validation = []
y_true_validation = []

#FOR SLEEP-EDF
# x_test, y_test = x_test.to(device), y_test.to(device)
# with torch.no_grad():
#     test_output, atten_matrix_list = model(x_test)
#     test_predicted = test_output.argmax(dim = 1)
# y_pred_validation.extend(test_predicted.cpu().numpy())
# y_true_validation.extend(y_test.cpu().numpy())

# #FOR BHUTAN
images, labels = next(iter(valid_loader))
images, labels = images.to(device), labels.to(device)
with torch.no_grad():
    test_output, atten_matrix_list = model(images)
    test_predicted = test_output.argmax(dim = 1)
y_pred_validation.extend(test_predicted.cpu().numpy())
y_true_validation.extend(labels.cpu().numpy())

std_val = 1 - accuracy_score(y_true_validation, y_pred_validation)
sem_val = std_val / np.sqrt(len(y_true_validation))
sem_p_val = (sem_val / 1) * 100
print(accuracy_score(y_true_validation, y_pred_validation) * 100, sem_p_val)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
#FOR SLEEP-EDF
# cm = confusion_matrix(y_true_validation, y_pred_validation, labels=[0,1,2,3,4], normalize='true')
# disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels = ["stage W", "Stage 1", "Stage 2", "Stage 3/4", "Stage R"])
# disp.plot()
# plt.show()

#FOR BHUTAN
cm = confusion_matrix(y_true_validation, y_pred_validation, labels=[0,1,2,3,4], normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels = ["EB", "EMLR", "EC", "EO", "JC"])
disp.plot()
plt.show()

In [None]:
from sklearn.metrics import classification_report
print(classification_report(y_true_validation, y_pred_validation, target_names=["EB", "EMLR", "EC", "EO", "JC"]))