<a href="https://colab.research.google.com/github/lakatosgabor/clip_test/blob/main/CLIPImageNet_Zero_Shot_Evaluation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -q -U jax jaxlib
!pip install -q pandas
!pip install -q ipywidgets
!pip install -q -U flax
!pip install -q sentence-transformers
!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q transformers
!pip install -q torch torchvision

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.7/1.7 MB[0m [31m20.8 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.0/85.0 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.0/86.0 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.7/7.7 MB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m30.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m295.0/295.0 kB[0m [31m27.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.8/3.8 MB[0m [31m42.4 MB/s[0m eta 

In [2]:
import os
import sys
import json

import numpy as np
import pandas as pd

os.environ['TOKENIZERS_PARALLELISM'] = "false"

import transformers
from transformers import AutoTokenizer

import torch
import torchvision
from torchvision import transforms
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor
from torchvision.transforms.functional import InterpolationMode
from tqdm.notebook import tqdm

sys.path.append('.')

from modeling_hybrid_clip import FlaxHybridCLIP

In [4]:
import jax
from jax import numpy as jnp
from transformers import FlaxVisionTextDualEncoderModel

TOKENIZER_NAME = "SZTAKI-HLT/hubert-base-cc"
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_NAME, cache_dir=None, use_fast=True)
model = FlaxVisionTextDualEncoderModel.from_pretrained("/content/drive/MyDrive/model/40")
def tokenize(texts):
    inputs = tokenizer(texts, max_length=96, padding="max_length", return_tensors="np")
    return inputs['input_ids'], inputs['attention_mask']

language_model = lambda queries: np.asarray(model.get_text_features(*tokenize(queries)))
image_model = lambda images: np.asarray(model.get_image_features(images.permute(0, 2, 3, 1).numpy(),))

You are using a model of type hybrid-clip to instantiate a model of type vision-text-dual-encoder. This is not supported for all configurations of models and can yield errors.


In [5]:
!wget -N -q https://raw.githubusercontent.com/clip-italian/clip-italian/imagenet_templates/evaluation/imagenet_labels_IT.tsv
classes_df = pd.read_csv('/content/imagenet_data.tsv', sep='\t', header=0)
imagenet_classes = list(classes_df['query_short'])  # list(classes_df['query_long_translated'])
imagenet_templates = [
    'fotó {}.',
]

print(f"{len(imagenet_classes)} classes, {len(imagenet_templates)} templates")

1000 classes, 1 templates


In [6]:
val_preprocess = transforms.Compose([
    Resize([224], interpolation=InterpolationMode.BICUBIC),
    CenterCrop(224),
    ToTensor(),
    Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])

In [7]:
print('Downloading Imagenet validation set...')
!wget -N -q --show-progress https://image-net.org/data/ILSVRC/2012/ILSVRC2012_img_val.tar
print('Downloading Imagenet devkit...')
!wget -N -q --show-progress https://image-net.org/data/ILSVRC/2012/ILSVRC2012_devkit_t12.tar.gz
print('Done.')

images = torchvision.datasets.ImageNet('./', split='val', transform=val_preprocess)
loader = torch.utils.data.DataLoader(
    images,
    batch_size=1024,
    shuffle=False,
    num_workers=2,
    persistent_workers=True,
    drop_last=False
)

Downloading Imagenet validation set...
Downloading Imagenet devkit...
Done.


In [8]:
def zeroshot_classifier(classnames, templates):
    zeroshot_weights = []
    for classname in tqdm(classnames):
        texts = [template.format(classname) for template in templates]
        class_embeddings = language_model(texts)
        class_embeddings = class_embeddings / np.linalg.norm(class_embeddings, axis=-1, keepdims=True)
        class_embedding = np.mean(class_embeddings, axis=0)
        class_embedding /= np.linalg.norm(class_embedding, axis=-1)
        zeroshot_weights.append(class_embedding)
    zeroshot_weights = np.stack(zeroshot_weights, axis=1)
    return zeroshot_weights


zeroshot_weights = zeroshot_classifier(imagenet_classes, imagenet_templates)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [9]:
def accuracy(output, target, topk=(1,)):
    output = torch.from_numpy(np.asarray(output))
    target = torch.from_numpy(np.asarray(target))
    pred = output.topk(max(topk), dim=1, largest=True, sorted=True)[1].t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]

In [10]:
top_ns = [1, 5, 10, 100]
acc_counters = [0. for _ in top_ns]
n = 0.

for i, (images, target) in enumerate(tqdm(loader)):
    images = images
    target = target.numpy()
    # predict
    image_features = image_model(images)
    image_features = image_features / np.linalg.norm(image_features, axis=-1, keepdims=True)
    logits = 100. * image_features @ zeroshot_weights

    # measure accuracy
    accs = accuracy(logits, target, topk=top_ns)
    for j in range(len(top_ns)):
        acc_counters[j] += accs[j]
    n += images.shape[0]

tops = {f'top{top_ns[i]}': acc_counters[i] / n * 100 for i in range(len(top_ns))}

print(tops)

  0%|          | 0/49 [00:00<?, ?it/s]

{'top1': 2.6780000000000004, 'top5': 8.581999999999999, 'top10': 13.267999999999999, 'top100': 43.052}
