In [1]:
import clip
import torch
import PIL
import numpy as np
from tqdm import tqdm

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "mps" if torch.has_mps else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
image = preprocess(PIL.Image.open("data/dog.jpeg")).unsqueeze(0).to(device)
labels = clip.tokenize(["pizza","dog", "car", "person", "computer"]).to(device)

In [4]:
def encode_text(model, text):
    x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

    x = x + model.positional_embedding.type(model.dtype)
    num_bpe_tokens, num_text, embedding_size = x.shape
    x = x.transpose(1,0).contiguous().view(num_text, num_bpe_tokens, embedding_size)  # NLD -> LND
    x = model.transformer(x)
    x = x.transpose(1,0).contiguous().view(num_bpe_tokens, num_text, embedding_size) # LND -> NLD
    x = model.ln_final(x).type(model.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    try:
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection
    except NotImplementedError:
        text = text.to('cpu')
        x = x.to('cpu')
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)].to('mps') @ model.text_projection
    return x


def encode_image(vision_model, x: torch.Tensor):
    x = vision_model.conv1(x)  # shape = [*, width, grid, grid]
    x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
    batch_size, width, grid = x.shape
    x = x.transpose(2,1).contiguous().view(batch_size, grid, width)#x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
    x = torch.cat([vision_model.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1)  # shape = [*, grid ** 2 + 1, width]
    x = x + vision_model.positional_embedding.to(x.dtype)
    x = vision_model.ln_pre(x)

    batch_size, width, grid = x.shape
    x = x.transpose(1,0).contiguous().view(width, batch_size, grid) #x.permute(1, 0, 2)  # NLD -> LND
    x = vision_model.transformer(x)
    x = x.transpose(1,0).contiguous().view(batch_size, width, grid)#x.permute(1, 0, 2)  # LND -> NLD

    x = vision_model.ln_post(x[:, 0, :])

    if vision_model.proj is not None:
        x = x @ vision_model.proj

    return x


In [5]:
with torch.no_grad():
    image_features = encode_image(model.visual, image.type(model.dtype))
    text_features = encode_text(model, labels)
    image_features =  image_features / image_features.norm(dim=-1, keepdim=True)
    text_features  = text_features / text_features.norm(dim=-1, keepdim=True)
    similarities = 100* image_features @ text_features.t()
    probs = similarities.softmax(dim=-1).cpu().numpy()
probs

array([[1.065e-04, 9.932e-01, 6.323e-04, 6.287e-03, 8.422e-05]],
      dtype=float16)