# **Guide to OpenAI’s CLIP – Connecting Text To Images**

Standard computer vision datasets cannot generalize many aspects of vision-based models. Creating image datasets would be laborious and have limitations, with restrictions over only a certain range of object categories. To overcome these image label constraints, OpenAI has designed its new neural network architecture CLIP (Contrastive Language-Image Pretraining) for Learning Transferable Visual Models From Natural Language Supervision. 

To learn how CLIP works, please refer [this](https://analyticsindiamag.com/hands-on-guide-to-openais-clip-connecting-text-to-images/) article.

## **Implementation**

In [None]:
!python -m pip install pip --upgrade --user -q --no-warn-script-location
!python -m pip install numpy pandas seaborn matplotlib scipy statsmodels sklearn tensorflow keras opencv-python pillow scikit-image torch torchvision \
    --user -q --no-warn-script-location

!python -m pip install ftfy regex --user -q


import IPython
IPython.Application.instance().kernel.do_shutdown(True)


In [None]:
#  MODELS = {
#      "ViT-B/32":   "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
#  }

In [None]:
# !wget {MODELS["ViT-B/32"]} -O model.pt

In [None]:
import subprocess

In [None]:
import numpy as np
import torch

print("Torch version:", torch.__version__)

In [None]:
model = torch.jit.load("https://gitlab.com/AnalyticsIndiaMagazine/practicedatasets/-/raw/main/openai_clip/model.pt").cpu().eval()
input_resolution = model.input_resolution.item()
context_length = model.context_length.item()
vocab_size = model.vocab_size.item()
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size) 

Image Preprocessing

In [None]:
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
preprocess = Compose([
    Resize(input_resolution, interpolation=Image.BICUBIC),
    CenterCrop(input_resolution),
    ToTensor()
])
image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).cpu()
image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).cpu() 

Text Preprocessing

In [None]:

# ! wget https://openaipublic.azureedge.net/clip/bpe_simple_vocab_16e6.txt.gz -O bpe_simple_vocab_16e6.txt.gz

Returns list of utf-8 byte and a corresponding list of Unicode strings. The reversible BPE codes work on unicode strings. This means you need many unicode characters in your vocab if you want to avoid UNKs. When you’re at something like a 10B token dataset, you need around 5K for decent coverage.

In [None]:
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re 

This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup tables between utf-8 bytes and Unicode strings. And avoids mapping to whitespace/control characters the bpe code barfs on.

In [None]:
def bytes_to_unicode():
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs)) 

Return set of symbol pairs in a word. Word is represented as a tuple of symbols (symbols being variable-length strings).

In [None]:
def get_pairs(word):
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs
def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()
def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text
class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = "bpe_simple_vocab_16e6.txt.gz"):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values())
        vocab = vocab + [v+'</w>' for v in vocab]
        for merge in merges:
            vocab.append(''.join(merge))
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]
        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)
        if not pairs:
            return token+'</w>'
        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram
            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break
                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)
        word = ' '.join(word)
        self.cache[token] = word
        return word
    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens
    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text 

Setting up input images and texts

In [None]:
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# images in skimage to use along with their textual descriptions
descriptions = {
    "page": "a page of text about segmentation",
    "chelsea": "a facial photo of a tabby cat",
    "astronaut": "a portrait of an astronaut with the American flag",
    "rocket": "a rocket standing on a launchpad",
    "motorcycle_right": "a red motorcycle standing in a garage",
    "camera": "a person looking at a camera on a tripod",
    "horse": "a black-and-white silhouette of a horse", 
    "coffee": "a cup of coffee on a saucer"
}
images = []
texts = []
plt.figure(figsize=(16, 5))
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
    name = os.path.splitext(filename)[0]
    if name not in descriptions:
        continue
    image = preprocess(Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB"))
    images.append(image)
    texts.append(descriptions[name])
    plt.subplot(2, 4, len(images))
    plt.imshow(image.permute(1, 2, 0))
    plt.title(f"{filename}\n{descriptions[name]}")
    plt.xticks([])
    plt.yticks([])
plt.tight_layout() 

Building features

In [None]:
image_input = torch.tensor(np.stack(images)).cpu()
image_input -= image_mean[:, None, None]
image_input /= image_std[:, None, None]
tokenizer = SimpleTokenizer()
text_tokens = [tokenizer.encode("This is " + desc + "<|endoftext|>") for desc in texts]
text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)
for i, tokens in enumerate(text_tokens):
    text_input[i, :len(tokens)] = torch.tensor(tokens)
text_input = text_input.cuda()
with torch.no_grad():
    image_features = model.encode_image(image_input).float()
    text_features = model.encode_text(text_input).float() 

Calculating cosine similarity

In [None]:
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
count = len(descriptions)
plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmax=0.3)
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(images):
    plt.imshow(image.permute(1, 2, 0), extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
    for y in range(similarity.shape[0]):
        plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
for side in ["left", "top", "right", "bottom"]:
  plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])
plt.title("Cosine similarity between text and image features", size=20)
Text(0.5, 1.0, 'Cosine similarity between text and image features')

## **Zero-Shot Image Classification**

Downloading https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz to /root/.cache/cifar-100-python.tar.gz

In [None]:
# !wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz -O /root/.cache/cifar-100-python.tar.gz

In [None]:
from torchvision.datasets import CIFAR100
cifar100 = CIFAR100(os.path.expanduser("https://gitlab.com/AnalyticsIndiaMagazine/practicedatasets/-/raw/main/openai_clip/cifar-100-python.tar.gz"), transform=preprocess, download=True) 

Extracting /root/.cache/cifar-100-python.tar.gz to /root/.cache

In [None]:
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = [tokenizer.encode(desc + "<|endoftext|>") for desc in text_descriptions]
text_input = torch.zeros(len(text_tokens), model.context_length, dtype=torch.long)
for i, tokens in enumerate(text_tokens):
    text_input[i, :len(tokens)] = torch.tensor(tokens)
text_input = text_input.cuda()
text_input.shape
torch.Size([100, 77])
with torch.no_grad():
    text_features = model.encode_text(text_input).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
plt.figure(figsize=(16, 16))
for i, image in enumerate(images):
    plt.subplot(4, 4, 2 * i + 1)
    plt.imshow(image.permute(1, 2, 0))
    plt.axis("off")
    plt.subplot(4, 4, 2 * i + 2)
    y = np.arange(top_probs.shape[-1])
    plt.grid()
    plt.barh(y, top_probs[i])
    plt.gca().invert_yaxis()
    plt.gca().set_axisbelow(True)
    plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
    plt.xlabel("probability")
plt.subplots_adjust(wspace=0.5)
plt.show() 

#**Related Articles:**

> * [Guide to OpenAI's CLIP](https://analyticsindiamag.com/hands-on-guide-to-openais-clip-connecting-text-to-images/)

> * [Guide to VISSL](https://analyticsindiamag.com/guide-to-vissl-vision-library-for-self-supervised-learning/)

> * [Unet Image Segmentation Model](https://analyticsindiamag.com/my-experiment-with-unet-building-an-image-segmentation-model/)

> * [Comparison of Semantic, Instance and Panoptic Segmentation](https://analyticsindiamag.com/semantic-vs-instance-vs-panoptic-which-image-segmentation-technique-to-choose/)

> * [Panoptic Segmentation](https://analyticsindiamag.com/guide-to-panoptic-segmentation-a-semantic-instance-segmentation-approach/)

> * [PaddleSeg](https://analyticsindiamag.com/guide-to-asymmetric-non-local-neural-networks-using-paddleseg/)

> * [MMDetection](https://analyticsindiamag.com/guide-to-mmdetection-an-object-detection-python-toolbox/)