In [65]:
import torch
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.transforms import transforms

from tqdm.auto import tqdm
import os, shutil, cv2
from utils.model_loaders import vit_small_patch16_224, vit_base_patch16_224
import plotly.express as ex
from utils.image_denorm import image_vizformat

In [2]:
data_path = "lib/data/dataset_50/val"

image_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
])

dataset = ImageFolder(data_path, transform=image_transform)
dataloader = DataLoader(dataset, batch_size=1, num_workers=4, shuffle=False)
dataloader = tqdm(dataloader)

  0%|          | 0/2500 [00:47<?, ?it/s]

In [66]:
model_path = "lib/benchmarking_model/jx_vit_base_p16_224_raw_images_24_max_validation_accuracy.pth"

model = vit_base_patch16_224(pretrained=True)
criterion = torch.nn.CrossEntropyLoss()

In [None]:
def embd_to_class(batched_embed):
    class_indices = []
    for embd in batched_embed:
        cls_idx = torch.argmax(embd)
        class_indices.append(cls_idx)
    return torch.Tensor(class_indices)

def measure_accuracy(yhat, y):
    yhat_idx = embd_to_class(yhat)

    accuracy = 0.0
    for yh, yval in zip(yhat_idx, y):
        accuracy += 1 if (yh == yval) else 0
        
    return accuracy


val_loss = 0.0
val_accuracy = 0.0

for batch in tqdm(dataloader):
    xv, yv = batch
    xv, yv = xv.cuda(), yv.cuda()
    yv_hat = model(xv)
    vloss = criterion(yv_hat, yv)
    
    val_loss += vloss.detach().cpu().item() / len(dataloader)
    val_accuracy += measure_accuracy(yv_hat, yv) / len(dataloader)