In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Compose, ToTensor, Lambda, Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np

from sklearn.metrics import accuracy_score,f1_score

from tqdm import tqdm
from transformers import ViTForImageClassification

## Finetuning Vision Transformer (ViT) for Image Classification

### Parameter

In [5]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 100
LR = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
device

device(type='cuda')

In [6]:
df_train = pd.read_csv(f'{DIRECTROY}/reduced_train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/reduced_test.csv') 
num_classes = len(df_train['newid'].unique())

In [5]:
df_test_public = df_test[df_test['Usage'] == 'Public']
df_test_private = df_test[df_test['Usage'] == 'Private']

In [6]:
len(df_test_public), len(df_test_private)

(2026, 4052)

# Divide the train into multiple chunks

Due to lack of ram, i will have to divide the dataloader into multiple dataloader and save it in SSD.

In [13]:
image_transforms = Compose([
    Resize((IMG_SIZE, IMG_SIZE)),
    ToTensor(), 
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])

In [6]:
from transformers import CLIPTokenizerFast

### Custom Dataset
Creating image with border like in the EDA

In [14]:
class CustomDataset(Dataset):
    def __init__(self, df, transforms, directory):
        self.tokenizer =  CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")
        self.df = df
        self.transforms = transforms
        self.directory = directory
        self.labels = torch.Tensor(df['newid'].values).long()
        self.imgs = torch.cat([ self.transforms(self.resize_img(Image.open(f'{DIRECTROY}/{self.directory}/{x}')).convert('RGB')).half().reshape(1,3,IMG_SIZE,IMG_SIZE) for x in tqdm(df['name'].values)])
        self.tokenized = self.tokenizer(df['label'].tolist(), padding=True, truncation=True, return_tensors="pt")
        self.input_ids = self.tokenized['input_ids']
        self.attention_mask = self.tokenized['attention_mask']
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        return img, label, input_ids, attention_mask
    
    @staticmethod
    def resize_img(a):
        w, h = a.size
        if w < h:
            scale = h/IMG_SIZE
            w = int(w/scale)
            h = IMG_SIZE
            a = a.resize((w, h))
            lb= np.array([a.load()[0,x] for x in range(h)])
            rb = np.array([a.load()[w-1,x] for x in range(h)])
            lb = lb.mean(axis=0).astype('uint8')
            rb = rb.mean(axis=0).astype('uint8')
            pic = Image.new('RGB', (h, h), color = (255, 255, 255))
            imgl = Image.new('RGB', (h//2, h), color = tuple(lb))
            imgr = Image.new('RGB', (h//2, h), color = tuple(rb))
            
            pic.paste(imgl, (0, 0))
            pic.paste(imgr, (h//2, 0))
            
            pic.paste(a, (h//2-w//2, 0))

        elif w>h:
            scale = w/IMG_SIZE
            h = int(h/scale)
            w = IMG_SIZE
            a = a.resize((w, h))
            
            lb= np.array([a.load()[x,0] for x in range(w)])
            rb = np.array([a.load()[x,h-1] for x in range(w)])
            lb = lb.mean(axis=0).astype('uint8')
            rb = rb.mean(axis=0).astype('uint8')
            
            pic = Image.new('RGB', (w, w), color = (255, 255, 255))
            imgl = Image.new('RGB', (w, w//2), color = tuple(lb))
            imgr = Image.new('RGB', (w, w//2), color = tuple(rb))
            
            pic.paste(imgl, (0, 0))
            pic.paste(imgr, (0, w//2))
            
            pic.paste(a, (0, w//2-h//2))

        else:
            a = a.resize((IMG_SIZE, IMG_SIZE))
            pic = a
        return pic

In [10]:
df_train = df_train.sample(frac=1).reset_index(drop=True)
df_test_public = df_test_public.sample(frac=1).reset_index(drop=True)
df_test_private = df_test_private.sample(frac=1).reset_index(drop=True)

In [11]:
len(df_train['newid'].unique()), len(df_test['newid'].unique())

(649, 649)

In [12]:
df_train.head()

Unnamed: 0,name,class,group,label,newid
0,86940.jpg,6048,195,warship leog set,330
1,10895.jpg,1363,23,blue workout short,90
2,49510.jpg,4164,145,white earphones,364
3,40239.jpg,3735,60,pink luggage with vertical stripes,638
4,8492.jpg,1107,23,black jogging trousers with white stripe,21


In [8]:
import math

In [38]:
for i in range(math.ceil(len(df_train)/18192)):
    train_dataset = CustomDataset(df_train[i*18192:(i+1)*18192], image_transforms, 'train')
    torch.save(train_dataset, f'{DIRECTROY}/train_dataset/train_dataset_reduced_{i}.pth')
    del train_dataset

100%|██████████| 18192/18192 [04:03<00:00, 74.67it/s]
100%|██████████| 11051/11051 [02:23<00:00, 76.78it/s]


In [39]:
for i in range(math.ceil(len(df_test)/18192)):
    test_dataset = CustomDataset(df_test[i*18192:(i+1)*18192], image_transforms, 'test')
    torch.save(test_dataset, f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_{i}.pth')
    del test_dataset

# for i in range(math.ceil(len(df_test_private)/16192)):
#     test_dataset = CustomDataset(df_test_private[i*16192:(i+1)*16192], image_transforms, 'test')
#     torch.save(test_dataset, f'{DIRECTROY}/test_private_dataset/test_private_reduced_dataset_{i}.pth')
#     del test_dataset


100%|██████████| 6078/6078 [01:08<00:00, 88.28it/s] 


### Load ViT

In [7]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=num_classes)
model = model.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
model.parameters

<bound method Module.parameters of ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
      

### Training

In [9]:
model.load_state_dict(torch.load('models/vit_reduced_model_50.pth'))

<All keys matched successfully>

In [10]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=EPOCHS)

In [11]:
optimizer.load_state_dict(torch.load('models/optimizer/vit_reduced_optimizer_50.pth'))

### Train with augumentation dataset

In [15]:
max_accuracy = 0.0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    
    # Training loop
    print('Training epoch:', epoch+1)
    len_train = 0
    
    for i in range(2):
        train_dataset = torch.load(f'{DIRECTROY}/train_dataset/train_dataset_reduced_aug_{i}.pth')
        train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)

    
        for inputs, labels, input_ids, attention_mask in tqdm(train_dataloader):
            optimizer.zero_grad()
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
        len_train += len(train_dataset)
        del train_dataloader
        del train_dataset
    
        
    scheduler.step()    
    train_loss/=len_train
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss}')
    
    eval_loss = 0.0
    model.eval()
    
    true_labels = []
    pred_labels = []
    
    print('Evaluating epoch:', epoch+1)
    with torch.no_grad():
        len_test = 0
        
        test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_0.pth')
        test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
          

        for inputs, labels, input_ids, attention_mask in tqdm(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels).to(device)
            eval_loss += loss.item()
            
            outputs = torch.argmax(outputs.logits, 1).flatten().cpu().numpy()
            labels = labels.flatten().cpu().numpy()
            
            true_labels.extend(labels)
            pred_labels.extend(outputs)
        
        len_test += len(test_dataset)
        
        
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {eval_loss/len_test}')
    print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
    print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
    print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
    if accuracy_score(true_labels, pred_labels) > max_accuracy:
        max_accuracy = accuracy_score(true_labels, pred_labels)
        torch.save(model.state_dict(), f'{MODEL_PATH}/vit_reduced_aug_model_{epoch+1}.pth')
        torch.save(optimizer.state_dict(), f'{MODEL_PATH}/optimizer/vit_reduced_aug_optimizer_{epoch+1}.pth')
            

Training epoch: 1


100%|██████████| 1137/1137 [04:31<00:00,  4.19it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 1/100, Loss: 0.002826779798214895
Evaluating epoch: 1


100%|██████████| 190/190 [00:16<00:00, 11.19it/s]


Epoch 1/100, Loss: 0.021054488424696403
Accuracy: 0.8622902270483712
F1 Score Weighted: 0.8601472368156395
F1 Score Macro: 0.8592582098358568
Training epoch: 2


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 2/100, Loss: 0.00028310826972598933
Evaluating epoch: 2


100%|██████████| 190/190 [00:16<00:00, 11.24it/s]


Epoch 2/100, Loss: 0.02171462320523213
Accuracy: 0.8672260612043435
F1 Score Weighted: 0.8652154017881426
F1 Score Macro: 0.8646283568520763
Training epoch: 3


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 3/100, Loss: 0.00014897810717638994
Evaluating epoch: 3


100%|██████████| 190/190 [00:16<00:00, 11.24it/s]


Epoch 3/100, Loss: 0.022345452762660558
Accuracy: 0.8673905890095426
F1 Score Weighted: 0.8641620241007695
F1 Score Macro: 0.862487480783788
Training epoch: 4


100%|██████████| 1137/1137 [04:31<00:00,  4.19it/s]
100%|██████████| 1121/1121 [04:27<00:00,  4.19it/s]


Epoch 4/100, Loss: 0.00018699670594224724
Evaluating epoch: 4


100%|██████████| 190/190 [00:16<00:00, 11.18it/s]


Epoch 4/100, Loss: 0.022163705594419582
Accuracy: 0.8655807831523528
F1 Score Weighted: 0.8634557924513893
F1 Score Macro: 0.8623069650926886
Training epoch: 5


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.17it/s]


Epoch 5/100, Loss: 0.00018763876359488294
Evaluating epoch: 5


100%|██████████| 190/190 [00:17<00:00, 11.17it/s]


Epoch 5/100, Loss: 0.022495161939991022
Accuracy: 0.8650871997367555
F1 Score Weighted: 0.8632372460267665
F1 Score Macro: 0.8643853580101796
Training epoch: 6


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 6/100, Loss: 0.00013123143487120496
Evaluating epoch: 6


100%|██████████| 190/190 [00:17<00:00, 11.15it/s]


Epoch 6/100, Loss: 0.02359859478017728
Accuracy: 0.8631128660743665
F1 Score Weighted: 0.8602880350796863
F1 Score Macro: 0.8583728414249396
Training epoch: 7


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.17it/s]


Epoch 7/100, Loss: 0.00018625325532397932
Evaluating epoch: 7


100%|██████████| 190/190 [00:16<00:00, 11.22it/s]


Epoch 7/100, Loss: 0.024119916365948648
Accuracy: 0.8553800592300099
F1 Score Weighted: 0.8531474828628285
F1 Score Macro: 0.8497608835375504
Training epoch: 8


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.16it/s]


Epoch 8/100, Loss: 0.000111856329538486
Evaluating epoch: 8


100%|██████████| 190/190 [00:17<00:00, 11.16it/s]


Epoch 8/100, Loss: 0.023368659889876048
Accuracy: 0.8588351431391905
F1 Score Weighted: 0.8578929757310353
F1 Score Macro: 0.8564587821015583
Training epoch: 9


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.17it/s]


Epoch 9/100, Loss: 0.0001437746950989598
Evaluating epoch: 9


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 9/100, Loss: 0.022836436544811305
Accuracy: 0.8644290885159592
F1 Score Weighted: 0.8622584220002074
F1 Score Macro: 0.8608241758129968
Training epoch: 10


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.17it/s]


Epoch 10/100, Loss: 0.0001306935879101372
Evaluating epoch: 10


100%|██████████| 190/190 [00:16<00:00, 11.22it/s]


Epoch 10/100, Loss: 0.024621560353433786
Accuracy: 0.859986837775584
F1 Score Weighted: 0.8570578531530691
F1 Score Macro: 0.8564886489678885
Training epoch: 11


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 11/100, Loss: 0.00014824959449029758
Evaluating epoch: 11


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 11/100, Loss: 0.024481757985738723
Accuracy: 0.8609740046067785
F1 Score Weighted: 0.8584290276642025
F1 Score Macro: 0.8579310990943488
Training epoch: 12


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 12/100, Loss: 7.827206634293126e-05
Evaluating epoch: 12


100%|██████████| 190/190 [00:17<00:00, 11.17it/s]


Epoch 12/100, Loss: 0.024226835530003647
Accuracy: 0.8645936163211583
F1 Score Weighted: 0.8611152887150688
F1 Score Macro: 0.8584788001712161
Training epoch: 13


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.17it/s]


Epoch 13/100, Loss: 0.00010221189263852264
Evaluating epoch: 13


100%|██████████| 190/190 [00:17<00:00, 11.14it/s]


Epoch 13/100, Loss: 0.02520698013995241
Accuracy: 0.859657782165186
F1 Score Weighted: 0.8564553126554206
F1 Score Macro: 0.8545422798880722
Training epoch: 14


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 14/100, Loss: 0.00010692011445443478
Evaluating epoch: 14


100%|██████████| 190/190 [00:17<00:00, 11.14it/s]


Epoch 14/100, Loss: 0.024173537785710332
Accuracy: 0.8688713392563343
F1 Score Weighted: 0.8654248330199644
F1 Score Macro: 0.8636047183165684
Training epoch: 15


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 15/100, Loss: 0.00011315844700379801
Evaluating epoch: 15


100%|██████████| 190/190 [00:17<00:00, 11.11it/s]


Epoch 15/100, Loss: 0.024829338661490934
Accuracy: 0.8642645607107601
F1 Score Weighted: 0.8611641665483447
F1 Score Macro: 0.8599632519767978
Training epoch: 16


100%|██████████| 1137/1137 [04:32<00:00,  4.18it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.18it/s]


Epoch 16/100, Loss: 7.854324203993845e-05
Evaluating epoch: 16


100%|██████████| 190/190 [00:16<00:00, 11.22it/s]


Epoch 16/100, Loss: 0.02472501278765038
Accuracy: 0.8621256992431721
F1 Score Weighted: 0.860131968567929
F1 Score Macro: 0.8586830087661274
Training epoch: 17


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 17/100, Loss: 8.622997230329104e-05
Evaluating epoch: 17


100%|██████████| 190/190 [00:16<00:00, 11.18it/s]


Epoch 17/100, Loss: 0.026666587116017475
Accuracy: 0.8566962816716025
F1 Score Weighted: 0.854364074595975
F1 Score Macro: 0.8542553539535132
Training epoch: 18


100%|██████████| 1137/1137 [04:34<00:00,  4.14it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 18/100, Loss: 8.906489277021337e-05
Evaluating epoch: 18


100%|██████████| 190/190 [00:16<00:00, 11.26it/s]


Epoch 18/100, Loss: 0.025508216656477862
Accuracy: 0.8609740046067785
F1 Score Weighted: 0.858670995257768
F1 Score Macro: 0.8577263549834327
Training epoch: 19


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 19/100, Loss: 0.00012338047167765825
Evaluating epoch: 19


100%|██████████| 190/190 [00:17<00:00, 11.15it/s]


Epoch 19/100, Loss: 0.02567574925553749
Accuracy: 0.8606449489963804
F1 Score Weighted: 0.8585401912165932
F1 Score Macro: 0.8583401292164416
Training epoch: 20


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 20/100, Loss: 4.9878632468256755e-05
Evaluating epoch: 20


100%|██████████| 190/190 [00:16<00:00, 11.21it/s]


Epoch 20/100, Loss: 0.025339972100976955
Accuracy: 0.8624547548535703
F1 Score Weighted: 0.8604829584273334
F1 Score Macro: 0.85793286391611
Training epoch: 21


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:27<00:00,  4.19it/s]


Epoch 21/100, Loss: 9.081845939014396e-05
Evaluating epoch: 21


100%|██████████| 190/190 [00:16<00:00, 11.26it/s]


Epoch 21/100, Loss: 0.02568693759620798
Accuracy: 0.8649226719315565
F1 Score Weighted: 0.8620075608870637
F1 Score Macro: 0.8604634521257619
Training epoch: 22


100%|██████████| 1137/1137 [04:31<00:00,  4.19it/s]
100%|██████████| 1121/1121 [04:27<00:00,  4.19it/s]


Epoch 22/100, Loss: 7.468867001498574e-05
Evaluating epoch: 22


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 22/100, Loss: 0.024964258831823946
Accuracy: 0.8668970055939453
F1 Score Weighted: 0.8651503045844964
F1 Score Macro: 0.8628880215323613
Training epoch: 23


100%|██████████| 1137/1137 [04:31<00:00,  4.19it/s]
100%|██████████| 1121/1121 [04:27<00:00,  4.19it/s]


Epoch 23/100, Loss: 0.00011840606893640604
Evaluating epoch: 23


100%|██████████| 190/190 [00:16<00:00, 11.26it/s]


Epoch 23/100, Loss: 0.0248096645881895
Accuracy: 0.8650871997367555
F1 Score Weighted: 0.8634180724060733
F1 Score Macro: 0.8638042487801688
Training epoch: 24


100%|██████████| 1137/1137 [04:31<00:00,  4.18it/s]
100%|██████████| 1121/1121 [04:27<00:00,  4.19it/s]


Epoch 24/100, Loss: 6.51950856237628e-05
Evaluating epoch: 24


100%|██████████| 190/190 [00:16<00:00, 11.21it/s]


Epoch 24/100, Loss: 0.02560334391695058
Accuracy: 0.8614675880223758
F1 Score Weighted: 0.859417339218577
F1 Score Macro: 0.8580394256623031
Training epoch: 25


100%|██████████| 1137/1137 [04:31<00:00,  4.18it/s]
100%|██████████| 1121/1121 [04:27<00:00,  4.19it/s]


Epoch 25/100, Loss: 9.426892709505834e-05
Evaluating epoch: 25


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 25/100, Loss: 0.025635055535818442
Accuracy: 0.8570253372820007
F1 Score Weighted: 0.8548773284699058
F1 Score Macro: 0.8534195814018185
Training epoch: 26


100%|██████████| 1137/1137 [04:31<00:00,  4.19it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.16it/s]


Epoch 26/100, Loss: 4.726642203459449e-05
Evaluating epoch: 26


100%|██████████| 190/190 [00:17<00:00, 11.17it/s]


Epoch 26/100, Loss: 0.024703365993892133
Accuracy: 0.8632773938795657
F1 Score Weighted: 0.8620954439029511
F1 Score Macro: 0.8636804153450555
Training epoch: 27


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.16it/s]


Epoch 27/100, Loss: 7.570096222862704e-05
Evaluating epoch: 27


100%|██████████| 190/190 [00:17<00:00, 11.14it/s]


Epoch 27/100, Loss: 0.026975989003773713
Accuracy: 0.8542283645936163
F1 Score Weighted: 0.8516598756499314
F1 Score Macro: 0.8540034309766167
Training epoch: 28


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 28/100, Loss: 6.40216179042223e-05
Evaluating epoch: 28


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 28/100, Loss: 0.027064154745429243
Accuracy: 0.8538993089832182
F1 Score Weighted: 0.8517427445540809
F1 Score Macro: 0.8515827621133464
Training epoch: 29


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.17it/s]


Epoch 29/100, Loss: 7.7074763491558e-05
Evaluating epoch: 29


100%|██████████| 190/190 [00:17<00:00, 11.13it/s]


Epoch 29/100, Loss: 0.026398468431317948
Accuracy: 0.8583415597235933
F1 Score Weighted: 0.8558507936866044
F1 Score Macro: 0.856484642947511
Training epoch: 30


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 30/100, Loss: 8.314025304437252e-05
Evaluating epoch: 30


100%|██████████| 190/190 [00:16<00:00, 11.19it/s]


Epoch 30/100, Loss: 0.025656624415974577
Accuracy: 0.8608094768015795
F1 Score Weighted: 0.858447207009448
F1 Score Macro: 0.856097182825439
Training epoch: 31


100%|██████████| 1137/1137 [04:33<00:00,  4.16it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 31/100, Loss: 5.667606228990327e-05
Evaluating epoch: 31


100%|██████████| 190/190 [00:16<00:00, 11.18it/s]


Epoch 31/100, Loss: 0.02497968253092785
Accuracy: 0.8655807831523528
F1 Score Weighted: 0.8633606197996099
F1 Score Macro: 0.8623247663992026
Training epoch: 32


100%|██████████| 1137/1137 [04:32<00:00,  4.17it/s]
100%|██████████| 1121/1121 [04:28<00:00,  4.17it/s]


Epoch 32/100, Loss: 3.0714596420579514e-05
Evaluating epoch: 32


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 32/100, Loss: 0.02649306217441029
Accuracy: 0.8575189206975979
F1 Score Weighted: 0.8552995968034782
F1 Score Macro: 0.8538429016796815
Training epoch: 33


100%|██████████| 1137/1137 [04:34<00:00,  4.15it/s]
100%|██████████| 1121/1121 [04:29<00:00,  4.17it/s]


Epoch 33/100, Loss: 8.966803801321443e-05
Evaluating epoch: 33


100%|██████████| 190/190 [00:17<00:00, 11.15it/s]


Epoch 33/100, Loss: 0.026145660539844573
Accuracy: 0.8571898650871997
F1 Score Weighted: 0.8549658295503484
F1 Score Macro: 0.8558709849082771
Training epoch: 34


 56%|█████▋    | 641/1137 [02:34<01:59,  4.14it/s]


KeyboardInterrupt: 

In [14]:
len(train_dataset)

36384

In [None]:
model.load_state_dict(torch.load('models/vit_reduced_model_23.pth'))

<All keys matched successfully>

In [None]:
with torch.no_grad():
    true_labels = []
    pred_labels = []
    len_test = 0
    for i in range(1):
        try:
            test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_{i}.pth')
            test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        except:
            break

        for inputs, labels in tqdm(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels).to(device)
            eval_loss += loss.item()
            
            outputs = torch.argmax(outputs.logits, 1).flatten().cpu().numpy()
            labels = labels.flatten().cpu().numpy()
            
            true_labels.extend(labels)
            pred_labels.extend(outputs)
        
        len_test += len(test_dataset)
        del test_dataset
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {eval_loss/len_test}')
    print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
    print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
    print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')

100%|██████████| 186/186 [00:16<00:00, 10.98it/s]

Epoch 93/100, Loss: 0.0652661323842611
Accuracy: 0.7968382105617222
F1 Score Weighted: 0.7896039842586119
F1 Score Macro: 0.7896039842586119



