In [2]:
import os 
import clip 
import torch 
from torchvision.datasets import CIFAR100
device = "cuda" if torch.cuda.is_available() else "cpu"

In [6]:
model, preprocess = clip.load("ViT-B/32", device)

In [7]:
cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /home/me/.cache/cifar-100-python.tar.gz


100%|██████████| 169001437/169001437 [00:22<00:00, 7564186.73it/s] 


Extracting /home/me/.cache/cifar-100-python.tar.gz to /home/me/.cache


In [30]:
image, class_id = cifar100[3637]
print(image, class_id)

<PIL.Image.Image image mode=RGB size=32x32 at 0x7AE418F217C0> 78


In [35]:
image.save('output.jpg')

In [11]:

image_input = preprocess(image).unsqueeze(0).to(device)
text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)

In [37]:
cifar100.classes[78]

'snake'

In [38]:
[clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]

[tensor([[49406,   320,  1125,   539,   320,  3055, 49407,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0]], dtype=torch.int32),
 tensor([[49406,   320,  1125,   539,   320, 16814,   318,  2759, 49407,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
              0,     0,     0,     0,  

In [39]:
print(text_inputs.size())

torch.Size([100, 77])


In [40]:
with torch.no_grad():
    image_features = model.encode_image(image_input)
    text_features = model.encode_text(text_inputs)

In [42]:
print("image features: ", image_features.size())
print("text features: ", text_features.size())

image features:  torch.Size([1, 512])
text features:  torch.Size([100, 512])


In [43]:
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)


@ = matmul  
[1, 512] @ [100, 512]<sup>T<sup>

In [44]:
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)

In [45]:
print(similarity.size())


torch.Size([1, 100])


In [46]:
print(text_features, text_features.size())


tensor([[-0.0115,  0.0482,  0.0021,  ..., -0.0322, -0.0200, -0.0037],
        [-0.0039, -0.0030,  0.0173,  ...,  0.0359, -0.0046,  0.0105],
        [-0.0048,  0.0119, -0.0191,  ..., -0.0723, -0.0104, -0.0209],
        ...,
        [ 0.0095, -0.0034,  0.0092,  ..., -0.0424, -0.0279,  0.0242],
        [-0.0074,  0.0228, -0.0125,  ..., -0.0759, -0.0134, -0.0071],
        [-0.0134,  0.0466,  0.0011,  ..., -0.0410, -0.0109, -0.0083]],
       device='cuda:0', dtype=torch.float16) torch.Size([100, 512])


In [48]:
values, indices = similarity[0].topk(5)
print(values, indices)

tensor([0.6523, 0.1245, 0.0385, 0.0188, 0.0174], device='cuda:0',
       dtype=torch.float16) tensor([78, 93, 83, 44, 27], device='cuda:0')


In [49]:
print("Top predictions: ")
for value, index in zip(values, indices):
    print(f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}%")

Top predictions: 
           snake: 65.23%
          turtle: 12.45%
    sweet_pepper: 3.85%
          lizard: 1.88%
       crocodile: 1.74%
