In [1]:
import os
import requests
import tarfile
import time

from torchvision import datasets, transforms
from torchvision.transforms import v2
from torch.utils.data import DataLoader, random_split, Dataset, Subset
import torchvision.models as models
import torch.nn as nn
import torch
from torch import optim
from torchsummary import summary

import PIL.Image
import pathlib

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from copy import deepcopy
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import timm
import uuid
import pickle

# check if CUDA is available
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

CUDA is available!  Training on GPU ...


### Dataloading

In [2]:
transform = transforms.Compose([transforms.Resize((448,448)),
                                transforms.RandomHorizontalFlip(),
                                transforms.RandomRotation(10),
                                transforms.RandomAffine(0, shear=5, scale=(0.8,1.2)), 
                              #   transforms.RandomGrayscale(p=0.1), 
                                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), 
                                transforms.ToTensor(), 
                                transforms.Normalize((0.4815, 0.4578, 0.4082), (0.2686, 0.2613, 0.2758)), 
                                      ])

val_transform = transforms.Compose([transforms.Resize((448,448)),
                                        transforms.ToTensor(),
                                        transforms.Normalize((0.4815, 0.4578, 0.4082), (0.2686, 0.2613, 0.2758)),
                                        ])

In [3]:
class TransformedDataset(Dataset):
    def __init__(self, dataset: Dataset, transform: transforms.Compose):
        self.dataset = dataset
        self.transform = transform
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, index):
        img, label = self.dataset[index]

        if self.transform:
            img = self.transform(img)
        
        return img, label

def stratified_split(dataset, val_split=0.):
    targets = np.array(dataset.targets)

    train_indices, val_indices = train_test_split(
        np.arange(targets.shape[0]),
        test_size=val_split,
        stratify=targets
    )

    # train_dataset = Subset(dataset, indices=train_indices)
    # val_dataset = Subset(dataset, indices=val_indices)
    # return train_dataset, val_dataset

    return train_indices, val_indices

In [4]:
# Load the dataset separately for training and validation
dataset = datasets.ImageFolder(root = "./final_data")

# train_indices, val_indices = stratified_split(dataset, val_split=0.2)

# Loading the indices from the saved pickle file to ensure the same split is used across different models
with open('train_indices.pkl', 'rb') as f:
    train_indices = pickle.load(f)

with open('val_indices.pkl', 'rb') as f:
    val_indices = pickle.load(f)

# Split the dataset into training and validation
train_dataset = Subset(dataset, train_indices)
val_dataset = Subset(dataset, val_indices)

transformed_train = TransformedDataset(train_dataset, transform)
transformed_val = TransformedDataset(val_dataset, val_transform)

In [5]:
train_loader = DataLoader(transformed_train, batch_size=32, shuffle=True)
val_loader = DataLoader(transformed_val, batch_size=32, shuffle=False)

In [6]:
# # Store all datapoints from transformed_train
# train_images = []
# train_labels = []
# for i in range(len(transformed_train)):
#     img, label = transformed_train[i]
#     train_images.append(img)
#     train_labels.append(label)

# # Store all datapoints from transformed_val
# val_images = []
# val_labels = []
# for i in range(len(transformed_val)):
#     img, label = transformed_val[i]
#     val_images.append(img)
#     val_labels.append(label)

# train_full = (torch.stack(train_images).to(device), torch.tensor(train_labels).to(device))
# val_full = (torch.stack(val_images).to(device), torch.tensor(val_labels).to(device))

### Model Instantiation

In [7]:
# Define the Vision Transformer model
class CustomViTModel(nn.Module):
    def __init__(self, num_classes=10):
        super(CustomViTModel, self).__init__()
        # Load the pre-trained ViT model
        self.base_vit = timm.create_model('eva02_large_patch14_448.mim_m38m_ft_in22k_in1k', pretrained=True)
        data_config = timm.data.resolve_model_data_config(self.base_vit)
        self.transforms_train = timm.data.create_transform(**data_config, is_training=True)
        self.transforms_val = timm.data.create_transform(**data_config, is_training=False)
        
        # Freeze the base model
        for param in self.base_vit.parameters():
            param.requires_grad = False

        # Replace the classifier head
        self.base_vit.head = nn.Sequential(
            # nn.Dropout(0.5),
            nn.Linear(self.base_vit.head.in_features, 256),
            nn.ReLU(),
            # nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.base_vit(x)

# Instantiate the model
num_classes = len(dataset.classes)  # Adjust according to your specific number of classes
model = CustomViTModel(num_classes=num_classes)
model.to(device)

model_paradigm = 'ViT'

### Training Setup - Model Evaluation

In [8]:
def top_k_accuracy(output, target, k=5):
    batch_size = target.size(0)
    _, pred = output.topk(k, 1, True, True)  # Get top-k predictions
    pred = pred.t()  # Transpose predictions for comparison
    correct = pred.eq(target.reshape(1, -1).expand_as(pred))  # Compare predictions with target
    correct_k = correct[:k].reshape(-1).float().sum(0, keepdim = True)  # Calculate correct top-k
    return correct_k.mul_(1.0 / batch_size).detach()  # Calculate top-k accuracy

def evaluate(model, loss_fn, data_loader):
    model.eval()

    loss = 0
    correct = 0
    total = 0
    top_1_accuracy = 0
    top_5_accuracy = 0

    progress_bar = tqdm(data_loader, desc = "Validating")

    with torch.no_grad():
        for batchX, batchY in progress_bar:
            batchX, batchY = batchX.to(device), batchY.to(device)

            output = model(batchX)
            predicted_labels = torch.argmax(output, dim = 1)

            loss += loss_fn(output, batchY).detach() * batchX.size(0)
            correct += (predicted_labels == batchY.type(torch.long)).sum().detach()
            total += batchX.size(0)
            top_1_accuracy += top_k_accuracy(output, batchY, k=1) * batchX.size(0)
            top_5_accuracy += top_k_accuracy(output, batchY, k=5) * batchX.size(0)
    
    return loss.item() / total, correct.item() / total, top_1_accuracy.item() / total, top_5_accuracy.item() / total

def evaluate_all(model, loss_fn, allX, allY):
    model.eval()

    loss = 0
    correct = 0
    top_1_accuracy = 0
    top_5_accuracy = 0

    allX, allY = allX.to(device), allY.to(device)

    with torch.no_grad():
        output = model(allX)
        predicted_labels = torch.argmax(output, dim = 1)

        loss += loss_fn(output, allY.type(torch.long)).detach()
        correct += (predicted_labels == allY.type(torch.long)).sum().detach()
        top_1_accuracy += top_k_accuracy(output, allY, k=1)
        top_5_accuracy += top_k_accuracy(output, allY, k=5)
    
    return loss.item(), correct.item() / allX.size(0), top_1_accuracy.item(), top_5_accuracy.item()

In [9]:
def plot_model_history(his):
    fig = plt.figure(figsize=(8, 5))
    ax = fig.add_subplot(111)
    ln1 = ax.plot(his['train_loss'], 'b--',label='loss')
    ln2 = ax.plot(his['val_loss'], 'b-',label='val_loss')
    ax.set_ylabel('loss', color='blue')
    ax.tick_params(axis='y', colors="blue")

    ax2 = ax.twinx()
    ln3 = ax2.plot(his['train_acc'], 'r--',label='accuracy')
    ln4 = ax2.plot(his['val_acc'], 'r-',label='val_accuracy')
    ax2.set_ylabel('accuracy', color='red')
    ax2.tick_params(axis='y', colors="red")

    lns = ln1 + ln2 + ln3 + ln4
    labels = [l.get_label() for l in lns]
    ax.legend(lns, labels, loc=7)
    plt.grid(True)
    plt.show()

In [10]:
optim_dict = {"Adam":optim.Adam, "Adadelta":optim.Adadelta, "Adagrad":optim.Adagrad,
              "Adamax":optim.Adamax, "AdamW": optim.AdamW, "ASGD":optim.ASGD,
              "NAdam":optim.NAdam, "RMSprop":optim.RMSprop, "RAdam":optim.RAdam,
              "Rprop": optim.Rprop, "SGD":optim.SGD}


# Loss and optimiser
# NOTE: Please note that different learning_rates were used for different models at different stages of experimentation.
# learning_rate = 0.0001
learning_rate = 0.0001
loss_fn = nn.CrossEntropyLoss()
optimiser = optim_dict["Adam"](model.parameters(), lr=learning_rate)
num_epochs = 50

In [11]:
history = {
    'train_loss': [],
    'val_loss': [],
    'train_acc': [],
    'val_acc': []
}

In [12]:
best_val_loss = float('inf')
best_val_acc = -1

# Early stopping - based on validation loss
patience_counter = 0
patience = 20

for epoch in range(num_epochs):
    model.train()

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}')

    running_loss = 0.0
    running_correct = 0
    total = 0

    for X, y in progress_bar:
        X, y = X.to(device), y.to(device)

        outputs = model(X)

        loss = loss_fn(outputs, y)

        loss.backward()
        optimiser.step()
        optimiser.zero_grad()

        running_loss += loss.detach() * X.size(0)
        running_correct += (torch.argmax(outputs, dim = 1) == y.type(torch.long)).sum().detach()
        total += X.size(0)
    
    running_loss = running_loss.item()
    running_correct = running_correct.item()

    # Evaluate the model after training is done instead of using running averages
    # train_loss, train_acc = evaluate_all(model, loss_fn, train_full[0], train_full[1])
    train_loss, train_acc = running_loss / total, running_correct / total
    # val_loss, val_acc = evaluate_all(model, loss_fn, val_full[0], val_full[1])
    val_loss, val_acc, top_1, top_5 = evaluate(model, loss_fn, val_loader)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f'best_model_{model_paradigm}.pth')

    # Patience is counted based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), f'best_model_acc_{model_paradigm}.pth')
        patience_counter = 0
    else:
        patience_counter += 1

    # torch.save(model.state_dict(), f'model_{model_paradigm}_epoch_{epoch+1}.pth')
    
    tqdm.write(f'Loss: {train_loss:.4f} - Accuracy: {train_acc*100:.4f}% - Val Loss: {val_loss:.4f} - Val Accuracy: {val_acc*100:.4f}% - Top 1 Accuracy: {top_1} - Top 5 Accuracy: {top_5}')

    if patience_counter == patience:
        print(f'Early stopping: patience limit reached after epoch {epoch + 1}')
        break

  x = F.scaled_dot_product_attention(
Epoch 1: 100%|██████████| 144/144 [02:03<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.19it/s]


Loss: 0.8932 - Accuracy: 75.5493% - Val Loss: 0.4557 - Val Accuracy: 85.9130% - Top 1 Accuracy: 0.8591304347826086 - Top 5 Accuracy: 0.9973913043478261


Epoch 2: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.3868 - Accuracy: 88.0574% - Val Loss: 0.3302 - Val Accuracy: 89.8261% - Top 1 Accuracy: 0.8982608695652174 - Top 5 Accuracy: 0.9982608695652174


Epoch 3: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:51<00:00,  1.43s/it]


Loss: 0.3069 - Accuracy: 90.2763% - Val Loss: 0.2880 - Val Accuracy: 91.0435% - Top 1 Accuracy: 0.9104347826086957 - Top 5 Accuracy: 0.9991304347826087


Epoch 4: 100%|██████████| 144/144 [02:12<00:00,  1.09it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.17it/s]


Loss: 0.2647 - Accuracy: 91.5815% - Val Loss: 0.2744 - Val Accuracy: 91.0435% - Top 1 Accuracy: 0.9104347826086957 - Top 5 Accuracy: 0.9982608695652174


Epoch 5: 100%|██████████| 144/144 [02:05<00:00,  1.14it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.17it/s]


Loss: 0.2411 - Accuracy: 92.1906% - Val Loss: 0.2594 - Val Accuracy: 91.1304% - Top 1 Accuracy: 0.9113043478260869 - Top 5 Accuracy: 0.9982608695652174


Epoch 6: 100%|██████████| 144/144 [02:05<00:00,  1.15it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.2200 - Accuracy: 92.9954% - Val Loss: 0.2402 - Val Accuracy: 91.2174% - Top 1 Accuracy: 0.9121739130434783 - Top 5 Accuracy: 0.9982608695652174


Epoch 7: 100%|██████████| 144/144 [02:04<00:00,  1.15it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1985 - Accuracy: 93.7133% - Val Loss: 0.2424 - Val Accuracy: 92.0000% - Top 1 Accuracy: 0.92 - Top 5 Accuracy: 0.9982608695652174


Epoch 8: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1881 - Accuracy: 93.6915% - Val Loss: 0.2296 - Val Accuracy: 91.7391% - Top 1 Accuracy: 0.9173913043478261 - Top 5 Accuracy: 0.9982608695652174


Epoch 9: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1706 - Accuracy: 94.9532% - Val Loss: 0.2316 - Val Accuracy: 90.8696% - Top 1 Accuracy: 0.908695652173913 - Top 5 Accuracy: 0.9991304347826087


Epoch 10: 100%|██████████| 144/144 [02:04<00:00,  1.15it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1618 - Accuracy: 94.6269% - Val Loss: 0.2215 - Val Accuracy: 92.6087% - Top 1 Accuracy: 0.9260869565217391 - Top 5 Accuracy: 0.9982608695652174


Epoch 11: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1539 - Accuracy: 94.9967% - Val Loss: 0.2369 - Val Accuracy: 89.8261% - Top 1 Accuracy: 0.8982608695652174 - Top 5 Accuracy: 0.9982608695652174


Epoch 12: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1386 - Accuracy: 95.5188% - Val Loss: 0.2313 - Val Accuracy: 91.3913% - Top 1 Accuracy: 0.9139130434782609 - Top 5 Accuracy: 1.0


Epoch 13: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1385 - Accuracy: 95.5406% - Val Loss: 0.2258 - Val Accuracy: 92.1739% - Top 1 Accuracy: 0.9217391304347826 - Top 5 Accuracy: 0.9982608695652174


Epoch 14: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1284 - Accuracy: 95.8451% - Val Loss: 0.2294 - Val Accuracy: 91.5652% - Top 1 Accuracy: 0.9156521739130434 - Top 5 Accuracy: 0.9982608695652174


Epoch 15: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.1198 - Accuracy: 96.3454% - Val Loss: 0.2440 - Val Accuracy: 90.9565% - Top 1 Accuracy: 0.9095652173913044 - Top 5 Accuracy: 0.9982608695652174


Epoch 16: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.17it/s]


Loss: 0.1134 - Accuracy: 96.4542% - Val Loss: 0.2198 - Val Accuracy: 91.3043% - Top 1 Accuracy: 0.9130434782608695 - Top 5 Accuracy: 0.9982608695652174


Epoch 17: 100%|██████████| 144/144 [02:04<00:00,  1.15it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.19it/s]


Loss: 0.1084 - Accuracy: 96.5847% - Val Loss: 0.2180 - Val Accuracy: 92.1739% - Top 1 Accuracy: 0.9217391304347826 - Top 5 Accuracy: 0.9991304347826087


Epoch 18: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.19it/s]


Loss: 0.1006 - Accuracy: 96.8675% - Val Loss: 0.2237 - Val Accuracy: 91.7391% - Top 1 Accuracy: 0.9173913043478261 - Top 5 Accuracy: 0.9991304347826087


Epoch 19: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.0992 - Accuracy: 96.9328% - Val Loss: 0.2183 - Val Accuracy: 91.5652% - Top 1 Accuracy: 0.9156521739130434 - Top 5 Accuracy: 0.9991304347826087


Epoch 20: 100%|██████████| 144/144 [02:04<00:00,  1.16it/s]
Validating: 100%|██████████| 36/36 [00:30<00:00,  1.18it/s]


Loss: 0.0950 - Accuracy: 97.0198% - Val Loss: 0.2298 - Val Accuracy: 92.0870% - Top 1 Accuracy: 0.9208695652173913 - Top 5 Accuracy: 0.9991304347826087


Epoch 21:  18%|█▊        | 26/144 [00:22<01:43,  1.14it/s]


KeyboardInterrupt: 

In [13]:
# Load the best model
model.load_state_dict(torch.load(f'best_model_acc_{model_paradigm}.pth'))

  model.load_state_dict(torch.load(f'best_model_acc_{model_paradigm}.pth'))


<All keys matched successfully>

In [14]:
# Unfreeze the model parameters
for param in model.parameters():
    param.requires_grad = True

# Train again for 20 epochs at a much lower learning rate
learning_rate = 0.0000005
loss_fn = nn.CrossEntropyLoss()
optimiser = optim_dict["Adam"](model.parameters(), lr=learning_rate)
num_epochs = 1

In [15]:
best_val_loss = float('inf')
best_val_acc = -1

# Early stopping - based on validation loss
patience_counter = 0
patience = 20

for epoch in range(num_epochs):
    model.train()

    progress_bar = tqdm(train_loader, desc=f'Epoch {epoch + 1}')

    running_loss = 0.0
    running_correct = 0
    total = 0

    for X, y in progress_bar:
        X, y = X.to(device), y.to(device)

        outputs = model(X)

        loss = loss_fn(outputs, y)

        loss.backward()
        optimiser.step()
        optimiser.zero_grad()

        running_loss += loss.detach() * X.size(0)
        running_correct += (torch.argmax(outputs, dim = 1) == y.type(torch.long)).sum().detach()
        total += X.size(0)
    
    running_loss = running_loss.item()
    running_correct = running_correct.item()

    # Evaluate the model after training is done instead of using running averages
    # train_loss, train_acc = evaluate_all(model, loss_fn, train_full[0], train_full[1])
    train_loss, train_acc = running_loss / total, running_correct / total
    # val_loss, val_acc = evaluate_all(model, loss_fn, val_full[0], val_full[1])
    val_loss, val_acc, top_1, top_5 = evaluate(model, loss_fn, val_loader)

    history['train_loss'].append(train_loss)
    history['val_loss'].append(val_loss)
    history['train_acc'].append(train_acc)
    history['val_acc'].append(val_acc)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), f'best_model_warmed_{model_paradigm}.pth')

    # Patience is counted based on validation accuracy
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), f'best_model_warmed_acc_{model_paradigm}.pth')
        patience_counter = 0
    else:
        patience_counter += 1

    # torch.save(model.state_dict(), f'model_{model_paradigm}_epoch_{epoch+1}.pth')
    
    tqdm.write(f'Loss: {train_loss:.4f} - Accuracy: {train_acc*100:.4f}% - Val Loss: {val_loss:.4f} - Val Accuracy: {val_acc*100:.4f}% - Top 1 Accuracy: {top_1} - Top 5 Accuracy: {top_5}')

    if patience_counter == patience:
        print(f'Early stopping: patience limit reached after epoch {epoch + 1}')
        break

Epoch 1:   0%|          | 0/144 [00:04<?, ?it/s]

Unexpected exception formatting exception. Falling back to standard exception



Traceback (most recent call last):
  File "c:\Users\Raven\anaconda3\Lib\site-packages\IPython\core\interactiveshell.py", line 2168, in showtraceback
    stb = self.InteractiveTB.structured_traceback(
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Raven\anaconda3\Lib\site-packages\IPython\core\ultratb.py", line 1454, in structured_traceback
    return FormattedTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Raven\anaconda3\Lib\site-packages\IPython\core\ultratb.py", line 1345, in structured_traceback
    return VerboseTB.structured_traceback(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Raven\anaconda3\Lib\site-packages\IPython\core\ultratb.py", line 1192, in structured_traceback
    formatted_exception = self.format_exception_as_a_whole(etype, evalue, etb, number_of_lines_of_context,
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "c:\Users\Raven

In [None]:
# Since train loss and accuracy was logged during training, the values are technically 1 epoch behind compared to val loss and accuracy
# Record the train loss and accuracy after the last epoch
train_loss, train_acc, top_1, top_5 = evaluate(model, loss_fn, train_loader)

# history['train_loss'].append(train_loss)
# history['train_acc'].append(train_acc)

# # Drop the first value of train_loss and train_acc since they were logged before the first epoch
# history['train_loss'].pop(0)
# history['train_acc'].pop(0)

In [None]:
plot_model_history(history)

In [None]:
# Load the best model
model.load_state_dict(torch.load(f'best_model_warmed_acc_{model_paradigm}.pth'))

In [None]:
test_dataset = datasets.ImageFolder(root = "./test_data")
transformed_test = TransformedDataset(test_dataset, val_transform)
test_loader = DataLoader(transformed_test, batch_size=32, shuffle=False)

In [None]:
test_loss, test_acc, top_1, top_5 = evaluate(model, loss_fn, test_loader)
print(f'Test Loss: {test_loss:.4f} - Test Accuracy: {test_acc*100:.4f}% - Top 1 Accuracy: {top_1} - Top 5 Accuracy: {top_5}')