In [7]:
import sys
sys.path.append('../model-soups')

from datasets import ImageNet, ImageNetV2, ImageNetSketch, ImageNetR, ObjectNet, ImageNetA
from utils import get_model_from_sd, test_model_on_dataset

data_location = '/path/to/dataset/'
data_location = '/mnt/tmp/'
batch_size = 256
workers = 4

### Download model weights trained with Model Stock 
- Downliad link: [link](https://www.dropbox.com/scl/fo/shyzn327ge206n3fimomd/AKY_mVAN_kFiR2ejpAGPHQ8?rlkey=qk76yc6jisv8fjf2zdtrw4ki6&st=m43urvv3&dl=0)
- CLIP ViT-B/32 Model Stock: `clip_vit_b_32_model_stock.pt`
- CLIP ViT-B/32 Model Stock*: `clip_vit_b_32_model_stock_star.pt`


In [9]:
import torch
model_stock = torch.load('/model/path/clip_vit_b_32_model_stock.pt')
model_stock_star = torch.load('/model/path/clip_vit_b_32_model_stock_star.pt')


In [10]:
import clip
import numpy as np

# reference: https://github.com/mlfoundations/model-soups/blob/main/main.py#L90-L103
def eval_model(model_state_dict):
    base_model, preprocess = clip.load('ViT-B/32', 'cpu', jit=False)    
    model = get_model_from_sd(model_state_dict, base_model)
    results = {}
    for dataset_cls in [ImageNet, ImageNetV2, ImageNetSketch, ImageNetR, ObjectNet, ImageNetA]:
        print(f'Evaluating on {dataset_cls.__name__}.')
        dataset = dataset_cls(preprocess, data_location, batch_size)
        accuracy = test_model_on_dataset(model, dataset)
        results[dataset_cls.__name__] = accuracy        
    print(f'In-distribution (ImageNet): {results["ImageNet"]}')
    print(f'Average of out-of-distributions: {np.mean(list(results.values())[1:])}')
    return results

In [None]:
eval_model(model_stock)

Evaluating on ImageNet.
[0% 0/196]	Acc: 89.45	Data (t) 9.881	Batch (t) 20.911
[10% 20/196]	Acc: 82.57	Data (t) 0.013	Batch (t) 0.155
[20% 40/196]	Acc: 81.98	Data (t) 0.013	Batch (t) 0.177
[31% 60/196]	Acc: 81.35	Data (t) 0.013	Batch (t) 0.164
[41% 80/196]	Acc: 81.75	Data (t) 0.013	Batch (t) 0.173
[51% 100/196]	Acc: 80.85	Data (t) 0.013	Batch (t) 0.169
[61% 120/196]	Acc: 80.79	Data (t) 0.013	Batch (t) 0.158
[71% 140/196]	Acc: 80.38	Data (t) 0.013	Batch (t) 0.148
[82% 160/196]	Acc: 80.08	Data (t) 0.013	Batch (t) 0.148
[92% 180/196]	Acc: 79.75	Data (t) 0.013	Batch (t) 0.152
Evaluating on ImageNetV2.
[0% 0/40]	Acc: 78.52	Data (t) 5.789	Batch (t) 5.996
[50% 20/40]	Acc: 70.35	Data (t) 0.013	Batch (t) 0.147
Evaluating on ImageNetSketch.
[0% 0/199]	Acc: 40.62	Data (t) 6.046	Batch (t) 6.243
[10% 20/199]	Acc: 38.84	Data (t) 0.013	Batch (t) 0.195
[20% 40/199]	Acc: 41.76	Data (t) 0.013	Batch (t) 0.192
[30% 60/199]	Acc: 42.32	Data (t) 0.013	Batch (t) 0.190
[40% 80/199]	Acc: 43.99	Data (t) 0.013	Bat

{'ImageNet': 0.79894,
 'ImageNetV2': 0.6893,
 'ImageNetSketch': 0.4621037945331997,
 'ImageNetR': 0.6429333333333334,
 'ObjectNet': 0.4630666523096802,
 'ImageNetA': 0.2922666666666667}

In [12]:
eval_model(model_stock_star)

Evaluating on ImageNet.
[0% 0/196]	Acc: 88.67	Data (t) 5.686	Batch (t) 5.915
[10% 20/196]	Acc: 83.63	Data (t) 0.013	Batch (t) 0.178
[20% 40/196]	Acc: 83.55	Data (t) 0.013	Batch (t) 0.181
[31% 60/196]	Acc: 83.07	Data (t) 0.014	Batch (t) 0.159
[41% 80/196]	Acc: 83.37	Data (t) 0.013	Batch (t) 0.167
[51% 100/196]	Acc: 82.32	Data (t) 0.014	Batch (t) 0.182
[61% 120/196]	Acc: 82.36	Data (t) 0.013	Batch (t) 0.166
[71% 140/196]	Acc: 81.88	Data (t) 0.013	Batch (t) 0.152
[82% 160/196]	Acc: 81.63	Data (t) 0.013	Batch (t) 0.148
[92% 180/196]	Acc: 81.21	Data (t) 0.013	Batch (t) 0.154
Evaluating on ImageNetV2.
[0% 0/40]	Acc: 80.08	Data (t) 4.866	Batch (t) 5.078
[50% 20/40]	Acc: 71.84	Data (t) 0.013	Batch (t) 0.147
Evaluating on ImageNetSketch.
[0% 0/199]	Acc: 50.39	Data (t) 6.453	Batch (t) 6.660
[10% 20/199]	Acc: 38.36	Data (t) 0.013	Batch (t) 0.184
[20% 40/199]	Acc: 42.04	Data (t) 0.013	Batch (t) 0.179
[30% 60/199]	Acc: 42.85	Data (t) 0.014	Batch (t) 0.188
[40% 80/199]	Acc: 44.46	Data (t) 0.013	Batc

{'ImageNet': 0.8133,
 'ImageNetV2': 0.7048,
 'ImageNetSketch': 0.4581933227220028,
 'ImageNetR': 0.6095666666666667,
 'ObjectNet': 0.42871756218369766,
 'ImageNetA': 0.23986666666666667}