# Interacting with CLIP

This is a self-contained notebook that shows how to download and run CLIP models, calculate the similarity between arbitrary image and text inputs, and perform zero-shot image classifications.

In [1]:
import numpy as np
import torch
from pkg_resources import packaging

print("Torch version:", torch.__version__)


Torch version: 1.11.0


# Loading the model

`clip.available_models()` will list the names of available CLIP models.

In [2]:
import clip

clip.available_models()

['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [3]:
# model, preprocess = clip.load("ViT-B/32")
# model, preprocess = clip.load("ViT-L/14")
# model, preprocess = clip.load("RN50x64")
model, preprocess = clip.load("RN50")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

100%|███████████████████████████████████████| 244M/244M [00:41<00:00, 6.15MiB/s]


Model parameters: 102,007,137
Input resolution: 224
Context length: 77
Vocab size: 49408


# Image Preprocessing

We resize the input images and center-crop them to conform with the image resolution that the model expects. Before doing so, we will normalize the pixel intensity using the dataset mean and standard deviation.

The second return value from `clip.load()` contains a torchvision `Transform` that performs this preprocessing.



In [4]:
preprocess

Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    CenterCrop(size=(224, 224))
    <function _convert_image_to_rgb at 0x7f28d7650ee0>
    ToTensor()
    Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
)

In [39]:
# Replace the original image pre-processor as we don't want to normalize the image using other dataset stats
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC


def _convert_image_to_rgb(image):
    return image.convert("RGB")

def _transform(n_px):
    return Compose([
        ToTensor(),
        Resize(n_px, interpolation=BICUBIC),
        # CenterCrop(n_px),
        # _convert_image_to_rgb,
        # ToTensor(),
        # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
        Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) #For CIFAR-10 dataset
    ])


# preprocess = _transform(model.input_resolution.item())
preprocess = _transform(model.visual.input_resolution)

print(preprocess)


Compose(
    Resize(size=224, interpolation=bicubic, max_size=None, antialias=None)
    ToTensor()
    Normalize(mean=(0.4914, 0.4822, 0.4465), std=(0.2023, 0.1994, 0.201))
)


CIFAR-10 Dataset

In [11]:

import torchvision.datasets as datasets
from torchvision import transforms

transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                             (0.2023, 0.1994, 0.2010)),
    ])


transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = datasets.CIFAR10(root='~/data', train=True, download=False,
                            transform=transform_train)
trainloader = torch.utils.data.DataLoader(trainset,
                                        #   batch_size=args.batch_size,
                                          batch_size=100,
                                          shuffle=True, num_workers=8)

testset = datasets.CIFAR10(root='~/data', train=False, download=False,
                           transform=transform_test)
testloader = torch.utils.data.DataLoader(testset, batch_size=100,
                                         shuffle=False, num_workers=8)


In [32]:
train_data = trainset.data
train_label = trainset.targets


In [40]:
preprocess(train_data[0])

TypeError: img should be PIL Image. Got <class 'numpy.ndarray'>

In [None]:

#     image = preprocess(org_image)
image_inputs = [preprocess(i) for i in original_images]

#     image_inputs = [image for i in range(10)]
image_inputs = torch.tensor(np.stack(image_inputs)).cuda()


# text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
text_tokens = clip.tokenize(texts).cuda()


# image_input = torch.tensor(np.stack(image)).cuda()
# text_token = clip.tokenize(caption).cuda()


with torch.no_grad():
    # image_features = model.encode_image(image_inputs).float()
    image_features = model.encode_image(image_inputs)
    text_features = model.encode_text(text_tokens).float()

print(image_features[0].shape, image_features[1].shape)
print(text_features.shape)
# print(image_features[1].reshape(-1, 4096).shape)
# image_features = torch.concat((image_features[0], image_features[1].reshape(-1, 1024)), dim = 0)
image_features = torch.concat((image_features[0], image_features[1]), dim = 0)
print(image_features.shape)

image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

#     print(similarity)
print(f"Max similarity score = {similarity.max()}")

In [6]:
import matplotlib.pyplot as plt
import matplotlib.pylab as pylab

import requests
from io import BytesIO
from PIL import Image
import numpy as np
pylab.rcParams['figure.figsize'] = 20, 12

def load(url):
    """
    Given an url of an image, downloads the image and
    returns a PIL image
    """
    response = requests.get(url)
    pil_image = Image.open(BytesIO(response.content)).convert("RGB")
    # convert to BGR format
    image = np.array(pil_image)[:, :, [2, 1, 0]]
    return image


def imshow(img, caption, ShowFig = True, SaveFig = True, path = 'tmp.png'):
    plt.imshow(img[:, :, [2, 1, 0]])
    plt.axis("off")
    plt.figtext(0.5, 0.09, caption, wrap=True, horizontalalignment='center', fontsize=20)

    if SaveFig:
        plt.savefig(path)

    if ShowFig:
        plt.show()


# def loadLocalImage(image_path):
#     pil_image = Image.open(image_path).convert("RGB")
#     # convert to BGR format
#     image = np.array(pil_image)[:, :, [2, 1, 0]]
#     return image


def loadLocalImage(image_path):
    pil_image = Image.open(image_path).convert("RGB")
#     # convert to BGR format
#     image = np.array(pil_image)[:, :, [2, 1, 0]]
    return pil_image

    
import os
import cv2

# import tqdm
# from tqdm.notebook import tqdm
import sys
from time import sleep
from tqdm import tqdm

import imageio

def saveAsGIF(image_list, video_name='temp.gif', fps=30):
    imageio.mimsave(video_name, image_list, fps = fps)


def saveAsVideo(image_list, video_name='temp.avi', fps=30):
    # out = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'DIVX'), fps = 15, frameSize = len(image_list))
    out = cv2.VideoWriter(video_name, cv2.VideoWriter_fourcc(*'DIVX'), fps = fps, frameSize = image_list[0].shape[:2])
    # out = cv2.VideoWriter(filename = video_name, apiPreference = cv2.CAP_FFMPEG,  fourcc = cv2.VideoWriter_fourcc(*'DIVX'), fps = fps, frameSize = image_list[0].shape[:2])
    for img in image_list:
        out.write(img)
    out.release()


In [7]:
def plotResults(original_images, texts, similarity):
    
    # count = len(descriptions)
    # count = len(texts)
    # count = len(original_images)
    count = max(len(texts), len(original_images))


    plt.figure(figsize=(20, 14))
    plt.imshow(similarity, vmin=0.1, vmax=0.3)
    # plt.colorbar()
    # plt.yticks(range(count), texts, fontsize=18)
    plt.yticks(range(len(texts)), texts, fontsize=18)
    plt.xticks([])
    for i, image in enumerate(original_images):
        plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
    for x in range(similarity.shape[1]):
        for y in range(similarity.shape[0]):
            plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)

    for side in ["left", "top", "right", "bottom"]:
      plt.gca().spines[side].set_visible(False)

    plt.xlim([-0.5, count - 0.5])
    plt.ylim([count + 0.5, -2])

    plt.title("Cosine similarity between text and image features", size=20)
    

def plotTopKImages(original_images, similarity, top_k = 5):

    # best_img_idx = similarity.argmax()
    # print(best_img_idx)
    # original_images[best_img_idx]
    
    similarity = similarity.squeeze(0)
    top_k_idx = np.argpartition(similarity, -top_k)[-top_k:]

    top_k_similarity = similarity[top_k_idx]
#     top_k_images = original_images[top_k_idx]

    sort_idx = top_k_similarity.argsort()
    sort_idx = sort_idx[::-1] #Reverse to get from max to min
    top_k_idx = top_k_idx[sort_idx]
    
    top_k_similarity = top_k_similarity[sort_idx]
    print(f"top_k_similarity = {top_k_similarity}")
    
#     top_k_images = top_k_images[sort_idx]
    
    
    plt.figure(figsize=(15, 5*int(len(top_k_idx)/3)+1))

    for idx, top_idx in enumerate(top_k_idx):
        plt.subplot(int(len(top_k_idx)/3)+1, 3, idx + 1)
        plt.imshow(original_images[top_idx])
        plt.title(f"Similarity score = {similarity[top_idx]}")
        plt.xticks([])
        plt.yticks([])

    plt.tight_layout()


def plotTopKTexts(original_images, similarity, top_k = 5):

    # best_img_idx = similarity.argmax()
    # print(best_img_idx)
    # original_images[best_img_idx]
    
    similarity = similarity.squeeze(0)
    top_k_idx = np.argpartition(similarity, -top_k)[-top_k:]

    top_k_similarity = similarity[top_k_idx]
#     top_k_images = original_images[top_k_idx]

    sort_idx = top_k_similarity.argsort()
    sort_idx = sort_idx[::-1] #Reverse to get from max to min
    top_k_idx = top_k_idx[sort_idx]
    
    top_k_similarity = top_k_similarity[sort_idx]
    print(f"top_k_similarity = {top_k_similarity}")
    
#     top_k_images = top_k_images[sort_idx]
    
    
    plt.figure(figsize=(15, 5*int(len(top_k_idx)/3)+1))

    for idx, top_idx in enumerate(top_k_idx):
        plt.subplot(int(len(top_k_idx)/3)+1, 3, idx + 1)
        plt.imshow(original_images[top_idx])
        plt.title(f"Similarity score = {similarity[top_idx]}")
        plt.xticks([])
        plt.yticks([])

    plt.tight_layout()



In [8]:
def imageResize(image, size):
#     return cv2.resize(i, size, interpolation = cv2.INTER_CUBIC)
    return cv2.resize(i, size)

def addCrop(image, h, w, crop_orientation = 'height', no_crops = 5):
    
    if crop_orientation == 'height':
        crops = [image[i*int(h/no_crops):(i+1)*int(h/no_crops), :, :] for i in range(no_crops)]
        crops = [cv2.resize(i, (w,h)) for i in crops] #Resize image
#         crops = [imageResize(i, (w,h)) for i in crops] #Resize image
        crops = [Image.fromarray(i) for i in crops] #Convert to PIL Image
    elif crop_orientation == 'width':
        crops = [image[:, i*int(w/no_crops):(i+1)*int(w/no_crops), :] for i in range(no_crops)]
        crops = [cv2.resize(i, (w,h)) for i in crops] #Resize image
#         crops = [imageResize(i, (w,h)) for i in crops] #Resize image
        crops = [Image.fromarray(i) for i in crops] #Convert to PIL Image
    elif crop_orientation == 'both':
#         crops = [[image[i*int(h/no_crops):(i+1)*int(h/no_crops), j*int(w/no_crops):(j+1)*int(w/no_crops), :] for i in range(no_crops)] for j in range(no_crops)]
        
        crops = []
        [[crops.append(image[i*int(h/no_crops):(i+1)*int(h/no_crops), j*int(w/no_crops):(j+1)*int(w/no_crops), :]) for i in range(no_crops)] for j in range(no_crops)]        
    
        crops = [cv2.resize(i, (w,h)) for i in crops] #Resize image
#         crops = [imageResize(i, (w,h)) for i in crops] #Resize image
        crops = [Image.fromarray(i) for i in crops] #Convert to PIL Image
    else:
        raise Exception(f"Error! Undefined crop_orientation = {crop_orientation}")
        
    return crops

def generateCandidateCrops(org_image):

    all_images = [org_image]

    image = np.array(org_image)[:, :, [2, 1, 0]]
    print(image.shape)

    h,w,c = image.shape

    all_images += addCrop(image, h, w, crop_orientation = 'height', no_crops = 3)
    all_images += addCrop(image, h, w, crop_orientation = 'width', no_crops = 3)
    all_images += addCrop(image, h, w, crop_orientation = 'both', no_crops = 3)
    
    # all_images += addCrop(image, h, w, crop_orientation = 'height', no_crops = 5)
    # all_images += addCrop(image, h, w, crop_orientation = 'width', no_crops = 5)
    # all_images += addCrop(image, h, w, crop_orientation = 'both', no_crops = 5)
    
    
    # all_images += addCrop(image, h, w, crop_orientation = 'height', no_crops = 10)
    # all_images += addCrop(image, h, w, crop_orientation = 'width', no_crops = 10)
    # all_images += addCrop(image, h, w, crop_orientation = 'both', no_crops = 10)

    return all_images

In [11]:
def runCLIP(org_image, captions, useCanditdateCrops = True):

#     original_images = [org_image]
    if useCanditdateCrops:
        original_images = generateCandidateCrops(org_image)
    elif isinstance(org_image, list):
        original_images = org_image
    else:
        original_images = [org_image]
        
#     image = preprocess(org_image)
    image_inputs = [preprocess(i) for i in original_images]

#     image_inputs = [image for i in range(10)]
    image_inputs = torch.tensor(np.stack(image_inputs)).cuda()

    if isinstance(captions, str):
        texts = [captions]
    else:
        texts = captions
    # text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
    text_tokens = clip.tokenize(texts).cuda()
    
    
    # image_input = torch.tensor(np.stack(image)).cuda()
    # text_token = clip.tokenize(caption).cuda()


    with torch.no_grad():
        # image_features = model.encode_image(image_inputs).float()
        image_features = model.encode_image(image_inputs)
        text_features = model.encode_text(text_tokens).float()
    
    print(image_features[0].shape, image_features[1].shape)
    print(text_features.shape)
    # print(image_features[1].reshape(-1, 4096).shape)
    # image_features = torch.concat((image_features[0], image_features[1].reshape(-1, 1024)), dim = 0)
    image_features = torch.concat((image_features[0], image_features[1]), dim = 0)
    print(image_features.shape)
    
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T

#     print(similarity)
    print(f"Max similarity score = {similarity.max()}")

    return original_images, texts, similarity