In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models

from sklearn import decomposition
from sklearn import manifold
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import numpy as np

import copy
from collections import namedtuple
import os
import random
import shutil
import time
from tqdm.notebook import tqdm
from torchinfo import summary
import pickle

In [2]:
SEED = 853

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [3]:
data_dir = 'CV/data'
images_dir = os.path.join(data_dir, 'images')
train_dir = os.path.join(data_dir, 'train')
test_dir = os.path.join(data_dir, 'test')

classes = os.listdir(images_dir)

In [4]:
pretrained_size = 388
pretrained_means = [0.485, 0.456, 0.406]
pretrained_stds= [0.229, 0.224, 0.225]

train_transforms = transforms.Compose([
                           transforms.Resize(pretrained_size),
                           transforms.RandomRotation(5),
                           transforms.RandomHorizontalFlip(0.5),
                           transforms.RandomCrop(pretrained_size, padding = 10),
                           transforms.ToTensor(),
                           transforms.Normalize(mean = pretrained_means, 
                                                std = pretrained_stds)
                       ])

test_transforms = transforms.Compose([
                           transforms.Resize(pretrained_size),
                           transforms.CenterCrop(pretrained_size),
                           transforms.ToTensor(),
                           transforms.Normalize(mean = pretrained_means, 
                                                std = pretrained_stds)
                       ])

In [5]:
train_data = datasets.ImageFolder(root = train_dir, 
                                  transform = train_transforms)

test_data = datasets.ImageFolder(root = test_dir, 
                                 transform = test_transforms)

In [6]:
VALID_RATIO = 0.9

n_train_examples = int(len(train_data) * VALID_RATIO)
n_valid_examples = len(train_data) - n_train_examples

train_data, valid_data = data.random_split(train_data, 
                                           [n_train_examples, n_valid_examples])

In [7]:
valid_data = copy.deepcopy(valid_data)
valid_data.dataset.transform = test_transforms

In [8]:
print(f'Number of training examples: {len(train_data)}')
print(f'Number of validation examples: {len(valid_data)}')
print(f'Number of testing examples: {len(test_data)}')

Number of training examples: 33202
Number of validation examples: 3690
Number of testing examples: 9225


In [9]:
BATCH_SIZE = 128

train_iterator = data.DataLoader(train_data, 
                                 shuffle = True, 
                                 batch_size = BATCH_SIZE,
                                num_workers = 14)

valid_iterator = data.DataLoader(valid_data, 
                                 batch_size = BATCH_SIZE,
                                num_workers = 14)

test_iterator = data.DataLoader(test_data, 
                                batch_size = BATCH_SIZE,
                                num_workers = 14)

In [10]:
from pytorch_pretrained_vit import ViT
pretrained_model = ViT('B_16_imagenet1k', pretrained=True)

pretrained_model.fc = nn.Linear(768, len(test_data.classes))

Loaded pretrained weights.


In [11]:
for param in pretrained_model.parameters():
    param.requires_grad = False
for param in pretrained_model.fc.parameters():
    param.requires_grad = True

In [12]:
summary(pretrained_model)

Layer (type:depth-idx)                                  Param #
ViT                                                     --
├─Conv2d: 1-1                                           (590,592)
├─PositionalEmbedding1D: 1-2                            (443,136)
├─Transformer: 1-3                                      --
│    └─ModuleList: 2-1                                  --
│    │    └─Block: 3-1                                  (7,087,872)
│    │    └─Block: 3-2                                  (7,087,872)
│    │    └─Block: 3-3                                  (7,087,872)
│    │    └─Block: 3-4                                  (7,087,872)
│    │    └─Block: 3-5                                  (7,087,872)
│    │    └─Block: 3-6                                  (7,087,872)
│    │    └─Block: 3-7                                  (7,087,872)
│    │    └─Block: 3-8                                  (7,087,872)
│    │    └─Block: 3-9                                  (7,087,872)
│    │    └─Blo

In [13]:
device = torch.device('cuda')

criterion = nn.CrossEntropyLoss().to(device)

model = pretrained_model.to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', verbose = True, patience=5)

In [14]:
EPOCHS = 500
STEPS_PER_EPOCH = len(train_iterator)
TOTAL_STEPS = EPOCHS * STEPS_PER_EPOCH

MAX_LRS = [p['lr'] for p in optimizer.param_groups]

In [15]:
def calculate_topk_accuracy(y_pred, y, k = 1):
    with torch.no_grad():
        batch_size = y.shape[0]
        _, top_pred = y_pred.topk(k, 1)
        top_pred = top_pred.t()
        correct = top_pred.eq(y.view(1, -1).expand_as(top_pred))
        correct_1 = correct[:1].reshape(-1).float().sum(0, keepdim = True)
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim = True)
        acc_1 = correct_1 / batch_size
        acc_k = correct_k / batch_size
    return acc_1, acc_k

In [16]:
def train(model, iterator, optimizer, criterion, scheduler, device):
    
    epoch_loss = 0
    epoch_acc_1 = 0
    epoch_acc_5 = 0
    
    model.train()
    
    for (x, y) in tqdm(iterator):
        
        x = x.to(device)
        y = y.to(device)
        
        optimizer.zero_grad()
                
        y_pred = model(x)
        
        loss = criterion(y_pred, y)
        
        acc_1, acc_5 = calculate_topk_accuracy(y_pred, y)
        
        loss.backward()
        
        optimizer.step()
        
        epoch_loss += loss.item()
        epoch_acc_1 += acc_1.item()
        epoch_acc_5 += acc_5.item()
        
    epoch_loss /= len(iterator)
    epoch_acc_1 /= len(iterator)
    epoch_acc_5 /= len(iterator)
        
    return epoch_loss, epoch_acc_1, epoch_acc_5

In [17]:
def evaluate(model, iterator, criterion, device):
    
    epoch_loss = 0
    epoch_acc_1 = 0
    epoch_acc_5 = 0
    
    model.eval()
    
    with torch.no_grad():
        
        for (x, y) in iterator:

            x = x.to(device)
            y = y.to(device)

            y_pred = model(x)

            loss = criterion(y_pred, y)

            acc_1, acc_5 = calculate_topk_accuracy(y_pred, y)

            epoch_loss += loss.item()
            epoch_acc_1 += acc_1.item()
            epoch_acc_5 += acc_5.item()
        
    epoch_loss /= len(iterator)
    epoch_acc_1 /= len(iterator)
    epoch_acc_5 /= len(iterator)
        
    return epoch_loss, epoch_acc_1, epoch_acc_5

In [18]:
def epoch_time(start_time, end_time):
    elapsed_time = end_time - start_time
    elapsed_mins = int(elapsed_time / 60)
    elapsed_secs = int(elapsed_time - (elapsed_mins * 60))
    return elapsed_mins, elapsed_secs

In [19]:
#torch.save(model.state_dict(), '/scratch/dpj7913/CV/vit/best-model_at_' + str(epoch) + '.pth')

In [20]:
model.load_state_dict(torch.load('/scratch/dpj7913/CV/vit/best-model_at_4.pth'))

<All keys matched successfully>

In [None]:
torch.cuda.empty_cache()

best_valid_loss = float('inf')
accuracies = {}

for epoch in tqdm(range(5, EPOCHS)):
    
    start_time = time.monotonic()
    
    train_loss, train_acc_1, train_acc_5 = train(model, train_iterator, optimizer, criterion, scheduler, device)
    valid_loss, valid_acc_1, valid_acc_5 = evaluate(model, valid_iterator, criterion, device)
        
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), '/scratch/dpj7913/CV/vit/best-model_at_' + str(epoch) + '.pth')

    end_time = time.monotonic()

    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    print(f'Epoch: {epoch+1:02} | Epoch Time: {epoch_mins}m {epoch_secs}s')
    print(f'\tTrain Loss: {train_loss:.3f} | Train Acc @1: {train_acc_1*100:6.2f}% | ' \
          f'Train Acc @5: {train_acc_5*100:6.2f}%')
    print(f'\tValid Loss: {valid_loss:.3f} | Valid Acc @1: {valid_acc_1*100:6.2f}% | ' \
          f'Valid Acc @5: {valid_acc_5*100:6.2f}%')
    
    scheduler.step(valid_loss)
    accuracies[epoch] = {'train': train_acc_1, 'val': valid_acc_1}
    
    if epoch%50 == 0:
        f = open("accuracies_vit.pkl", "wb")
        pickle.dump(accuracies, f)
        f.close()

  0%|          | 0/495 [00:00<?, ?it/s]

  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 06 | Epoch Time: 9m 46s
	Train Loss: 0.746 | Train Acc @1:  66.89% | Train Acc @5:  66.89%
	Valid Loss: 0.790 | Valid Acc @1:  63.80% | Valid Acc @5:  63.80%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 07 | Epoch Time: 9m 47s
	Train Loss: 0.746 | Train Acc @1:  67.03% | Train Acc @5:  67.03%
	Valid Loss: 0.789 | Valid Acc @1:  64.49% | Valid Acc @5:  64.49%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 08 | Epoch Time: 9m 47s
	Train Loss: 0.748 | Train Acc @1:  66.87% | Train Acc @5:  66.87%
	Valid Loss: 0.787 | Valid Acc @1:  64.07% | Valid Acc @5:  64.07%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 09 | Epoch Time: 9m 46s
	Train Loss: 0.743 | Train Acc @1:  66.93% | Train Acc @5:  66.93%
	Valid Loss: 0.796 | Valid Acc @1:  64.19% | Valid Acc @5:  64.19%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 10 | Epoch Time: 9m 47s
	Train Loss: 0.743 | Train Acc @1:  67.22% | Train Acc @5:  67.22%
	Valid Loss: 0.818 | Valid Acc @1:  62.16% | Valid Acc @5:  62.16%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 11 | Epoch Time: 9m 46s
	Train Loss: 0.744 | Train Acc @1:  67.01% | Train Acc @5:  67.01%
	Valid Loss: 0.801 | Valid Acc @1:  63.59% | Valid Acc @5:  63.59%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 12 | Epoch Time: 9m 46s
	Train Loss: 0.749 | Train Acc @1:  66.85% | Train Acc @5:  66.85%
	Valid Loss: 0.804 | Valid Acc @1:  63.39% | Valid Acc @5:  63.39%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 13 | Epoch Time: 9m 46s
	Train Loss: 0.744 | Train Acc @1:  67.12% | Train Acc @5:  67.12%
	Valid Loss: 0.788 | Valid Acc @1:  64.50% | Valid Acc @5:  64.50%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 14 | Epoch Time: 9m 46s
	Train Loss: 0.745 | Train Acc @1:  67.05% | Train Acc @5:  67.05%
	Valid Loss: 0.776 | Valid Acc @1:  65.79% | Valid Acc @5:  65.79%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 15 | Epoch Time: 9m 46s
	Train Loss: 0.746 | Train Acc @1:  67.41% | Train Acc @5:  67.41%
	Valid Loss: 0.776 | Valid Acc @1:  65.20% | Valid Acc @5:  65.20%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 16 | Epoch Time: 9m 47s
	Train Loss: 0.744 | Train Acc @1:  66.87% | Train Acc @5:  66.87%
	Valid Loss: 0.789 | Valid Acc @1:  64.99% | Valid Acc @5:  64.99%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 17 | Epoch Time: 9m 46s
	Train Loss: 0.748 | Train Acc @1:  66.91% | Train Acc @5:  66.91%
	Valid Loss: 0.811 | Valid Acc @1:  62.35% | Valid Acc @5:  62.35%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 18 | Epoch Time: 9m 46s
	Train Loss: 0.748 | Train Acc @1:  66.59% | Train Acc @5:  66.59%
	Valid Loss: 0.778 | Valid Acc @1:  65.00% | Valid Acc @5:  65.00%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 19 | Epoch Time: 9m 46s
	Train Loss: 0.748 | Train Acc @1:  66.73% | Train Acc @5:  66.73%
	Valid Loss: 0.776 | Valid Acc @1:  64.36% | Valid Acc @5:  64.36%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 20 | Epoch Time: 9m 47s
	Train Loss: 0.748 | Train Acc @1:  66.78% | Train Acc @5:  66.78%
	Valid Loss: 0.779 | Valid Acc @1:  65.68% | Valid Acc @5:  65.68%
Epoch    15: reducing learning rate of group 0 to 1.0000e-04.


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 21 | Epoch Time: 9m 46s
	Train Loss: 0.733 | Train Acc @1:  67.59% | Train Acc @5:  67.59%
	Valid Loss: 0.781 | Valid Acc @1:  65.61% | Valid Acc @5:  65.61%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 22 | Epoch Time: 9m 47s
	Train Loss: 0.730 | Train Acc @1:  68.03% | Train Acc @5:  68.03%
	Valid Loss: 0.780 | Valid Acc @1:  65.58% | Valid Acc @5:  65.58%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 23 | Epoch Time: 9m 46s
	Train Loss: 0.731 | Train Acc @1:  67.42% | Train Acc @5:  67.42%
	Valid Loss: 0.780 | Valid Acc @1:  65.57% | Valid Acc @5:  65.57%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 24 | Epoch Time: 9m 47s
	Train Loss: 0.729 | Train Acc @1:  67.92% | Train Acc @5:  67.92%
	Valid Loss: 0.783 | Valid Acc @1:  64.95% | Valid Acc @5:  64.95%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 25 | Epoch Time: 9m 46s
	Train Loss: 0.729 | Train Acc @1:  68.06% | Train Acc @5:  68.06%
	Valid Loss: 0.777 | Valid Acc @1:  65.57% | Valid Acc @5:  65.57%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 26 | Epoch Time: 9m 47s
	Train Loss: 0.730 | Train Acc @1:  67.79% | Train Acc @5:  67.79%
	Valid Loss: 0.780 | Valid Acc @1:  64.93% | Valid Acc @5:  64.93%
Epoch    21: reducing learning rate of group 0 to 1.0000e-05.


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 27 | Epoch Time: 9m 46s
	Train Loss: 0.730 | Train Acc @1:  68.11% | Train Acc @5:  68.11%
	Valid Loss: 0.780 | Valid Acc @1:  64.96% | Valid Acc @5:  64.96%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 28 | Epoch Time: 9m 47s
	Train Loss: 0.727 | Train Acc @1:  68.11% | Train Acc @5:  68.11%
	Valid Loss: 0.780 | Valid Acc @1:  64.74% | Valid Acc @5:  64.74%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 29 | Epoch Time: 9m 46s
	Train Loss: 0.728 | Train Acc @1:  67.87% | Train Acc @5:  67.87%
	Valid Loss: 0.780 | Valid Acc @1:  64.90% | Valid Acc @5:  64.90%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 30 | Epoch Time: 9m 46s
	Train Loss: 0.728 | Train Acc @1:  68.13% | Train Acc @5:  68.13%
	Valid Loss: 0.780 | Valid Acc @1:  64.87% | Valid Acc @5:  64.87%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 31 | Epoch Time: 9m 46s
	Train Loss: 0.729 | Train Acc @1:  67.95% | Train Acc @5:  67.95%
	Valid Loss: 0.780 | Valid Acc @1:  64.98% | Valid Acc @5:  64.98%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 32 | Epoch Time: 9m 47s
	Train Loss: 0.727 | Train Acc @1:  67.90% | Train Acc @5:  67.90%
	Valid Loss: 0.780 | Valid Acc @1:  64.98% | Valid Acc @5:  64.98%
Epoch    27: reducing learning rate of group 0 to 1.0000e-06.


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 33 | Epoch Time: 9m 46s
	Train Loss: 0.726 | Train Acc @1:  68.05% | Train Acc @5:  68.05%
	Valid Loss: 0.780 | Valid Acc @1:  65.01% | Valid Acc @5:  65.01%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 34 | Epoch Time: 9m 46s
	Train Loss: 0.728 | Train Acc @1:  68.10% | Train Acc @5:  68.10%
	Valid Loss: 0.780 | Valid Acc @1:  65.01% | Valid Acc @5:  65.01%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 35 | Epoch Time: 9m 46s
	Train Loss: 0.728 | Train Acc @1:  67.89% | Train Acc @5:  67.89%
	Valid Loss: 0.780 | Valid Acc @1:  65.01% | Valid Acc @5:  65.01%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 36 | Epoch Time: 9m 46s
	Train Loss: 0.729 | Train Acc @1:  67.91% | Train Acc @5:  67.91%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 37 | Epoch Time: 9m 47s
	Train Loss: 0.727 | Train Acc @1:  68.11% | Train Acc @5:  68.11%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 38 | Epoch Time: 9m 47s
	Train Loss: 0.725 | Train Acc @1:  68.51% | Train Acc @5:  68.51%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%
Epoch    33: reducing learning rate of group 0 to 1.0000e-07.


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 39 | Epoch Time: 9m 47s
	Train Loss: 0.727 | Train Acc @1:  68.14% | Train Acc @5:  68.14%
	Valid Loss: 0.779 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 40 | Epoch Time: 9m 46s
	Train Loss: 0.728 | Train Acc @1:  68.17% | Train Acc @5:  68.17%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 41 | Epoch Time: 9m 46s
	Train Loss: 0.725 | Train Acc @1:  68.35% | Train Acc @5:  68.35%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 42 | Epoch Time: 9m 46s
	Train Loss: 0.725 | Train Acc @1:  68.46% | Train Acc @5:  68.46%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 43 | Epoch Time: 9m 46s
	Train Loss: 0.727 | Train Acc @1:  67.57% | Train Acc @5:  67.57%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 44 | Epoch Time: 9m 46s
	Train Loss: 0.729 | Train Acc @1:  68.08% | Train Acc @5:  68.08%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%
Epoch    39: reducing learning rate of group 0 to 1.0000e-08.


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 45 | Epoch Time: 9m 47s
	Train Loss: 0.729 | Train Acc @1:  68.25% | Train Acc @5:  68.25%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 46 | Epoch Time: 9m 47s
	Train Loss: 0.726 | Train Acc @1:  68.01% | Train Acc @5:  68.01%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 47 | Epoch Time: 9m 47s
	Train Loss: 0.725 | Train Acc @1:  68.35% | Train Acc @5:  68.35%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]

Epoch: 48 | Epoch Time: 9m 47s
	Train Loss: 0.726 | Train Acc @1:  68.27% | Train Acc @5:  68.27%
	Valid Loss: 0.780 | Valid Acc @1:  65.04% | Valid Acc @5:  65.04%


  0%|          | 0/260 [00:00<?, ?it/s]