# Interacting with CLIP

This is a self-contained notebook that shows how to download and run CLIP models, calculate the similarity between arbitrary image and text inputs, and perform zero-shot image classifications.

# Preparation for Colab

Make sure you're running a GPU runtime; if not, select "GPU" as the hardware accelerator in Runtime > Change Runtime Type in the menu. The next cells will print the CUDA version of the runtime if it has a GPU, and install PyTorch 1.7.1.

In [1]:
import subprocess

CUDA_version = [s for s in subprocess.check_output(["nvcc", "--version"]).decode("UTF-8").split(", ") if s.startswith("release")][0].split(" ")[-1]
print("CUDA version:", CUDA_version)

if CUDA_version == "10.0":
    torch_version_suffix = "+cu100"
elif CUDA_version == "10.1":
    torch_version_suffix = "+cu101"
elif CUDA_version == "10.2":
    torch_version_suffix = ""
else:
    torch_version_suffix = "+cu110"

CUDA version: 11.0


In [None]:
! pip install torch==1.7.1{torch_version_suffix} torchvision==0.8.2{torch_version_suffix} -f https://download.pytorch.org/whl/torch_stable.html ftfy regex

In [3]:
import numpy as np
import torch

print("Torch version:", torch.__version__)

Torch version: 1.7.1+cu110


# Downloading the model

CLIP models are distributed as TorchScript modules.

In [4]:
MODELS = {
    "ViT-B/32":       "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
}

In [5]:
! wget {MODELS["ViT-B/32"]} -O model.pt

--2021-03-07 19:24:01--  https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt
Resolving openaipublic.azureedge.net (openaipublic.azureedge.net)... 13.107.246.19, 13.107.213.19, 2620:1ec:bdf::19, ...
Connecting to openaipublic.azureedge.net (openaipublic.azureedge.net)|13.107.246.19|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 353976522 (338M) [application/octet-stream]
Saving to: ‘model.pt’


2021-03-07 19:24:12 (31.5 MB/s) - ‘model.pt’ saved [353976522/353976522]



In [6]:
model = torch.jit.load("model.pt").cuda().eval()
input_resolution = model.input_resolution.item()
context_length = model.context_length.item()
vocab_size = model.vocab_size.item()

print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)

Model parameters: 151,277,313
Input resolution: 224
Context length: 77
Vocab size: 49408


# Image Preprocessing

We resize the input images and center-crop them to conform with the image resolution that the model expects. Before doing so, we will normalize the pixel intensity using the dataset mean and standard deviation.



In [7]:
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image

preprocess = Compose([
    Resize(input_resolution, interpolation=Image.BICUBIC),
    CenterCrop(input_resolution),
    ToTensor()
])

image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cuda()
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cuda()

# Text Preprocessing

We use a case-insensitive tokenizer. The tokenizer code is hidden in the second cell below

In [8]:
! pip install ftfy regex
! wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz

--2021-03-07 19:24:19--  https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz
Resolving openaipublic.azureedge.net (openaipublic.azureedge.net)... 13.107.246.19, 13.107.213.19, 2620:1ec:bdf::19, ...
Connecting to openaipublic.azureedge.net (openaipublic.azureedge.net)|13.107.246.19|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1356917 (1.3M) [application/octet-stream]
Saving to: ‘bpe_simple_vocab_16e6.txt.gz’


2021-03-07 19:24:19 (23.0 MB/s) - ‘bpe_simple_vocab_16e6.txt.gz’ saved [1356917/1356917]



In [9]:
#@title

import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = "bpe_simple_vocab_16e6.txt.gz"):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)

        if not pairs:
            return token+'</w>'

        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text


# Zero-shot classification for Oxford buildings dataset.

#### Getting dataset

In [10]:
%mkdir data
%cd data
%mkdir buildings
# the following line should be modified if you run the notebook on your computer
# change directory to data where you will store the dataset
%cd /content/data/buildings
!wget https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/oxbuild_images.tgz

/content/data
/content/data/buildings
--2021-03-07 19:24:20--  https://www.robots.ox.ac.uk/~vgg/data/oxbuildings/oxbuild_images.tgz
Resolving www.robots.ox.ac.uk (www.robots.ox.ac.uk)... 129.67.94.2
Connecting to www.robots.ox.ac.uk (www.robots.ox.ac.uk)|129.67.94.2|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1980280437 (1.8G) [application/x-gzip]
Saving to: ‘oxbuild_images.tgz’


2021-03-07 19:25:49 (21.4 MB/s) - ‘oxbuild_images.tgz’ saved [1980280437/1980280437]



In [None]:
!tar -zxvf oxbuild_images.tgz

In [12]:
# create directory if it does not exist
def check_dir(dir_path):
    dir_path = dir_path.replace('//','/')
    os.makedirs(dir_path, exist_ok=True)

In [13]:
import os
from os import listdir
from os.path import isfile, join

data_dir = '/content/data'

import shutil
for f in listdir(join(data_dir, 'buildings')):
    # create directory if necessary
    path = join(data_dir, f[:-11])
    check_dir(path)
    # copy to target directory
    src_path = join(data_dir, 'buildings', f)
    targ_path = join(path, f)
    shutil.move(src_path, targ_path)

In [14]:
try:
    shutil.move('/content/data/oxbuild', '/content/sample_data')
    print('The directory has been moved successfully')
except:
    print('The directory has already been moved')

The directory was moved successfully


In [15]:
from pathlib import Path

In [16]:
try:
    shutil.move('/content/data/buildings', '/content/sample_data')
    print('The directory has been moved successfully')
except:
    print('The directory has already been moved')

The directory was moved successfully


In [17]:
!pwd

/content/sample_data/buildings


In [18]:
from torchvision import datasets

In [19]:
data_dir = '/content/data'
dsets = datasets.ImageFolder(data_dir, preprocess)
# with preprocessing
# dsets = datasets.ImageFolder(data_dir, preprocess)

In [20]:
%ls /content/data

[0m[01;34mall_souls[0m/  [01;34mbodleian[0m/       [01;34mhertford[0m/  [01;34mmagdalen[0m/  [01;34moxford[0m/            [01;34mtrinity[0m/
[01;34mashmolean[0m/  [01;34mchrist_church[0m/  [01;34mjesus[0m/     [01;34mnew[0m/       [01;34mpitt_rivers[0m/       [01;34mworcester[0m/
[01;34mballiol[0m/    [01;34mcornmarket[0m/     [01;34mkeble[0m/     [01;34moriel[0m/     [01;34mradcliffe_camera[0m/


In [21]:
data_loader = torch.utils.data.DataLoader(dsets, batch_size=64, shuffle=True, num_workers=6)

In [22]:
from_text = True
text_descriptions = {}
if from_text:
  s = ""
  with open("/content/oxford2.txt") as fp:
      Lines = fp.readlines()
      for line in Lines:
          if line.startswith('---'):
              if s != "":
                  text_descriptions.update({building: s})
                  print("Line {}: {}".format(building, s))
              s = ""
              building = line[3:-4]
          else:
              s += ' ' + line[:-2] 
  text_descriptions.update({building: s})
else:
  classes = [s.replace('_', ' ') for s in dsets.classes]
  text_descriptions = {label: f"This is a photo of {label}" for label in classes}

Line All Souls College:  All Souls College is a constituent college of the University of Oxford in England. The college entrance is on the north side of the High Street whilst it has a long frontage onto Radcliffe Square. To its east is The Queen's College whilst Hertford College is to the north of All Souls
Line The Ashmolean Museum:  The Ashmolean Museum of Art and Archaeology on Beaumont Street, Oxford, England, is the world's second university museum  and Britain's first public museum. Its first building was erected in 1678–1683 to house the cabinet of curiosities that Elias Ashmole gave to the University of Oxford in 1677
Line Balliol College:  Balliol College is one of the constituent colleges of the University of Oxford in England. One of Oxford's oldest colleges, it was founded around 1263 by John I de Balliol, a rich landowner from Barnard Castle in County Durham, who provided the foundation and endowment for the college. When de Balliol died in 1269 his widow, Dervorguilla, a

In [23]:
print(*dsets.classes, sep='\n')

all_souls
ashmolean
balliol
bodleian
christ_church
cornmarket
hertford
jesus
keble
magdalen
new
oriel
oxford
pitt_rivers
radcliffe_camera
trinity
worcester


In [None]:
text_descriptions

In [25]:
# text_descriptions['All Souls College']

In [26]:
!pwd

/content/sample_data/buildings


In [27]:
%cd /content/

/content


In [33]:
tokenizer = SimpleTokenizer()

sot_token = tokenizer.encoder['<|startoftext|>']
eot_token = tokenizer.encoder['<|endoftext|>']

text_tokens = [[sot_token] + tokenizer.encode(desc)[:65] + [eot_token] 
               for name, desc in text_descriptions.items()]
text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)

for i, tokens in enumerate(text_tokens):
    text_input[i, :len(tokens)] = torch.tensor(tokens)

text_input = text_input.cuda()
text_input.shape

torch.Size([17, 77])

In [34]:
with torch.no_grad():
    text_features = model.encode_text(text_input).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

In [35]:
import matplotlib.pyplot as plt

def imshow(inp, title=None):
#   Imshow for Tensor.
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = np.clip(std * inp + mean, 0,1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)

In [36]:
acc, acc_top5, count = 0, 0, 0
for i, (x, label) in enumerate(data_loader):
    
    with torch.no_grad():
        image_input = torch.tensor(np.stack(x)).cuda()
        # image_input -= image_mean[:, None, None]
        # image_input /= image_std[:, None, None]
        image_features = model.encode_image(image_input).float()

    text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
    top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)

    acc += (top_labels[:, 0] == label).sum()
    # print(top_labels[:, 0], label)
    count += len(label)
    for k in range(5):
        acc_top5 += (top_labels[:,k] == label).sum()

    if i % 10 == 0: 
      print('%3d' %i, 'batches processed |', 
            'Accuracy is %.3f |' %(acc / count), 
            'Top 5 accuracy is %.3f' %(acc_top5 / count))
print('-'*85)
print('Total accuracy is %.3f |' %(acc / count), 
      'Top 5 accuracy is %.3f' %(acc_top5 / count))

  0 batches processed | Accuracy is 0.266 | Top 5 accuracy is 0.594
 10 batches processed | Accuracy is 0.298 | Top 5 accuracy is 0.672
 20 batches processed | Accuracy is 0.313 | Top 5 accuracy is 0.699
 30 batches processed | Accuracy is 0.315 | Top 5 accuracy is 0.704
 40 batches processed | Accuracy is 0.318 | Top 5 accuracy is 0.707
 50 batches processed | Accuracy is 0.313 | Top 5 accuracy is 0.704
 60 batches processed | Accuracy is 0.314 | Top 5 accuracy is 0.706
 70 batches processed | Accuracy is 0.310 | Top 5 accuracy is 0.704
-------------------------------------------------------------------------------------
Total accuracy is 0.311 | Top 5 accuracy is 0.705


In [37]:
print('Accuracy is %.3f |' %(acc / count), 
      'Top 5 accuracy is %.3f' %(acc_top5 / count))

Accuracy is 0.311 | Top 5 accuracy is 0.705
