In [4]:
import torch
from vit_transformer import VitTransformer
from torch.utils.data import DataLoader
from torchvision.transforms import CenterCrop, Compose, Normalize, Resize, ToTensor

In [5]:
from transformers import ViTFeatureExtractor
from datasets import load_dataset

In [None]:
encoder_layers = 6

embed_dim = 512
num_heads = 8

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:" + str(device))



In [None]:
# load cifar10 
train_set, test_set = load_dataset('cifar10', split=['train', 'test'])
num_classes = 10

In [8]:
feature_extractor = ViTFeatureExtractor.from_pretrained("google/vit-base-patch16-224-in21k")
normalize = Normalize(mean=feature_extractor.image_mean, std=feature_extractor.image_std)

image_h_w, img_channels, patch_h_w = 224, 3, 16

In [9]:
_val_test_transforms_steps = Compose(
    [
        Resize(feature_extractor.size),
        CenterCrop(feature_extractor.size),
        ToTensor(),
        normalize,
    ])


def val_test_transforms(images):
    transformed_images = [_val_test_transforms_steps(curr_image['img'].convert("RGB")) for curr_image in images]
    return transformed_images

In [10]:
def collate_func(images):
    labels = torch.tensor([image_t["label"] for image_t in images])
    return {"images": images, "labels": labels}


In [11]:
epoch_model_to_load = 1
load_model = torch.load('vit_model_epoch_' + str(epoch_model_to_load) + '.pth', map_location=device)
model = load_model['model']
print("load model of epoch: " + str(epoch_model_to_load))

load model of epoch: 1


In [20]:
batch_size = len(test_set)
test_loader = DataLoader(test_set, collate_fn=collate_func, batch_size=batch_size, shuffle=False, pin_memory=True)

In [21]:
def test(model, test_loader, metrics):
    model.eval()
    metrics_results = []
    for i, batch_data in enumerate(test_loader):
        with torch.no_grad():
            transformed_images = torch.stack(val_test_transforms(batch_data['images']))
            transformed_images = transformed_images.to(device)

            labels = batch_data['labels']
            labels = labels.to(device)
            labels = labels.contiguous().view(-1)  # dims: [batch_size * 1]

            preds = model(transformed_images)
            _, idxes = torch.max(preds, dim=1)

            for met_criterion in metrics:
                metric = load_metric(met_criterion)
                metrics_results.append(metric.compute(predictions=idxes, references=labels))

    return metrics_results

In [None]:
from datasets import load_metric

metrics = ['accuracy']
metrics_results = test(model, test_loader, metrics)
print(metrics_results)