In [1]:
! pip install open_clip_torch
! pip install --upgrade datasets
! pip install relplot

Collecting open_clip_torch
  Downloading open_clip_torch-2.29.0-py3-none-any.whl.metadata (31 kB)
Collecting ftfy (from open_clip_torch)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading open_clip_torch-2.29.0-py3-none-any.whl (1.5 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.5/1.5 MB[0m [31m41.9 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy, open_clip_torch
Successfully installed ftfy-6.3.1 open_clip_torch-2.29.0
Collecting datasets
  Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting m

In [2]:
import pathlib

import datasets
import open_clip
import relplot
import torch

In [3]:
# Datasets
DATASET_URLS = {
    "llama-test": "https://drive.google.com/uc?id=1pop_ltmF9mpiMDqi1lcTbpmOyTo96if1",
    "llama-train": "https://drive.google.com/uc?id=1Wyc8U_I2UCjT863ndpDfhgSLxjfArk3A",
    "beach-train": "https://drive.google.com/uc?id=1F6ozO15919KpPP_57Z0jD5dHOqU_w-nb",
    "beach-test": "https://drive.google.com/uc?id=1o5MpdWKKrC80I4zbJr9u2L_zKPUXdQgy",
    "beach-test-ood": "https://drive.google.com/uc?id=163NyfmWarIAAOjUqH3QPedQokzfmNoFH",
}

In [4]:
! gdown {DATASET_URLS["llama-test"]}
! unzip /content/llama-test.zip
! gdown {DATASET_URLS["llama-train"]}
! unzip /content/llama-train.zip

Downloading...
From (original): https://drive.google.com/uc?id=1pop_ltmF9mpiMDqi1lcTbpmOyTo96if1
From (redirected): https://drive.google.com/uc?id=1pop_ltmF9mpiMDqi1lcTbpmOyTo96if1&confirm=t&uuid=8168121f-0798-43eb-aa80-865ecb614f72
To: /content/llama-test.zip
100% 50.0M/50.0M [00:01<00:00, 31.5MB/s]
Archive:  /content/llama-test.zip
   creating: llama-test/
  inflating: llama-test/state.json   
  inflating: llama-test/dataset_info.json  
  inflating: llama-test/data-00000-of-00001.arrow  
Downloading...
From (original): https://drive.google.com/uc?id=1Wyc8U_I2UCjT863ndpDfhgSLxjfArk3A
From (redirected): https://drive.google.com/uc?id=1Wyc8U_I2UCjT863ndpDfhgSLxjfArk3A&confirm=t&uuid=71d191e0-a9d1-4ed0-9b6d-89f74c354151
To: /content/llama-train.zip
100% 1.28G/1.28G [00:36<00:00, 34.9MB/s]
Archive:  /content/llama-train.zip
   creating: llama-train/
  inflating: llama-train/state.json  
  inflating: llama-train/dataset_info.json  
  inflating: llama-train/data-00002-of-00003.arrow  
  inf

In [5]:
class ZeroShotClassifier:
    TEMPLATES = [
        "itap of a {}.",
        "a bad photo of the {}.",
        "a origami {}.",
        "a photo of the large {}.",
        "a {} in a video game.",
        "art of the {}.",
        "a photo of the small {}.",
    ]
    def __init__(self, model_name, pretrained_source):
        self.model_name = model_name
        self.pretrained_source = pretrained_source
        self.model, _, self.preprocess = open_clip.create_model_and_transforms(
            self.model_name,
            pretrained=self.pretrained_source,
        )
        self.model.eval()
        self.tokenizer = open_clip.get_tokenizer(self.model_name)
        self.text = None
        self.text_features = None

    def set_text(self, text):
        self.text = text
        tokens = self.tokenizer([t.format(self.text) for t in ZeroShotClassifier.TEMPLATES])
        with torch.no_grad(), torch.amp.autocast("cuda"):
            text_features = self.model.encode_text(tokens)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            text_features = text_features.mean(dim=0)
            text_features /= text_features.norm()
        self.text_features = text_features.numpy()

    def get_image_features(self, image):
        input_features = self.preprocess(image).unsqueeze(0)
        with torch.no_grad(), torch.amp.autocast("cuda"):
            image_features = self.model.encode_image(input_features)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.numpy()

    def get_image_batch_features(self, images):
        input_features = self.preprocess(images)
        with torch.no_grad(), torch.amp.autocast("cuda"):
            image_features = self.model.encode_image(input_features)
            image_features /= image_features.norm(dim=-1, keepdim=True)
        return image_features.numpy()

    def score_image(self, image, with_features=False):
        image_features = self.get_image_features(image)
        score = (image_features @ self.text_features).item()
        results = {"score": score}
        if with_features:
            results["features"] = image_features
        return results

    def score_image_batch(self, images, with_features=False):
        # TODO
        pass

In [6]:
def build_scores_dataset(zero_shot_classifier, image_dataset, with_features=False):
    return image_dataset.map(
        lambda ex: dict(
            label=ex["label"],
            **zero_shot_classifier.score_image(ex["image"], with_features)
        ),
        remove_columns=["image"],
    ) # TODO: make this work with batching

In [None]:
zsc = ZeroShotClassifier("ViT-B-32-quickgelu", "openai")
labels = ["llama"]
for label in labels:
    zsc.set_text(label)

    train_image_dataset_path = pathlib.Path(f"/content/{label}-train")
    train_image_dataset = datasets.load_from_disk(train_image_dataset_path)
    #train_scores_dataset = build_scores_dataset(zsc, train_image_dataset)

    test_image_dataset_path = pathlib.Path(f"/content/{label}-test")
    test_image_dataset = datasets.load_from_disk(test_image_dataset_path)
    test_scores_dataset = build_scores_dataset(zsc, test_image_dataset)


Map:   0%|          | 0/550 [00:00<?, ? examples/s]