In [2]:
from transformers import (CLIPProcessor, CLIPModel, CLIPTextModel, AutoTokenizer, AutoProcessor, CLIPVisionModel)
from datasets import load_dataset
from matplotlib import pyplot as plt
import numpy as np
import math
from sklearn.metrics import accuracy_score
from sklearn.metrics.pairwise import cosine_similarity
import itertools
from PIL import Image
from zipfile import ZipFile

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [3]:
# Name of the pre-trained base model
base_model_name = 'openai/clip-vit-base-patch32'

# Pre-trained CLIP model
model = CLIPModel.from_pretrained(base_model_name)
processor = CLIPProcessor.from_pretrained(base_model_name)

Downloading (…)lve/main/config.json: 100%|██████████| 4.19k/4.19k [00:00<00:00, 1.71MB/s]
Downloading pytorch_model.bin: 100%|██████████| 605M/605M [00:28<00:00, 21.0MB/s] 
Downloading (…)rocessor_config.json: 100%|██████████| 316/316 [00:00<00:00, 188kB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 568/568 [00:00<00:00, 239kB/s]
Downloading (…)olve/main/vocab.json: 100%|██████████| 862k/862k [00:00<00:00, 8.14MB/s]
Downloading (…)olve/main/merges.txt: 100%|██████████| 525k/525k [00:00<00:00, 5.32MB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.22M/2.22M [00:00<00:00, 11.7MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 389/389 [00:00<00:00, 262kB/s]


In [None]:
class CLIPWrapper:
    def __init__(self, clip_model, clip_processor, template_pos:str, template_neg:str):
        self.clip_model = clip_model
        self.clip_processor = clip_processor
        self.template_pos = template_pos
        self.template_neg = template_neg

    def _prompts(self):
        """return text prompts for CLIP"""
        prompts = []
        for attr in CLIPWrapper.ATTRIBUTES:
            for t in self.template:
                prompts.append(t.format(attr=attr))
        return prompts

    def set_template(self, template_pos:str, template_neg:str):
        self.template_pos = template_pos
        self.template_neg = template_neg

    def __call__(self, x):
        texts = self._prompts()
        inputs = self.clip_processor(text=texts, images=x, return_tensors="pt", padding=True)
        outputs = self.clip_model(**inputs)

        # given M images, N texts, output shape will be M x N
        logits_per_image = outputs.logits_per_image # this is the image-text similarity score
        N_img, N_txt = logits_per_image.shape
        N_temp = len(self.template)
        N_attr = len(CLIPWrapper.ATTRIBUTES)
        assert N_txt == N_attr*N_temp, f'Num text ({N_txt}) != Num template x Num attrs ({N_temp}x{N_attr}={N_temp*N_attr})'

        logits_data = logits_per_image.data.reshape(N_img, N_attr, N_temp)

        return logits_data.softmax(dim=-1)

    def pred(self, x):
        probs = self(x)
        return probs.argmax(dim=-1)