In [None]:
import os

import torch
from tqdm import tqdm
from ffcv.fields import BytesField, IntField, RGBImageField
from ffcv.writer import DatasetWriter

from data_utils.data_stats import *
from data_utils.dataloader import get_loader
from utils.metrics import topk_acc, real_acc, AverageMeter
from models.networks import get_model
from data_utils.dataset_to_beton import get_dataset

In [None]:
dataset = 'cifar10'                 # One of cifar10, cifar100, stl10, imagenet or imagenet21
architecture = 'B_12-Wi_1024'
data_resolution = 32                # Resolution of data as it is stored
crop_resolution = 64                # Resolution of fine-tuned model (64 for all models we provide)
num_classes = CLASS_DICT[dataset]
data_path = './beton/'
eval_batch_size = 1024
checkpoint = 'in21k_cifar10'        # This means you want the network pre-trained on ImageNet21k and finetuned on CIFAR10

In [None]:
# If you did not yet, produce .beton file for CIFAR10 (check README for how to do that for ImageNet)
def create_beton(dataset, mode, data_path, res):
    dataset = get_dataset(dataset, mode, data_path, res)

    write_path = os.path.join(
        write_path, dataset, mode, f"{mode}_{res}.beton"
    )

    os.makedirs(os.path.dirname(write_path), exist_ok=True)

    writer = DatasetWriter(
        write_path,
        {
            "image": RGBImageField(write_mode="smart", max_resolution=res),
            "label": IntField(),
        },
        num_workers=0,
    )

    writer.from_indexed_dataset(dataset, chunksize=100)


create_beton(dataset, 'test', data_path, data_resolution)

In [None]:
torch.backends.cuda.matmul.allow_tf32 = True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define the model and specify the pre-trained weights
model = get_model(architecture=architecture, resolution=crop_resolution, num_classes=CLASS_DICT[dataset],
                  checkpoint='in21k_cifar10')
model.cuda()

In [None]:
# Get the test loader
loader = get_loader(
    dataset,
    bs=eval_batch_size,
    mode="test",
    augment=False,
    dev=device,
    mixup=0.0,
    data_path=data_path,
    data_resolution=data_resolution,
    crop_resolution=crop_resolution,
)

In [None]:
# Define a test function that evaluates test accuracy
@torch.no_grad()
def test(model, loader):
    model.eval()
    total_acc, total_top5 = AverageMeter(), AverageMeter()

    for ims, targs in tqdm(loader, desc="Evaluation"):
        ims = torch.reshape(ims, (ims.shape[0], -1))
        preds = model(ims)

        if dataset != 'imagenet_real':
            acc, top5 = topk_acc(preds, targs, k=5, avg=True)
        else:
            acc = real_acc(preds, targs, k=5, avg=True)
            top5 = 0

        total_acc.update(acc, ims.shape[0])
        total_top5.update(top5, ims.shape[0])


    return (
        total_acc.get_avg(percentage=True),
        total_top5.get_avg(percentage=True),
    )

In [None]:
test_acc, test_top5 = test(model, loader)

# Print all the stats
print("Test Accuracy        ", "{:.4f}".format(test_acc))
print("Top 5 Test Accuracy          ", "{:.4f}".format(test_top5))