# Overview
The goal of this notebook is to split the CLIP model into its text and image models respectively. Running CLIP directly from the openai repo gives the following error. https://github.com/openai/CLIP/issues/255. I guess this will be worked on in the future, meanwhile, this is a workaround 


In [1]:
import clip
import torch
import PIL
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
device = "mps" if torch.has_mps else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
image = preprocess(PIL.Image.open("data/dog.jpeg")).unsqueeze(0).to(device)
labels = clip.tokenize(["dog", "cat", "pizza","computer"]).to(device)

In [4]:
labels.shape

torch.Size([4, 77])

## Roadblock 1

Something funny going on with transformer inputs

In [5]:
model.encode_text(labels)

RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

### Solution: Rewrite torch.permute using reshape and contiguous

In [8]:
def encode_text_mps1(model, text):
    x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

    x = x + model.positional_embedding.type(model.dtype)
    num_bpe_tokens, num_text, embedding_size = x.shape
    x = x.T.reshape(num_text, num_bpe_tokens, embedding_size)  # NLD -> LND
    x = model.transformer(x)
    x = x.T.reshape(num_bpe_tokens, num_text, embedding_size) # LND -> NLD
    x = model.ln_final(x).type(model.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection

    return 

## Roadblock 2

aten::index.Tensor_out operator not implemented for MPS

In [9]:
encode_text_mps1(model, labels)

  x = x.T.reshape(num_text, num_bpe_tokens, embedding_size)  # NLD -> LND


NotImplementedError: The operator 'aten::index.Tensor_out' is not current implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.

### Solution: Perform the indexing operation in CPU and rewrite x.permute() using x.T.contigous().view()

In [32]:
def encode_text_mps2(model, text):
    x = model.token_embedding(text).type(model.dtype)  # [batch_size, n_ctx, d_model]

    x = x + model.positional_embedding.type(model.dtype)
    num_bpe_tokens, num_text, embedding_size = x.shape
    x = x.T.contiguous().view(num_text, num_bpe_tokens, embedding_size)  # NLD -> LND
    x = model.transformer(x)
    x = x.T.contiguous().view(num_bpe_tokens, num_text, embedding_size) # LND -> NLD
    x = model.ln_final(x).type(model.dtype)

    # x.shape = [batch_size, n_ctx, transformer.width]
    # take features from the eot embedding (eot_token is the highest number in each sequence)
    try:
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ model.text_projection
    except NotImplementedError:
        text = text.to('cpu')
        x = x.to('cpu')
        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)].to('mps') @ model.text_projection
    return x

In [33]:
encode_text_mps2(model, labels)

tensor([[ 0.1976,  0.0608, -0.0259,  ...,  0.7798,  0.1473, -0.2206],
        [-0.1168,  0.0208, -0.6133,  ...,  0.7041,  0.3325, -0.2130],
        [-0.0951, -0.0976, -0.5884,  ..., -0.1495,  0.0050,  0.0012],
        [ 0.0043, -0.1022, -0.2522,  ...,  0.8203,  0.2289, -0.3628]],
       device='mps:0', dtype=torch.float16, grad_fn=<MmBackward0>)

In [36]:
with torch.no_grad():
    image_features = model.visual(image.type(model.dtype))
    text_features = encode_text_mps2(model, labels)
    image_features =  image_features / image_features.norm(dim=-1, keepdim=True)
    text_features  = text_features / text_features.norm(dim=-1, keepdim=True)
    similarities = 100* image_features @ text_features.t()
    probs = similarities.softmax(dim=-1).cpu().numpy()
probs

array([[0., 0., 1., 0.]], dtype=float16)

In [25]:
labels.shape

torch.Size([4, 77])