In [1]:
import torch
import clip
from PIL import Image

device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("RN50", device=device)

In [2]:
model

CLIP(
  (visual): ModifiedResNet(
    (conv1): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu1): ReLU(inplace=True)
    (conv2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu2): ReLU(inplace=True)
    (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu3): ReLU(inplace=True)
    (avgpool): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (layer1): Sequential(
      (0): Bottleneck(
        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu1): ReLU(inplace=True)
     

In [3]:
image = preprocess(Image.open("CLIP.png")).unsqueeze(0).to(device)
classes = ["diagram", "dog", "cat", "animal", "book", "paper"]
text = torch.cat([clip.tokenize(f"a photo of a {c}") for c in classes]).to(device)

In [4]:
image.shape, text.shape

(torch.Size([1, 3, 224, 224]), torch.Size([6, 77]))

In [5]:
image_features = model.encode_image(image)
text_features = model.encode_text(text)
image_features.shape, text_features.shape

(torch.Size([1, 1024]), torch.Size([6, 1024]))

In [6]:
logits_per_image, logits_per_text = model(image, text)
logits_per_image.shape, logits_per_text.shape

(torch.Size([1, 6]), torch.Size([6, 1]))

In [7]:
# Calculate features
with torch.no_grad():
    image_features = model.encode_image(image)
    text_features = model.encode_text(text)

# Pick the top 5 most similar labels for the image
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(5)

# Print the result
print("\nTop predictions:\n")
for value, index in zip(values, indices):
    print(f"{classes[index]:>16s}: {100 * value.item():.2f}%")


Top predictions:

         diagram: 98.19%
           paper: 1.07%
          animal: 0.29%
             dog: 0.22%
             cat: 0.11%
