In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import Compose, ToTensor, Lambda, Resize, Normalize
from PIL import Image, ImageDraw
from torch.utils.data import DataLoader, Dataset
import pandas as pd
import numpy as np

from sklearn.metrics import accuracy_score,f1_score

from tqdm import tqdm
from transformers import ViTForImageClassification, TrainingArguments, Trainer

## Finetuning Vision Transformer (ViT) for Image Classification

### Parameter

In [2]:
DIRECTROY = 'data'
MODEL_PATH = 'models'
IMG_SIZE = 224
BATCH_SIZE = 32
EPOCHS = 50
LR = 0.0001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [3]:
device

device(type='cuda')

In [4]:
df_train = pd.read_csv(f'{DIRECTROY}/reduced_train.csv') 
df_test = pd.read_csv(f'{DIRECTROY}/reduced_test.csv') 
num_classes = len(df_train['newid'].unique())

In [5]:
df_test_public = df_test[df_test['Usage'] == 'Public']
df_test_private = df_test[df_test['Usage'] == 'Private']

In [6]:
len(df_test_public), len(df_test_private)

(2026, 4052)

# Divide the train into multiple chunks

Due to lack of ram, i will have to divide the dataloader into multiple dataloader and save it in SSD.

In [7]:
image_transforms = Compose([
    Resize((IMG_SIZE, IMG_SIZE)),
    ToTensor(), 
    Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
])

In [8]:
from transformers import CLIPTokenizerFast

### Custom Dataset
Creating image with border like in the EDA

In [9]:
class CustomDataset(Dataset):
    def __init__(self, df, transforms, directory):
        self.tokenizer =  CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch16")
        self.df = df
        self.transforms = transforms
        self.directory = directory
        self.labels = torch.Tensor(df['newid'].values).long()
        self.imgs = torch.cat([ self.transforms(self.resize_img(Image.open(f'{DIRECTROY}/{self.directory}/{x}')).convert('RGB')).half().reshape(1,3,IMG_SIZE,IMG_SIZE) for x in tqdm(df['name'].values)])
        self.tokenized = self.tokenizer(df['label'].tolist(), padding=True, truncation=True, return_tensors="pt")
        self.input_ids = self.tokenized['input_ids']
        self.attention_mask = self.tokenized['attention_mask']
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        img = self.imgs[idx]
        label = self.labels[idx]
        input_ids = self.input_ids[idx]
        attention_mask = self.attention_mask[idx]
        return img, label, input_ids, attention_mask
    
    @staticmethod
    def resize_img(a):
        w, h = a.size
        if w < h:
            scale = h/IMG_SIZE
            w = int(w/scale)
            h = IMG_SIZE
            a = a.resize((w, h))
            lb= np.array([a.load()[0,x] for x in range(h)])
            rb = np.array([a.load()[w-1,x] for x in range(h)])
            lb = lb.mean(axis=0).astype('uint8')
            rb = rb.mean(axis=0).astype('uint8')
            pic = Image.new('RGB', (h, h), color = (255, 255, 255))
            imgl = Image.new('RGB', (h//2, h), color = tuple(lb))
            imgr = Image.new('RGB', (h//2, h), color = tuple(rb))
            
            pic.paste(imgl, (0, 0))
            pic.paste(imgr, (h//2, 0))
            
            pic.paste(a, (h//2-w//2, 0))

        elif w>h:
            scale = w/IMG_SIZE
            h = int(h/scale)
            w = IMG_SIZE
            a = a.resize((w, h))
            
            lb= np.array([a.load()[x,0] for x in range(w)])
            rb = np.array([a.load()[x,h-1] for x in range(w)])
            lb = lb.mean(axis=0).astype('uint8')
            rb = rb.mean(axis=0).astype('uint8')
            
            pic = Image.new('RGB', (w, w), color = (255, 255, 255))
            imgl = Image.new('RGB', (w, w//2), color = tuple(lb))
            imgr = Image.new('RGB', (w, w//2), color = tuple(rb))
            
            pic.paste(imgl, (0, 0))
            pic.paste(imgr, (0, w//2))
            
            pic.paste(a, (0, w//2-h//2))

        else:
            a = a.resize((IMG_SIZE, IMG_SIZE))
            pic = a
        return pic

In [10]:
df_train = df_train.sample(frac=1).reset_index(drop=True)
df_test_public = df_test_public.sample(frac=1).reset_index(drop=True)
df_test_private = df_test_private.sample(frac=1).reset_index(drop=True)

In [11]:
len(df_train['newid'].unique()), len(df_test['newid'].unique())

(649, 649)

In [12]:
df_train.head()

Unnamed: 0,name,class,group,label,newid
0,86940.jpg,6048,195,warship leog set,330
1,10895.jpg,1363,23,blue workout short,90
2,49510.jpg,4164,145,white earphones,364
3,40239.jpg,3735,60,pink luggage with vertical stripes,638
4,8492.jpg,1107,23,black jogging trousers with white stripe,21


In [37]:
import math

In [13]:
train_dataset = CustomDataset(df_train, image_transforms, 'train')
torch.save(train_dataset, f'{DIRECTROY}/train_dataset/train_dataset_reduced_prompts.pth')

100%|██████████| 29243/29243 [06:16<00:00, 77.66it/s]


In [38]:
for i in range(math.ceil(len(df_train)/18192)):
    train_dataset = CustomDataset(df_train[i*18192:(i+1)*18192], image_transforms, 'train')
    torch.save(train_dataset, f'{DIRECTROY}/train_dataset/train_dataset_reduced_{i}.pth')
    del train_dataset

100%|██████████| 18192/18192 [04:03<00:00, 74.67it/s]
100%|██████████| 11051/11051 [02:23<00:00, 76.78it/s]


In [39]:
for i in range(math.ceil(len(df_test)/18192)):
    test_dataset = CustomDataset(df_test[i*18192:(i+1)*18192], image_transforms, 'test')
    torch.save(test_dataset, f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_{i}.pth')
    del test_dataset

# for i in range(math.ceil(len(df_test_private)/16192)):
#     test_dataset = CustomDataset(df_test_private[i*16192:(i+1)*16192], image_transforms, 'test')
#     torch.save(test_dataset, f'{DIRECTROY}/test_private_dataset/test_private_reduced_dataset_{i}.pth')
#     del test_dataset


100%|██████████| 6078/6078 [01:08<00:00, 88.28it/s] 


### Load ViT

In [40]:
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224-in21k', num_labels=num_classes)
model = model.to(device)


Some weights of ViTForImageClassification were not initialized from the model checkpoint at google/vit-base-patch16-224-in21k and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [41]:
model.parameters

<bound method Module.parameters of ViTForImageClassification(
  (vit): ViTModel(
    (embeddings): ViTEmbeddings(
      (patch_embeddings): ViTPatchEmbeddings(
        (projection): Conv2d(3, 768, kernel_size=(16, 16), stride=(16, 16))
      )
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): ViTEncoder(
      (layer): ModuleList(
        (0-11): 12 x ViTLayer(
          (attention): ViTAttention(
            (attention): ViTSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
            (output): ViTSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.0, inplace=False)
            )
          )
          (intermediate): ViTIntermediate(
      

### Training

In [42]:
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr = LR)
scheduler = optim.lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.1, total_iters=EPOCHS)

In [44]:
max_accuracy = 0.0

for epoch in range(EPOCHS):
    model.train()
    train_loss = 0.0
    
    # Training loop
    print('Training epoch:', epoch+1)
    len_train = 0
    for i in range(2):
        try:
            train_dataset = torch.load(f'{DIRECTROY}/train_dataset/train_dataset_reduced_{i}.pth')
            train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
        except:
            break
        
        for inputs, labels, input_ids, attention_mask in tqdm(train_dataloader, desc=f'Batch {i+1}/{2}'):
            optimizer.zero_grad()
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)
            loss = criterion(outputs.logits, labels)

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
        len_train += len(train_dataset)
        del train_dataset
        
    scheduler.step()    
    train_loss/=len_train
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {train_loss}')
    
    eval_loss = 0.0
    model.eval()
    
    true_labels = []
    pred_labels = []
    
    print('Evaluating epoch:', epoch+1)
    with torch.no_grad():
        len_test = 0
        for i in range(1):
            try:
                test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_{i}.pth')
                test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
            except:
                break

            for inputs, labels, input_ids, attention_mask in tqdm(test_dataloader):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                outputs = model(inputs)
                loss = criterion(outputs.logits, labels).to(device)
                eval_loss += loss.item()
                
                outputs = torch.argmax(outputs.logits, 1).flatten().cpu().numpy()
                labels = labels.flatten().cpu().numpy()
                
                true_labels.extend(labels)
                pred_labels.extend(outputs)
            
            len_test += len(test_dataset)
            del test_dataset
        
        print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {eval_loss/len_test}')
        print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
        print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
        print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')
        if accuracy_score(true_labels, pred_labels) > max_accuracy:
            max_accuracy = accuracy_score(true_labels, pred_labels)
            torch.save(model.state_dict(), f'{MODEL_PATH}/vit_reduced_model_{epoch+1}.pth')
            torch.save(optimizer.state_dict(), f'{MODEL_PATH}/optimizer/vit_reduced_optimizer_{epoch+1}.pth')
            

Training epoch: 1


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:24<00:00,  4.10it/s]


Epoch 1/50, Loss: 0.14573990254682134
Evaluating epoch: 1


100%|██████████| 190/190 [00:17<00:00, 11.15it/s]


Epoch 1/50, Loss: 0.09127754226018034
Accuracy: 0.5378413951957881
F1 Score Weighted: 0.45008926881853034
F1 Score Macro: 0.38259801624946366
Training epoch: 2


Batch 1/2: 100%|██████████| 569/569 [02:19<00:00,  4.09it/s]
Batch 2/2: 100%|██████████| 346/346 [01:24<00:00,  4.08it/s]


Epoch 2/50, Loss: 0.049591023773663236
Evaluating epoch: 2


100%|██████████| 190/190 [00:17<00:00, 11.01it/s]


Epoch 2/50, Loss: 0.03428297731369567
Accuracy: 0.7777229351760447
F1 Score Weighted: 0.7493998221332343
F1 Score Macro: 0.7244291157540531
Training epoch: 3


Batch 1/2: 100%|██████████| 569/569 [02:18<00:00,  4.11it/s]
Batch 2/2: 100%|██████████| 346/346 [01:24<00:00,  4.09it/s]


Epoch 3/50, Loss: 0.014537394342584975
Evaluating epoch: 3


100%|██████████| 190/190 [00:16<00:00, 11.22it/s]


Epoch 3/50, Loss: 0.022389712749283815
Accuracy: 0.8311944718657454
F1 Score Weighted: 0.8195922461130848
F1 Score Macro: 0.8113776621747909
Training epoch: 4


Batch 1/2: 100%|██████████| 569/569 [02:20<00:00,  4.04it/s]
Batch 2/2: 100%|██████████| 346/346 [01:25<00:00,  4.05it/s]


Epoch 4/50, Loss: 0.005975931678778153
Evaluating epoch: 4


100%|██████████| 190/190 [00:17<00:00, 11.05it/s]


Epoch 4/50, Loss: 0.018364740447970197
Accuracy: 0.8483053636064495
F1 Score Weighted: 0.8445080940737778
F1 Score Macro: 0.8404901268621798
Training epoch: 5


Batch 1/2: 100%|██████████| 569/569 [02:17<00:00,  4.14it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.13it/s]


Epoch 5/50, Loss: 0.003213073838799881
Evaluating epoch: 5


100%|██████████| 190/190 [00:16<00:00, 11.24it/s]


Epoch 5/50, Loss: 0.01796701523791306
Accuracy: 0.8547219480092135
F1 Score Weighted: 0.8522202059222657
F1 Score Macro: 0.848309361242108
Training epoch: 6


Batch 1/2: 100%|██████████| 569/569 [02:18<00:00,  4.12it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.12it/s]


Epoch 6/50, Loss: 0.002325428578635351
Evaluating epoch: 6


100%|██████████| 190/190 [00:16<00:00, 11.22it/s]


Epoch 6/50, Loss: 0.02039609242790857
Accuracy: 0.8382691674893057
F1 Score Weighted: 0.8360674552399182
F1 Score Macro: 0.8354821041042875
Training epoch: 7


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.17it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.16it/s]


Epoch 7/50, Loss: 0.0022506316160632277
Evaluating epoch: 7


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 7/50, Loss: 0.02031139097494901
Accuracy: 0.8433695294504772
F1 Score Weighted: 0.8412348748248998
F1 Score Macro: 0.8412909084723578
Training epoch: 8


Batch 1/2: 100%|██████████| 569/569 [02:17<00:00,  4.15it/s]
Batch 2/2: 100%|██████████| 346/346 [01:22<00:00,  4.17it/s]


Epoch 8/50, Loss: 0.0011626219869239475
Evaluating epoch: 8


100%|██████████| 190/190 [00:16<00:00, 11.30it/s]


Epoch 8/50, Loss: 0.019108110456995595
Accuracy: 0.8538993089832182
F1 Score Weighted: 0.8525541394994909
F1 Score Macro: 0.8532708808379842
Training epoch: 9


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:22<00:00,  4.17it/s]


Epoch 9/50, Loss: 0.0011413313922211561
Evaluating epoch: 9


100%|██████████| 190/190 [00:16<00:00, 11.24it/s]


Epoch 9/50, Loss: 0.023172094380926166
Accuracy: 0.8287265547877591
F1 Score Weighted: 0.8262422323615807
F1 Score Macro: 0.8299175787407709
Training epoch: 10


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.17it/s]


Epoch 10/50, Loss: 0.0013352026106759835
Evaluating epoch: 10


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 10/50, Loss: 0.02146194301660235
Accuracy: 0.8456729187232642
F1 Score Weighted: 0.8425024241692859
F1 Score Macro: 0.8415448775537622
Training epoch: 11


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.17it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.16it/s]


Epoch 11/50, Loss: 0.0008156375936115467
Evaluating epoch: 11


100%|██████████| 190/190 [00:16<00:00, 11.29it/s]


Epoch 11/50, Loss: 0.02361547756068563
Accuracy: 0.8338269167489306
F1 Score Weighted: 0.8309433018620547
F1 Score Macro: 0.8292314936779909
Training epoch: 12


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.16it/s]


Epoch 12/50, Loss: 0.0011516809015841608
Evaluating epoch: 12


100%|██████████| 190/190 [00:16<00:00, 11.31it/s]


Epoch 12/50, Loss: 0.02105271387819229
Accuracy: 0.8455083909180652
F1 Score Weighted: 0.8434422117626995
F1 Score Macro: 0.8428856300773334
Training epoch: 13


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:22<00:00,  4.17it/s]


Epoch 13/50, Loss: 0.000836544698308008
Evaluating epoch: 13


100%|██████████| 190/190 [00:16<00:00, 11.29it/s]


Epoch 13/50, Loss: 0.02214386651474632
Accuracy: 0.8402435011516947
F1 Score Weighted: 0.8383634082986706
F1 Score Macro: 0.8371327771418673
Training epoch: 14


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.17it/s]


Epoch 14/50, Loss: 0.0009671804546215652
Evaluating epoch: 14


100%|██████████| 190/190 [00:16<00:00, 11.28it/s]


Epoch 14/50, Loss: 0.0223157157371873
Accuracy: 0.8402435011516947
F1 Score Weighted: 0.8388676751253572
F1 Score Macro: 0.8388778709307733
Training epoch: 15


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:22<00:00,  4.17it/s]


Epoch 15/50, Loss: 0.0010105260511310177
Evaluating epoch: 15


100%|██████████| 190/190 [00:16<00:00, 11.23it/s]


Epoch 15/50, Loss: 0.022713002445395054
Accuracy: 0.8374465284633102
F1 Score Weighted: 0.8354437933383115
F1 Score Macro: 0.8331318983341596
Training epoch: 16


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.17it/s]
Batch 2/2: 100%|██████████| 346/346 [01:22<00:00,  4.18it/s]


Epoch 16/50, Loss: 0.0006838950092306457
Evaluating epoch: 16


100%|██████████| 190/190 [00:16<00:00, 11.26it/s]


Epoch 16/50, Loss: 0.02193513100435453
Accuracy: 0.8463310299440605
F1 Score Weighted: 0.8433556525570296
F1 Score Macro: 0.8433211919180104
Training epoch: 17


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.15it/s]
Batch 2/2: 100%|██████████| 346/346 [01:25<00:00,  4.06it/s]


Epoch 17/50, Loss: 0.0003850296496785661
Evaluating epoch: 17


100%|██████████| 190/190 [00:17<00:00, 11.17it/s]


Epoch 17/50, Loss: 0.02102273195105819
Accuracy: 0.8550510036196117
F1 Score Weighted: 0.8529228373815335
F1 Score Macro: 0.8521525491461085
Training epoch: 18


Batch 1/2: 100%|██████████| 569/569 [02:17<00:00,  4.14it/s]
Batch 2/2: 100%|██████████| 346/346 [01:22<00:00,  4.17it/s]


Epoch 18/50, Loss: 0.0001986946664409073
Evaluating epoch: 18


100%|██████████| 190/190 [00:17<00:00, 11.05it/s]


Epoch 18/50, Loss: 0.02241643195885732
Accuracy: 0.8496215860480422
F1 Score Weighted: 0.8462687210784139
F1 Score Macro: 0.8455719290787578
Training epoch: 19


Batch 1/2: 100%|██████████| 569/569 [02:17<00:00,  4.13it/s]
Batch 2/2: 100%|██████████| 346/346 [01:24<00:00,  4.11it/s]


Epoch 19/50, Loss: 0.0008832855096085872
Evaluating epoch: 19


100%|██████████| 190/190 [00:17<00:00, 11.12it/s]


Epoch 19/50, Loss: 0.021663261488078185
Accuracy: 0.8506087528792365
F1 Score Weighted: 0.8482072667913492
F1 Score Macro: 0.8471723176138761
Training epoch: 20


Batch 1/2: 100%|██████████| 569/569 [02:17<00:00,  4.13it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.14it/s]


Epoch 20/50, Loss: 0.00041185608493791645
Evaluating epoch: 20


100%|██████████| 190/190 [00:17<00:00, 11.14it/s]


Epoch 20/50, Loss: 0.022260042544477013
Accuracy: 0.8530766699572228
F1 Score Weighted: 0.8500867093721122
F1 Score Macro: 0.8495926349223979
Training epoch: 21


Batch 1/2: 100%|██████████| 569/569 [02:17<00:00,  4.14it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.16it/s]


Epoch 21/50, Loss: 0.0007386841103199447
Evaluating epoch: 21


100%|██████████| 190/190 [00:16<00:00, 11.31it/s]


Epoch 21/50, Loss: 0.024322647407072358
Accuracy: 0.834814083580125
F1 Score Weighted: 0.8324470635322984
F1 Score Macro: 0.8319920034806033
Training epoch: 22


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.17it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.17it/s]


Epoch 22/50, Loss: 0.0005174366625659033
Evaluating epoch: 22


100%|██████████| 190/190 [00:16<00:00, 11.23it/s]


Epoch 22/50, Loss: 0.02113734166695319
Accuracy: 0.8566962816716025
F1 Score Weighted: 0.8534438943267805
F1 Score Macro: 0.8523601114677894
Training epoch: 23


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.17it/s]


Epoch 23/50, Loss: 0.0002485097833241033
Evaluating epoch: 23


100%|██████████| 190/190 [00:16<00:00, 11.25it/s]


Epoch 23/50, Loss: 0.019615839466859575
Accuracy: 0.8626192826587693
F1 Score Weighted: 0.8606000346746759
F1 Score Macro: 0.8600062762829117
Training epoch: 24


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:24<00:00,  4.12it/s]


Epoch 24/50, Loss: 0.00013212004166517682
Evaluating epoch: 24


100%|██████████| 190/190 [00:17<00:00, 11.17it/s]


Epoch 24/50, Loss: 0.02124054182476207
Accuracy: 0.8558736426456071
F1 Score Weighted: 0.854254868723441
F1 Score Macro: 0.8537816764958248
Training epoch: 25


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.16it/s]


Epoch 25/50, Loss: 0.0007282051153206233
Evaluating epoch: 25


100%|██████████| 190/190 [00:16<00:00, 11.23it/s]


Epoch 25/50, Loss: 0.024818544457883396
Accuracy: 0.8381046396841066
F1 Score Weighted: 0.8350292752619815
F1 Score Macro: 0.8368754123792098
Training epoch: 26


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.17it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.15it/s]


Epoch 26/50, Loss: 0.00031952447061521223
Evaluating epoch: 26


100%|██████████| 190/190 [00:17<00:00, 11.17it/s]


Epoch 26/50, Loss: 0.02121683032688289
Accuracy: 0.861796643632774
F1 Score Weighted: 0.8595228987367851
F1 Score Macro: 0.8593143676316556
Training epoch: 27


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.16it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.15it/s]


Epoch 27/50, Loss: 0.0001357601484809945
Evaluating epoch: 27


100%|██████████| 190/190 [00:16<00:00, 11.20it/s]


Epoch 27/50, Loss: 0.02224517815602562
Accuracy: 0.859986837775584
F1 Score Weighted: 0.8574346328365929
F1 Score Macro: 0.8556761613072088
Training epoch: 28


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.17it/s]
Batch 2/2: 100%|██████████| 346/346 [01:22<00:00,  4.18it/s]


Epoch 28/50, Loss: 0.0003794377338279796
Evaluating epoch: 28


100%|██████████| 190/190 [00:17<00:00, 11.15it/s]


Epoch 28/50, Loss: 0.024186335061109857
Accuracy: 0.8405725567620927
F1 Score Weighted: 0.8385789577932882
F1 Score Macro: 0.838112621940565
Training epoch: 29


Batch 1/2: 100%|██████████| 569/569 [02:16<00:00,  4.17it/s]
Batch 2/2: 100%|██████████| 346/346 [01:23<00:00,  4.16it/s]


Epoch 29/50, Loss: 0.00024532172882561195
Evaluating epoch: 29


100%|██████████| 190/190 [00:16<00:00, 11.22it/s]


Epoch 29/50, Loss: 0.022219705306231137
Accuracy: 0.8538993089832182
F1 Score Weighted: 0.85176557305243
F1 Score Macro: 0.8508067435974935
Training epoch: 30


Batch 1/2:  31%|███       | 177/569 [00:42<01:34,  4.14it/s]


KeyboardInterrupt: 

In [None]:
model.load_state_dict(torch.load('models/vit_reduced_model_23.pth'))

<All keys matched successfully>

In [None]:
with torch.no_grad():
    true_labels = []
    pred_labels = []
    len_test = 0
    for i in range(1):
        try:
            test_dataset = torch.load(f'{DIRECTROY}/test_public_dataset/test_public_reduced_dataset_{i}.pth')
            test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
        except:
            break

        for inputs, labels in tqdm(test_dataloader):
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs.logits, labels).to(device)
            eval_loss += loss.item()
            
            outputs = torch.argmax(outputs.logits, 1).flatten().cpu().numpy()
            labels = labels.flatten().cpu().numpy()
            
            true_labels.extend(labels)
            pred_labels.extend(outputs)
        
        len_test += len(test_dataset)
        del test_dataset
    print(f'Epoch {epoch+1}/{EPOCHS}, Loss: {eval_loss/len_test}')
    print(f'Accuracy: {accuracy_score(true_labels, pred_labels)}')
    print(f'F1 Score Weighted: {f1_score(true_labels, pred_labels, average="weighted")}')
    print(f'F1 Score Macro: {f1_score(true_labels, pred_labels, average="macro")}')

100%|██████████| 186/186 [00:16<00:00, 10.98it/s]

Epoch 93/100, Loss: 0.0652661323842611
Accuracy: 0.7968382105617222
F1 Score Weighted: 0.7896039842586119
F1 Score Macro: 0.7896039842586119



