## **CLIP MODEL Implementation**

- Multi-Head Attention (custom implementation, original repo used PyTorch's nn.MultiheadAttention)
- Transformer Architecture
- Text Transformer
- Vision Transformer
- Contrastive Loss

Acknowledgement: [CLIP's Repository](https://github.com/openai/CLIP/tree/main)

In [21]:
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns

import pandas as pd
import torch.nn as nn
from typing import Any, Optional, Tuple, Union

!pip install ftfy regex tqdm



In [22]:
def tokenize(text):
  tokens = tokenizer.encode(text)
  return tokens

def tokenize_text(texts, context_len): # [B, "text string"]
  if isinstance(texts, str):
    texts = [texts]

  batch_tokens = []
  sot_token = tokenizer.encoder["<|startoftext|>"]
  eot_token = tokenizer.encoder["<|endoftext|>"]
  batch_tokens = [[sot_token] + tokenize(text) + [eot_token] for text in texts]

  result = torch.zeros(len(batch_tokens), context_len, dtype=torch.int)

  for i, tokens in enumerate(batch_tokens):
    if len(tokens) > context_len:
      tokens = tokens[:context_len]
      tokens[-1] = eot_token
    result[i, :len(tokens)] = torch.tensor(tokens)

  return result

In [23]:
class QuickGELU(nn.Module):
  def forward(self, x):
    return x * torch.sigmoid(1.702 * x)

In [24]:
class MHAttention(nn.Module):
    def __init__(self, d_model, num_heads, attn_mask=None):
      super().__init__()
      self.d_model = d_model # d_model/n_heads = 64 (head_dim) from "Attention is all you need"
      self.num_heads = num_heads
      self.head_dim = d_model // num_heads

      # D = 64 * num_heads
      # d = 64
      # h = num_heads
      # L = seq_len

      if self.head_dim * self.num_heads != self.d_model:
        raise ValueError(
            f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.d_model} and `num_heads`:"
            f" {self.num_heads})."
        )

      self.scale = self.head_dim**-0.5 # sqrt(512/8=64)
      self.dropout = 0.1

      self.q_proj = nn.Linear(d_model, d_model)
      self.k_proj = nn.Linear(d_model, d_model)
      self.v_proj = nn.Linear(d_model, d_model)

      self.out_proj = nn.Linear(d_model, d_model) # first d_model = d_v * n_heads

    def _shape(self, tensor, seq_len, bsz):
      # return N,h,L,d tensor
      return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1,2).contiguous()

    def forward(self, x, attn_mask=None, output_attentions=True): # NLD

      # attn_mask : size()-> (bsz, 1, L, L)

      bsz, ctx_len, embed_dim = x.size()

      queries = self.q_proj(x) * self.scale # NLD

      keys = self.k_proj(x) # NLD
      keys = self._shape(keys, -1, bsz) # NhLd

      values = self.v_proj(x)
      values = self._shape(values, -1, bsz) # NhLd

      proj_shape = (bsz * self.num_heads, -1, self.head_dim)

      queries = self._shape(queries, ctx_len, bsz).view(*proj_shape) # N*h, L, d
      keys = keys.view(*proj_shape) # N*h, L, d
      values = values.view(*proj_shape) # N*h, L, d

      src_len = keys.size(1)

      # keys.transpose(1,2) # N*h, d, L
      attn_weights = torch.bmm(queries, keys.transpose(1,2)) # N*h, L, L

      if attn_mask is not None:
        # adds -inf to values where we don't want to put attention to
        attn_weights = attn_weights.view(bsz, self.num_heads, ctx_len, ctx_len) + attn_mask
        attn_weights = attn_weights.view(bsz * self.num_heads, ctx_len, ctx_len) # N*h, L, L

      attn_weights = nn.functional.softmax(attn_weights, dim=-1) # softmax(-inf) -> 0 -> attention is zero for that token

      if output_attentions:
        # From CLIP's repo: this operation is a bit akward, but it's required to
        # make sure that attn_weights keeps its gradient.
        # In order to do so, attn_weights have to reshaped
        # twice and have to be reused in the following
        attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, ctx_len, ctx_len)
        attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, ctx_len, ctx_len)
      else:
        attn_weights_reshaped = None

      # N*h, L, L
      attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)

      # values: N*h, L, d
      attn_outputs = torch.bmm(attn_probs, values) # N*h, L, d

      attn_outputs = attn_outputs.contiguous().view(bsz, self.num_heads, ctx_len, self.head_dim)
      attn_outputs = attn_outputs.transpose(1,2)
      attn_outputs = attn_outputs.contiguous().view(bsz, ctx_len, self.d_model)

      attn_outputs = self.out_proj(attn_outputs)
      return attn_outputs, attn_weights_reshaped


In [25]:
class ResidualAttentionBlock(nn.Module):
  def __init__(self, d_model, n_heads, attn_mask=None):
    super().__init__()

    self.attn_block = MHAttention(d_model, n_heads)
    self.ln_1 = nn.LayerNorm(d_model)
    self.mlp = nn.Sequential(
        nn.Linear(d_model, 4*d_model),
        QuickGELU(),
        nn.Linear(4*d_model, d_model)
    )
    self.ln_2 = nn.LayerNorm(d_model)
    self.attn_mask = attn_mask

  def attention(self, x):
    self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
    return self.attn_block(x, attn_mask=self.attn_mask, output_attentions=False)[0]

  def forward(self, x): # NLD
    res = x
    x = self.ln_1(x)
    x = self.attention(x)
    x += res
    res = x
    x = self.ln_2(x)
    x = self.mlp(x)
    x += res
    return x

class Transformer(nn.Module):
  def __init__(self,
               width,
               layers,
               heads,
               attn_mask=None,
               ):
    super().__init__()

    self.model = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask=attn_mask) for _ in range(layers)])

  def forward(self, x): # x: NLD
    return self.model(x)

In [26]:
class TextTransformer(nn.Module):
  def __init__(self,
               vocab_size,
               context_len,
               width,
               layers,
               heads,
               output_dim,
               attn_mask=None):
    super().__init__()
    self.token_embedding = nn.Embedding(vocab_size, width)
    self.position_embedding = nn.Parameter(torch.empty(context_len, width))

    self.transformer = Transformer(width, layers, heads, attn_mask)

    self.ln_final = nn.LayerNorm(width)
    self.text_projection = nn.Parameter(torch.empty(width, output_dim))


  def forward(self, tokens):
    x = self.token_embedding(tokens)
    x = x + self.position_embedding # NLD

    x = self.transformer(x) # NLD
    x = self.ln_final(x) # # NLD

    # extracting eot_token features (representative of entire input sentence)
    x = x[torch.arange(x.shape[0]), tokens.argmax(dim=-1)] # ND
    x = x @ self.text_projection # projecting to CLIP embedding space

    return x

class VisionTransformer(nn.Module):
  def __init__(self, input_resolution, patch_size, width, layers, heads, output_dim):
    super().__init__()
    self.input_resolution = input_resolution
    self.patch_size = patch_size
    self.output_dim = output_dim

    self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

    self.scale = width ** -0.5
    self.width = width
    self.class_embedding = nn.Parameter(self.scale * torch.randn(width)) # CLS token for an input
    self.positional_embedding = nn.Parameter(self.scale * torch.randn((input_resolution // patch_size)**2 + 1, width))
    # self.ln_pre = nn.LayerNorm(width) # redundant as Transformer apply LN to input

    self.transformer = Transformer(width, layers, heads)
    self.ln_post = nn.LayerNorm(width)
    self.proj = nn.Parameter(self.scale * torch.randn(width, output_dim))

  def forward(self, x):
    x = self.conv1(x) # bsz, width, patch_size, patch_size
    x = x.view(x.shape[0], x.shape[1], -1) # bsz, width, n_patches
    x = x.permute(0, 2, 1) # bsz, n_patches, width (n_patches == seq_len)
    x = torch.cat([torch.zeros(x.shape[0], 1, x.shape[-1]) + self.class_embedding, x], dim=1) # bsz, n_patches+1, width
    x = x + self.positional_embedding
    # x = self.ln_pre(x)

    x = self.transformer(x) # bsz, n_patches+1, width
    x = self.ln_post(x[:, 0, :]) # # bsz, width
    x = x @ self.proj

    return x


In [27]:
def build_causal_attn_mask(ctx_len):
  mask = torch.empty(ctx_len, ctx_len)
  mask.fill_(float('-inf'))
  mask.triu_(1)
  return mask

class CLIP(nn.Module):
  def __init__(self,
               vocab_size, # text
               context_len,
               text_width,
               text_layers,
               text_heads,
               text_output_dim,
               attn_mask,
               image_resolution,
               patch_size,
               vit_width,
               vit_layers,
               vit_heads,
               vit_output_dim
               ):
    super().__init__()

    self.text_transformer = TextTransformer(vocab_size, context_len, text_width, text_layers, text_heads, text_output_dim, attn_mask)
    self.vit = VisionTransformer(image_resolution, patch_size, vit_width, vit_layers, vit_heads, vit_output_dim)

  def encode_text(self, tokens): # batch of tokens (tensors)
    x = self.text_transformer(tokens)
    return x

  def encode_image(self, images): # batch of images (tensors)
    return self.vit(images)


In [28]:

from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC


In [29]:
def _convert_to_rgb(image):
  return image.convert("RGB")

def _transform(image_res):
  return Compose([
      Resize(image_res, interpolation=BICUBIC),
      CenterCrop(image_res),
      _convert_to_rgb,
      ToTensor(),
      Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
  ])

In [30]:
import gzip
import html
import os
from functools import lru_cache

import ftfy
import regex as re


@lru_cache()
def default_bpe():
    return "bpe_simple_vocab_16e6.txt.gz"


@lru_cache()
def bytes_to_unicode():
    """
    Returns list of utf-8 byte and a corresponding list of unicode strings.
    The reversible bpe codes work on unicode strings.
    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
    This is a signficant percentage of your normal, say, 32K bpe vocab.
    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
    And avoids mapping to whitespace/control characters the bpe code barfs on.
    """
    bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def basic_clean(text):
    text = ftfy.fix_text(text)
    text = html.unescape(html.unescape(text))
    return text.strip()


def whitespace_clean(text):
    text = re.sub(r'\s+', ' ', text)
    text = text.strip()
    return text


class SimpleTokenizer(object):
    def __init__(self, bpe_path: str = default_bpe()):
        self.byte_encoder = bytes_to_unicode()
        self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
        merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
        merges = merges[1:49152-256-2+1]
        merges = [tuple(merge.split()) for merge in merges]
        vocab = list(bytes_to_unicode().values()) # bytes/symbols
        vocab = vocab + [v+'</w>' for v in vocab] # including when bytes/symbols are at the end of word
        for merge in merges:
            vocab.append(''.join(merge)) # I think merging most frequently occuring pairs
        vocab.extend(['<|startoftext|>', '<|endoftext|>'])
        self.vocab = vocab
        self.encoder = dict(zip(vocab, range(len(vocab))))
        self.decoder = {v: k for k, v in self.encoder.items()}
        self.bpe_ranks = dict(zip(merges, range(len(merges))))
        self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
        self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)

    def bpe(self, token):
        if token in self.cache:
            return self.cache[token]

        word = tuple(token[:-1]) + ( token[-1] + '</w>',)
        pairs = get_pairs(word)
        if not pairs:
            return token+'</w>'
        while True:
            bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first, second = bigram

            new_word = []
            i = 0
            while i < len(word):
                try:
                    j = word.index(first, i)
                    new_word.extend(word[i:j])
                    i = j
                except:
                    new_word.extend(word[i:])
                    break

                if word[i] == first and i < len(word)-1 and word[i+1] == second:
                    new_word.append(first+second)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1

            new_word = tuple(new_word)
            word = new_word
            if len(word) == 1:
                break
            else:
                pairs = get_pairs(word)

        word = ' '.join(word)
        self.cache[token] = word
        return word

    def encode(self, text):
        bpe_tokens = []
        text = whitespace_clean(basic_clean(text)).lower()
        for token in re.findall(self.pat, text):
            token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens

    def decode(self, tokens):
        text = ''.join([self.decoder[token] for token in tokens])
        text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
        return text

In [31]:
def tokenize(text): # "text string" -> list of tokens
  tokens = tokenizer.encode(text)
  return tokens

def tokenize_text(texts, context_len): # [B, "text string"] or "text string"
  if isinstance(texts, str):
    texts = [texts]

  batch_tokens = []
  sot_token = tokenizer.encoder["<|startoftext|>"]
  eot_token = tokenizer.encoder["<|endoftext|>"]
  batch_tokens = [[sot_token] + tokenize(text) + [eot_token] for text in texts]

  result = torch.zeros(len(batch_tokens), context_len, dtype=torch.int)

  for i, tokens in enumerate(batch_tokens):
    if len(tokens) > context_len:
      tokens = tokens[:context_len]
      tokens[-1] = eot_token
    result[i, :len(tokens)] = torch.tensor(tokens)

  return result

In [32]:
# text
tokenizer = SimpleTokenizer()
vocab_size = len(tokenizer.vocab)
context_len = 50
width = 256
heads = 8
layers = 6
output_dim = 512

tokenizer = SimpleTokenizer()

# image
image_resolution = 224
patch_size = 16

image_preprocess = _transform(image_resolution)

clip = CLIP(vocab_size, context_len, width, layers, heads, output_dim, build_causal_attn_mask(context_len),
            image_resolution, patch_size, width, layers, heads, output_dim)

In [33]:
texts = ["a diagram", "a dog", "a cat"]
batch_tokens = tokenize_text(texts, context_len)
clip_text_embedding = clip.encode_text(batch_tokens)
print('CLIP Text Embeddings:', clip_text_embedding.shape)

CLIP Text Embeddings: torch.Size([3, 512])


In [34]:
image = image_preprocess(Image.open("CLIP.png")).unsqueeze(0)
clip_image_embedding = clip.encode_image(image)
print('CLIP image embeddings:', clip_image_embedding.shape)

CLIP image embeddings: torch.Size([1, 512])
