In [2]:
import numpy as np
import os
import random
import torch
from torch import nn
import torch.nn.functional as F
import tqdm
import torchvision
import torchvision.transforms as transforms
from vit import ViT
from sklearn.model_selection import train_test_split
from torch_loader import Data

def set_seed(seed=1):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

def select_two_classes_from_sleep_edf(dataset, classes):
    idx = (np.array(dataset.targets) == classes[0]) | (np.array(dataset.targets) == classes[1])
    dataset.targets = np.array(dataset.targets)[idx]
    dataset.targets[dataset.targets==classes[0]] = 0
    dataset.targets[dataset.targets==classes[1]] = 1
    dataset.targets= dataset.targets.tolist()  
    dataset.data = dataset.data[idx]
    return dataset

def prepare_dataloaders(batch_size, classes=[2, 4], dataset = Data):
    # TASK: Experiment with data augmentation
    # train_transform = transforms.Compose([transforms.ToTensor(),
    #                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    # )

    # test_transform = transforms.Compose([transforms.ToTensor(),
    #                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
    # )
    train_set, test_set = train_test_split(dataset, test_size = 0.2, shuffle = True)
    

    # select two classes 
    #train_set = select_two_classes_from_sleep_edf(train_set, classes=classes)
    #test_set = select_two_classes_from_sleep_edf(test_set, classes=classes)


    trainloader = torch.utils.data.DataLoader(train_set, batch_size=batch_size,
                                              shuffle=True
    )
    testloader = torch.utils.data.DataLoader(test_set, batch_size=batch_size,
                                              shuffle=False
    )
    return trainloader, testloader, train_set, test_set


def main(image_size=(128,128), patch_size=(16,16), channels=3, 
         embed_dim=128, num_heads=4, num_layers=4, num_classes=5,
         pos_enc='learnable', pool='cls', dropout=0.3, fc_dim=None, 
         num_epochs=50, batch_size=32, lr=1e-4, warmup_steps=625,
         weight_decay=1e-3, gradient_clipping=1
         
    ):

    loss_function = nn.CrossEntropyLoss()

    train_iter, test_iter, _, _ = prepare_dataloaders(batch_size=batch_size, dataset = Data)

    model = ViT(image_size=image_size, patch_size=patch_size, channels=channels, 
                embed_dim=embed_dim, num_heads=num_heads, num_layers=num_layers,
                pos_enc=pos_enc, pool=pool, dropout=dropout, fc_dim=fc_dim, 
                num_classes=num_classes
    )

    if torch.cuda.is_available():
        model = model.to('cuda')

    opt = torch.optim.AdamW(lr=lr, params=model.parameters(), weight_decay=weight_decay)
    sch = torch.optim.lr_scheduler.LambdaLR(opt, lambda i: min(i / warmup_steps, 1.0))

    # training loop
    best_val_loss = 1e10
    for e in range(num_epochs):
        print(f'\n epoch {e}')
        model.train()
        train_loss = 0
        for image, label in tqdm.tqdm(train_iter):
            if torch.cuda.is_available():
                image, label = image.to('cuda'), label.to('cuda')
            opt.zero_grad()
            out, atten_matix_list = model(image)
            loss = loss_function(out, label)
            loss.backward()
            train_loss += loss.item()
            # if the total gradient vector has a length > 1, we clip it back down to 1.
            if gradient_clipping > 0.0:
                nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
            opt.step()
            sch.step()

        train_loss/=len(train_iter)

        val_loss = 0
        with torch.no_grad():
            model.eval()
            tot, cor= 0.0, 0.0
            for image, label in test_iter:
                if torch.cuda.is_available():
                    image, label = image.to('cuda'), label.to('cuda')
                out, atten_matrix_list = model(image)
                loss = loss_function(out, label)
                val_loss += loss.item()
                out = out.argmax(dim=1)
                tot += float(image.size(0))
                cor += float((label == out).sum().item())
            acc = cor / tot
            val_loss /= len(test_iter)
            print(f'-- train loss {train_loss:.3f} -- validation accuracy {acc:.3f} -- validation loss: {val_loss:.3f}')
            if val_loss <= best_val_loss:
                torch.save(model.state_dict(), 'model.pth')
                best_val_loss = val_loss

        # save model
        torch.save(model, 'model.pth')

In [3]:
if __name__ == "__main__":
    #os.environ["CUDA_VISIBLE_DEVICES"]= str(0)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')  
    print(f"Model will run on {device}")
    set_seed(seed=1)
    main()

Model will run on cpu

 epoch 0


100%|██████████| 31/31 [00:02<00:00, 10.60it/s]


-- train loss 1.728 -- validation accuracy 0.146 -- validation loss: 1.504

 epoch 1


100%|██████████| 31/31 [00:02<00:00, 10.61it/s]


-- train loss 1.681 -- validation accuracy 0.526 -- validation loss: 1.366

 epoch 2


100%|██████████| 31/31 [00:03<00:00,  8.90it/s]


-- train loss 1.587 -- validation accuracy 0.526 -- validation loss: 1.429

 epoch 3


100%|██████████| 31/31 [00:03<00:00,  9.68it/s]


-- train loss 1.529 -- validation accuracy 0.526 -- validation loss: 1.674

 epoch 4


100%|██████████| 31/31 [00:02<00:00, 11.17it/s]


-- train loss 1.510 -- validation accuracy 0.526 -- validation loss: 1.794

 epoch 5


100%|██████████| 31/31 [00:02<00:00, 11.28it/s]


-- train loss 1.495 -- validation accuracy 0.526 -- validation loss: 1.827

 epoch 6


100%|██████████| 31/31 [00:02<00:00, 11.24it/s]


-- train loss 1.436 -- validation accuracy 0.526 -- validation loss: 1.674

 epoch 7


100%|██████████| 31/31 [00:02<00:00, 11.06it/s]


-- train loss 1.448 -- validation accuracy 0.526 -- validation loss: 1.687

 epoch 8


100%|██████████| 31/31 [00:02<00:00, 10.90it/s]


-- train loss 1.419 -- validation accuracy 0.526 -- validation loss: 1.553

 epoch 9


100%|██████████| 31/31 [00:02<00:00, 10.51it/s]


-- train loss 1.408 -- validation accuracy 0.526 -- validation loss: 1.481

 epoch 10


100%|██████████| 31/31 [00:02<00:00, 11.21it/s]


-- train loss 1.376 -- validation accuracy 0.518 -- validation loss: 1.414

 epoch 11


100%|██████████| 31/31 [00:02<00:00, 11.34it/s]


-- train loss 1.358 -- validation accuracy 0.538 -- validation loss: 1.474

 epoch 12


100%|██████████| 31/31 [00:02<00:00, 10.98it/s]


-- train loss 1.309 -- validation accuracy 0.555 -- validation loss: 1.668

 epoch 13


100%|██████████| 31/31 [00:02<00:00, 11.25it/s]


-- train loss 1.295 -- validation accuracy 0.538 -- validation loss: 1.494

 epoch 14


100%|██████████| 31/31 [00:02<00:00, 11.08it/s]


-- train loss 1.276 -- validation accuracy 0.543 -- validation loss: 1.419

 epoch 15


100%|██████████| 31/31 [00:02<00:00, 11.16it/s]


-- train loss 1.256 -- validation accuracy 0.543 -- validation loss: 1.585

 epoch 16


100%|██████████| 31/31 [00:02<00:00, 11.18it/s]


-- train loss 1.230 -- validation accuracy 0.543 -- validation loss: 1.431

 epoch 17


100%|██████████| 31/31 [00:02<00:00, 11.04it/s]


-- train loss 1.193 -- validation accuracy 0.559 -- validation loss: 1.241

 epoch 18


100%|██████████| 31/31 [00:02<00:00, 11.18it/s]


-- train loss 1.114 -- validation accuracy 0.619 -- validation loss: 1.061

 epoch 19


100%|██████████| 31/31 [00:02<00:00, 11.14it/s]


-- train loss 0.988 -- validation accuracy 0.623 -- validation loss: 1.062

 epoch 20


100%|██████████| 31/31 [00:02<00:00, 11.17it/s]


-- train loss 0.968 -- validation accuracy 0.656 -- validation loss: 1.069

 epoch 21


100%|██████████| 31/31 [00:02<00:00, 11.23it/s]


-- train loss 0.965 -- validation accuracy 0.615 -- validation loss: 0.997

 epoch 22


100%|██████████| 31/31 [00:02<00:00, 10.92it/s]


-- train loss 0.918 -- validation accuracy 0.664 -- validation loss: 0.986

 epoch 23


100%|██████████| 31/31 [00:02<00:00, 11.27it/s]


-- train loss 0.908 -- validation accuracy 0.538 -- validation loss: 1.146

 epoch 24


100%|██████████| 31/31 [00:02<00:00, 11.15it/s]


-- train loss 0.915 -- validation accuracy 0.636 -- validation loss: 0.987

 epoch 25


100%|██████████| 31/31 [00:02<00:00, 10.61it/s]


-- train loss 0.946 -- validation accuracy 0.632 -- validation loss: 0.919

 epoch 26


100%|██████████| 31/31 [00:02<00:00, 10.66it/s]


-- train loss 0.933 -- validation accuracy 0.664 -- validation loss: 1.124

 epoch 27


100%|██████████| 31/31 [00:02<00:00, 10.76it/s]


-- train loss 0.913 -- validation accuracy 0.640 -- validation loss: 0.935

 epoch 28


100%|██████████| 31/31 [00:02<00:00, 10.86it/s]


-- train loss 0.884 -- validation accuracy 0.583 -- validation loss: 1.041

 epoch 29


100%|██████████| 31/31 [00:02<00:00, 10.86it/s]


-- train loss 0.919 -- validation accuracy 0.648 -- validation loss: 0.961

 epoch 30


100%|██████████| 31/31 [00:02<00:00, 10.58it/s]


-- train loss 0.849 -- validation accuracy 0.644 -- validation loss: 0.979

 epoch 31


100%|██████████| 31/31 [00:02<00:00, 11.01it/s]


-- train loss 0.870 -- validation accuracy 0.632 -- validation loss: 0.908

 epoch 32


100%|██████████| 31/31 [00:02<00:00, 10.75it/s]


-- train loss 0.858 -- validation accuracy 0.648 -- validation loss: 0.895

 epoch 33


100%|██████████| 31/31 [00:02<00:00, 10.58it/s]


-- train loss 0.859 -- validation accuracy 0.668 -- validation loss: 1.009

 epoch 34


100%|██████████| 31/31 [00:02<00:00, 10.77it/s]


-- train loss 0.849 -- validation accuracy 0.644 -- validation loss: 0.937

 epoch 35


100%|██████████| 31/31 [00:02<00:00, 10.56it/s]


-- train loss 0.839 -- validation accuracy 0.640 -- validation loss: 0.876

 epoch 36


100%|██████████| 31/31 [00:02<00:00, 10.81it/s]


-- train loss 0.846 -- validation accuracy 0.648 -- validation loss: 0.849

 epoch 37


100%|██████████| 31/31 [00:02<00:00, 10.34it/s]


-- train loss 0.823 -- validation accuracy 0.640 -- validation loss: 0.887

 epoch 38


100%|██████████| 31/31 [00:02<00:00, 10.70it/s]


-- train loss 0.824 -- validation accuracy 0.684 -- validation loss: 0.882

 epoch 39


100%|██████████| 31/31 [00:02<00:00, 10.75it/s]


-- train loss 0.801 -- validation accuracy 0.668 -- validation loss: 0.878

 epoch 40


100%|██████████| 31/31 [00:02<00:00, 10.65it/s]


-- train loss 0.823 -- validation accuracy 0.668 -- validation loss: 0.911

 epoch 41


100%|██████████| 31/31 [00:02<00:00, 10.74it/s]


-- train loss 0.827 -- validation accuracy 0.652 -- validation loss: 0.855

 epoch 42


100%|██████████| 31/31 [00:02<00:00, 10.49it/s]


-- train loss 0.798 -- validation accuracy 0.672 -- validation loss: 0.876

 epoch 43


100%|██████████| 31/31 [00:02<00:00, 10.65it/s]


-- train loss 0.797 -- validation accuracy 0.656 -- validation loss: 1.000

 epoch 44


100%|██████████| 31/31 [00:02<00:00, 10.60it/s]


-- train loss 0.793 -- validation accuracy 0.668 -- validation loss: 0.871

 epoch 45


100%|██████████| 31/31 [00:02<00:00, 10.61it/s]


-- train loss 0.811 -- validation accuracy 0.676 -- validation loss: 0.806

 epoch 46


100%|██████████| 31/31 [00:02<00:00, 10.65it/s]


-- train loss 0.760 -- validation accuracy 0.680 -- validation loss: 0.831

 epoch 47


100%|██████████| 31/31 [00:02<00:00, 10.63it/s]


-- train loss 0.761 -- validation accuracy 0.684 -- validation loss: 0.875

 epoch 48


100%|██████████| 31/31 [00:02<00:00, 10.60it/s]


-- train loss 0.783 -- validation accuracy 0.676 -- validation loss: 0.855

 epoch 49


100%|██████████| 31/31 [00:02<00:00, 10.67it/s]


-- train loss 0.780 -- validation accuracy 0.668 -- validation loss: 0.836
