### Import Libraries

In [1]:
import os
import time
import torch
import numpy as np
import timm
from datasets import load_dataset

import wandb

torch.cuda.empty_cache()

  from .autonotebook import tqdm as notebook_tqdm


### Parameters

In [2]:
# selection: you can select multiple but consider the memory
selected_datasets = {'Oxford-IIIT-Pets': 'jonathancui/oxford-pets',
                     'Oxford-Flowers-102': 'nelorth/oxford-flowers',
                     'CIFAR-10': 'cifar10',
                     'CIFAR-100': 'cifar100',
                     'ImageNet': 'imagenet-1k',
                     # 'ImageNetReaL': 'imagenet-1k',
} # couldn't find VTAB

selected_models = {
                   # 'ViT-L-16': 'vit_large_patch16_224.augreg_in21k',
                   #'ViT-L-16-FT': 'vit_large_patch16_224.augreg_in21k_ft_in1k',
                   # 'EfficientNet-L2': 'tf_efficientnetv2_l.in21k',
                   'ViT-H-14': 'vit_huge_patch14_224_in21k',
                   #'BiT-L-ResNet152x4': 'resnetv2_152x4_bit.goog_in21k',
                   
}
# fine-tuning params
batch_size = 2048
total_batches = 64

config = {'selected_dataset': selected_datasets,
           'selected_model': selected_models,
           'batch_size': batch_size}

In [3]:
def evaluate(model, data_loader, device):
    y_true = np.array([], dtype=np.int32)
    y_pred = np.array([], dtype=np.int32)
    
    with torch.no_grad():
        for i, data in enumerate(data_loader):
            if i == total_batches: break
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            print(predicted.shape, labels.shape)
            
            y_true = np.concatenate((y_true, labels.squeeze().cpu()))
            y_pred = np.concatenate((y_pred, predicted.cpu()))
    
    accuracy = np.sum(y_true == y_pred) / len(y_true)
    percentage_accuracy = accuracy * 100
    return percentage_accuracy

### Eval

In [4]:
table = wandb.Table(columns=['Dataset', *selected_models.keys()])
for model_title, model_id in selected_models.items():
    for dataset_title, data_id in selected_datasets.items():
        print(f'Evaluating: {model_title}-{dataset_title} ...')
        dataset = load_dataset(data_id)
        key = [*dataset['train'].features.keys()][-1]
        if dataset_title == 'ImageNet':
            model = timm.create_model(model_id, pretrained=True)
            model.to('cuda:0')
            model.requires_grad_(False)
            model = torch.nn.DataParallel(model, device_ids=[0])
        else:
            model = timm.create_model(model_id, pretrained=True, num_classes=len(dataset['train'].features[key].names))
            model.to('cuda:0')
            model.requires_grad_(False)
            model = torch.nn.DataParallel(model, device_ids=[0])
            model.load_state_dict(torch.load(f'{model_title}-{dataset_title}.pt'))
        model_configs = timm.data.resolve_data_config(model=model)
        data_transform = timm.data.create_transform(**model_configs)
        def collate_fn(batch_items):
            keys = [*batch_items[0].keys()]
            pixel_values = torch.stack([data_transform(example[keys[0]].convert('RGB')) for example in batch_items])
            labels = torch.tensor([[example[keys[-1]]] for example in batch_items])
            return (pixel_values, labels)
        test = 'validation' if 'validation' in dataset else 'test'
        dataloader = timm.data.create_loader(dataset[test],
                                model_configs['input_size'],
                                batch_size=config['batch_size'],
                                collate_fn=lambda x: collate_fn(x))
        run = wandb.init(project='try 02 - accuracy',
                     # time
                     name=model_title + '-' + dataset_title + ' ' + time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(time.time())),
                     tags=[model_title, dataset_title],
                     job_type='eval')
        wandb.config=config,
    
        accuracy = evaluate(model, dataloader, 'cuda:0')
        wandb.log({'Model': model_title, 'Dataset': dataset_title, 'Accuracy': accuracy})
        run.finish()
        del model
        del dataloader
        del dataset
        torch.cuda.empty_cache()
        

Evaluating: ViT-L-16-Oxford-IIIT-Pets ...


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mmuhammad-shahbaz[0m ([33murbanity[0m). Use [1m`wandb login --relogin`[0m to force relogin


torch.Size([2048]) torch.Size([2048, 1])
torch.Size([1621]) torch.Size([1621, 1])


0,1
Accuracy,▁

0,1
Accuracy,2.72554
Dataset,Oxford-IIIT-Pets
Model,ViT-L-16


Evaluating: ViT-L-16-Oxford-Flowers-102 ...


torch.Size([1020]) torch.Size([1020, 1])


0,1
Accuracy,▁

0,1
Accuracy,1.27451
Dataset,Oxford-Flowers-102
Model,ViT-L-16


Evaluating: ViT-L-16-CIFAR-10 ...


torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([1808]) torch.Size([1808, 1])


0,1
Accuracy,▁

0,1
Accuracy,25.85
Dataset,CIFAR-10
Model,ViT-L-16


Evaluating: ViT-L-16-CIFAR-100 ...


torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([1808]) torch.Size([1808, 1])


0,1
Accuracy,▁

0,1
Accuracy,12.01
Dataset,CIFAR-100
Model,ViT-L-16


Evaluating: ViT-L-16-ImageNet ...


torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([848]

0,1
Accuracy,▁

0,1
Accuracy,0.0
Dataset,ImageNet
Model,ViT-L-16


Evaluating: EfficientNet-L2-Oxford-IIIT-Pets ...


torch.Size([2048]) torch.Size([2048, 1])
torch.Size([1621]) torch.Size([1621, 1])


0,1
Accuracy,▁

0,1
Accuracy,6.07795
Dataset,Oxford-IIIT-Pets
Model,EfficientNet-L2


Evaluating: EfficientNet-L2-Oxford-Flowers-102 ...


torch.Size([1020]) torch.Size([1020, 1])


0,1
Accuracy,▁

0,1
Accuracy,18.43137
Dataset,Oxford-Flowers-102
Model,EfficientNet-L2


Evaluating: EfficientNet-L2-CIFAR-10 ...


torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([1808]) torch.Size([1808, 1])


0,1
Accuracy,▁

0,1
Accuracy,58.01
Dataset,CIFAR-10
Model,EfficientNet-L2


Evaluating: EfficientNet-L2-CIFAR-100 ...


torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([1808]) torch.Size([1808, 1])


0,1
Accuracy,▁

0,1
Accuracy,50.77
Dataset,CIFAR-100
Model,EfficientNet-L2


Evaluating: EfficientNet-L2-ImageNet ...


torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([2048]) torch.Size([2048, 1])
torch.Size([848]

0,1
Accuracy,▁

0,1
Accuracy,0.0
Dataset,ImageNet
Model,EfficientNet-L2
