In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
from colorist import Color
import time
import os

import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchsummary import summary
from torch.optim import Adam

from models.pyramid_ViG import DeepGCN
from timm.models import create_model
from timm.scheduler import CosineLRScheduler
from trainer import Trainer

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
model = create_model('PyramidVIG_Tiny_GELU')
model = model.to(device)

<models.pyramid_ViG.PyramidVIG_Tiny_GELU.<locals>.OptInit object at 0x00000272BDD8E470>
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos
using relative_pos


In [3]:
rand_X = torch.randn(4, 3, 256, 256).to(device)
sample_output = model(rand_X)
print(sample_output.shape)

torch.Size([4, 7])


In [4]:
print(summary(model, (3, 256, 256)))

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1         [-1, 24, 128, 128]             672
       BatchNorm2d-2         [-1, 24, 128, 128]              48
              GELU-3         [-1, 24, 128, 128]               0
            Conv2d-4           [-1, 48, 64, 64]          10,416
       BatchNorm2d-5           [-1, 48, 64, 64]              96
              GELU-6           [-1, 48, 64, 64]               0
            Conv2d-7           [-1, 48, 64, 64]          20,784
       BatchNorm2d-8           [-1, 48, 64, 64]              96
              Stem-9           [-1, 48, 64, 64]               0
           Conv2d-10           [-1, 48, 64, 64]           2,352
      BatchNorm2d-11           [-1, 48, 64, 64]              96
     DenseDilated-12           [-1, 2, 4096, 9]               0
DenseDilatedKnnGraph-13           [-1, 2, 4096, 9]               0
           Conv2d-14          [-1, 9

In [5]:
train_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

valid_transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

data_directory = './data'

train_dataset = ImageFolder(data_directory + '/train', transform=train_transform)
valid_dataset = ImageFolder(data_directory + '/valid', transform=valid_transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)
valid_loader = DataLoader(valid_dataset, batch_size=32, num_workers=4, pin_memory=True)

train_count = len(train_dataset)
valid_count = len(valid_dataset)
learning_rate = 2e-3
optimizer = Adam(model.parameters(), lr=learning_rate)
scheduler = CosineLRScheduler(optimizer, t_initial=20, warmup_lr_init=1e-6)

preparation_summary = {
    'data_directory': data_directory,
    'train_count': train_count,
    'valid_count': valid_count,
    'learning_rate': learning_rate,
    'class_names': [i for i in train_dataset.classes]
}

summary_df = pd.DataFrame(list(preparation_summary.values()),
                          index=preparation_summary.keys(),
                          columns=['Value'])
display(summary_df)

Unnamed: 0,Value
data_directory,./data
train_count,38569
valid_count,938
learning_rate,0.002
class_names,"[akiec, bcc, bkl, df, mel, nv, vasc]"


In [6]:
trainer = Trainer(model, train_loader, valid_loader, optimizer, scheduler, device)

Created new directory for model checkpoints at ./model_checkpoints/2024-03-11_04.37.57


In [7]:
trainer(epochs = 200)

[35mEpoch 1[0m:   0%|          | 0/1206 [00:00<?, ?it/s]

[35mEpoch 1[0m: 100%|██████████| 1206/1206 [07:21<00:00,  2.73it/s]


Train set ===> Average Loss: [31m0.0395[0m | Accuracy: 19889/38569 ([36m51.57%[0m)
Test set  ===> Average Loss: [31m0.0175[0m | Accuracy: 748/938 ([32m79.74%[0m)
Model saved at ./model_checkpoints/2024-03-11_04.37.57/best_model.pth
Best Accuracy: [[1;32m79.744%[0m]


[35mEpoch 2[0m: 100%|██████████| 1206/1206 [07:23<00:00,  2.72it/s]


Train set ===> Average Loss: [31m0.0320[0m | Accuracy: 23295/38569 ([36m60.40%[0m)
Test set  ===> Average Loss: [31m0.0144[0m | Accuracy: 786/938 ([32m83.80%[0m)
Model saved at ./model_checkpoints/2024-03-11_04.37.57/best_model.pth
Best Accuracy: [[1;32m83.795%[0m]


[35mEpoch 3[0m: 100%|██████████| 1206/1206 [07:44<00:00,  2.60it/s]


Train set ===> Average Loss: [31m0.0280[0m | Accuracy: 25249/38569 ([36m65.46%[0m)
Test set  ===> Average Loss: [31m0.0125[0m | Accuracy: 814/938 ([32m86.78%[0m)
Model saved at ./model_checkpoints/2024-03-11_04.37.57/best_model.pth
Best Accuracy: [[1;32m86.78%[0m]


[35mEpoch 4[0m: 100%|██████████| 1206/1206 [08:13<00:00,  2.45it/s]


Train set ===> Average Loss: [31m0.0254[0m | Accuracy: 26516/38569 ([36m68.75%[0m)
Test set  ===> Average Loss: [31m0.0159[0m | Accuracy: 770/938 ([32m82.09%[0m)
Best Accuracy: [[1;32m86.78%[0m]


[35mEpoch 5[0m: 100%|██████████| 1206/1206 [08:19<00:00,  2.41it/s]


Train set ===> Average Loss: [31m0.0230[0m | Accuracy: 27635/38569 ([36m71.65%[0m)
Test set  ===> Average Loss: [31m0.0141[0m | Accuracy: 784/938 ([32m83.58%[0m)
Best Accuracy: [[1;32m86.78%[0m]


[35mEpoch 6[0m: 100%|██████████| 1206/1206 [08:16<00:00,  2.43it/s]


Train set ===> Average Loss: [31m0.0210[0m | Accuracy: 28682/38569 ([36m74.37%[0m)
Test set  ===> Average Loss: [31m0.0138[0m | Accuracy: 777/938 ([32m82.84%[0m)
Best Accuracy: [[1;32m86.78%[0m]


[35mEpoch 7[0m: 100%|██████████| 1206/1206 [08:05<00:00,  2.48it/s]


Train set ===> Average Loss: [31m0.0192[0m | Accuracy: 29543/38569 ([36m76.60%[0m)
Test set  ===> Average Loss: [31m0.0130[0m | Accuracy: 805/938 ([32m85.82%[0m)
Best Accuracy: [[1;32m86.78%[0m]


[35mEpoch 8[0m: 100%|██████████| 1206/1206 [07:50<00:00,  2.57it/s]


Train set ===> Average Loss: [31m0.0179[0m | Accuracy: 30076/38569 ([36m77.98%[0m)
Test set  ===> Average Loss: [31m0.0111[0m | Accuracy: 823/938 ([32m87.74%[0m)
Model saved at ./model_checkpoints/2024-03-11_04.37.57/best_model.pth
Best Accuracy: [[1;32m87.74%[0m]


[35mEpoch 9[0m: 100%|██████████| 1206/1206 [07:28<00:00,  2.69it/s]


Train set ===> Average Loss: [31m0.0165[0m | Accuracy: 30772/38569 ([36m79.78%[0m)
Test set  ===> Average Loss: [31m0.0117[0m | Accuracy: 816/938 ([32m86.99%[0m)
Best Accuracy: [[1;32m87.74%[0m]


[35mEpoch 10[0m: 100%|██████████| 1206/1206 [07:36<00:00,  2.64it/s]


Train set ===> Average Loss: [31m0.0153[0m | Accuracy: 31339/38569 ([36m81.25%[0m)
Test set  ===> Average Loss: [31m0.0124[0m | Accuracy: 808/938 ([32m86.14%[0m)
Best Accuracy: [[1;32m87.74%[0m]


[35mEpoch 11[0m: 100%|██████████| 1206/1206 [07:41<00:00,  2.61it/s]


Train set ===> Average Loss: [31m0.0143[0m | Accuracy: 31910/38569 ([36m82.73%[0m)
Test set  ===> Average Loss: [31m0.0135[0m | Accuracy: 794/938 ([32m84.65%[0m)
Best Accuracy: [[1;32m87.74%[0m]


[35mEpoch 12[0m: 100%|██████████| 1206/1206 [07:42<00:00,  2.61it/s]


Train set ===> Average Loss: [31m0.0131[0m | Accuracy: 32341/38569 ([36m83.85%[0m)
Test set  ===> Average Loss: [31m0.0136[0m | Accuracy: 802/938 ([32m85.50%[0m)
Best Accuracy: [[1;32m87.74%[0m]


[35mEpoch 13[0m: 100%|██████████| 1206/1206 [07:36<00:00,  2.64it/s]


Train set ===> Average Loss: [31m0.0121[0m | Accuracy: 32817/38569 ([36m85.09%[0m)
Test set  ===> Average Loss: [31m0.0112[0m | Accuracy: 826/938 ([32m88.06%[0m)
Model saved at ./model_checkpoints/2024-03-11_04.37.57/best_model.pth
Best Accuracy: [[1;32m88.06%[0m]


[35mEpoch 14[0m: 100%|██████████| 1206/1206 [07:21<00:00,  2.73it/s]


Train set ===> Average Loss: [31m0.0112[0m | Accuracy: 33378/38569 ([36m86.54%[0m)
Test set  ===> Average Loss: [31m0.0103[0m | Accuracy: 829/938 ([32m88.38%[0m)
Model saved at ./model_checkpoints/2024-03-11_04.37.57/best_model.pth
Best Accuracy: [[1;32m88.38%[0m]


[35mEpoch 15[0m: 100%|██████████| 1206/1206 [07:21<00:00,  2.73it/s]


Train set ===> Average Loss: [31m0.0102[0m | Accuracy: 33800/38569 ([36m87.64%[0m)
Test set  ===> Average Loss: [31m0.0113[0m | Accuracy: 830/938 ([32m88.49%[0m)
Model saved at ./model_checkpoints/2024-03-11_04.37.57/best_model.pth
Best Accuracy: [[1;32m88.486%[0m]


[35mEpoch 16[0m:  77%|███████▋  | 932/1206 [05:49<01:39,  2.75it/s]