In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from data import get_datasets, get_dataloaders
from random_texts import CLIPZeroShotClassifier

In [3]:
datasets, classnames = get_datasets(fraction=1e-3)
for name, dataset in datasets.items():
    print(len(dataset))

Loading dataset shards:   0%|          | 0/27 [00:00<?, ?it/s]

120
15
21
51
52
20


In [4]:
baseline_model = CLIPZeroShotClassifier(classnames)
dataloaders = get_dataloaders(datasets, baseline_model.preprocess)

100%|██████████| 345/345 [00:25<00:00, 13.68it/s]


In [5]:
from torch import nn
import torch
from tqdm import tqdm
from pprint import pprint


def evaluate(model: nn.Module) -> dict[str, float]:
    model.eval()
    results = {}
    with torch.no_grad():
        for name, dataloader in dataloaders.items():
            correct = 0
            total = 0
            for batch in tqdm(dataloader, desc=f"Evaluating {name}"):
                images = batch["image"]
                labels = batch["label"]
                logits = model(images)
                correct += (logits.argmax(dim=-1) == labels).float().sum()
                total += len(labels)
            results[name] = correct / total
    return results

## Zero-shot model

In [6]:
baseline_results = evaluate(baseline_model)
pprint(baseline_results)

Evaluating ID: 100%|██████████| 4/4 [00:06<00:00,  1.51s/it]
Evaluating OOD_infograph: 100%|██████████| 1/1 [00:00<00:00,  1.25it/s]
Evaluating OOD_painting: 100%|██████████| 1/1 [00:01<00:00,  1.02s/it]
Evaluating OOD_quickdraw: 100%|██████████| 2/2 [00:02<00:00,  1.20s/it]
Evaluating OOD_real: 100%|██████████| 2/2 [00:02<00:00,  1.26s/it]
Evaluating OOD_clipart: 100%|██████████| 1/1 [00:00<00:00,  1.03it/s]

{'ID': tensor(0.8083),
 'OOD_clipart': tensor(0.4000),
 'OOD_infograph': tensor(0.6667),
 'OOD_painting': tensor(0.6190),
 'OOD_quickdraw': tensor(0.1569),
 'OOD_real': tensor(0.7692)}



