# transformers: CLIP

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import requests
from PIL import Image
from transformers import (
    pipeline,
    CLIPTokenizer,
    CLIPImageProcessor,
    CLIPTextModel,
    CLIPVisionModel,
    CLIPProcessor,
    CLIPModel
    # AutoProcessor,
    # AutoModel
)

## Load image

In [None]:
# load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'

image = Image.open(requests.get(url, stream=True).raw)

In [None]:
# show image
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(np.asarray(image))
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

## Load model

In [None]:
# set model name
model_name = 'openai/clip-vit-base-patch32'  # smaller
# model_name = 'openai/clip-vit-large-patch14'  # larger

In [None]:
# create text and image preprocessors
processor = CLIPProcessor.from_pretrained(model_name)

# load model
model = CLIPModel.from_pretrained(
    model_name,
    attn_implementation='sdpa',  # this is the default
    torch_dtype=torch.bfloat16,
    device_map='auto'
)
model = model.eval()

print(f'Model device: {model.device}')
print(f'Model dtype: {model.dtype}')
print('Memory footprint: {:.2f} GiB'.format(model.get_memory_footprint() * 1e-9))

print(f'\nEmbedding dim.: {model.config.projection_dim}')

In [None]:
# load pipeline (preprocessors, model and postprocessor)
pipe = pipeline(
   task='zero-shot-image-classification',
   model=model_name,
   torch_dtype=torch.bfloat16,
   device_map='auto'
)

## Run model

In [None]:
# set candidate captions
candidate_labels = ['cat', 'dog', 'car']
candidate_captions = [f'a photo of a {label}' for label in candidate_labels]

# preprocess inputs
inputs = processor(
    text=candidate_captions,
    images=image,
    return_tensors='pt',
    padding=True
)

print(f'Input IDs shape: {inputs['input_ids'].shape}')
print(f'Pixel values shape: {inputs['pixel_values'].shape}')

In [None]:
# run model
with torch.no_grad():
    outputs = model(**inputs.to(model.device))

logits_per_image = outputs.logits_per_image.cpu()  # get image-text similarity scores
probs_per_image = logits_per_image.softmax(dim=-1)  # get label probabilities

print(f'Logits shape: {logits_per_image.shape}')

In [None]:
# get predicted labels
top_idx = logits_per_image.argmax(dim=-1)

top_caption = candidate_captions[top_idx]
top_label = candidate_labels[top_idx]

top_prob = probs_per_image[0, top_idx].item()

print(f'Top prediction: {top_label} ({top_prob:.2f})')

## Run pipeline

In [None]:
# run pipeline
results = pipe(
    image=image,
    candidate_labels=candidate_labels
)

print(results)

## Run text model

In [None]:
# create tokenizer
tokenizer = CLIPTokenizer.from_pretrained(model_name)

# initialize text model
text_model = CLIPTextModel.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map='auto'
)

print(f'Padding side: {tokenizer.padding_side}')
print(f'Pad token: {tokenizer.pad_token}')

In [None]:
# tokenize
text_inputs = tokenizer(candidate_captions, return_tensors='pt')

print(text_inputs)

In [None]:
# print tokens
for input_ids in text_inputs['input_ids']:
    print(tokenizer.decode(input_ids))
    # print(tokenizer.convert_ids_to_tokens(input_ids))

In [None]:
# run text model
with torch.no_grad():
    text_out = text_model(**text_inputs.to(text_model.device))

last_hidden_state = text_out.last_hidden_state.cpu()  # (batch, sequence, features)
pooler_output = text_out.pooler_output.cpu()  # (batch, features)

print(f'Last hidden state shape: {last_hidden_state.shape}')
print(f'Pooler output shape: {pooler_output.shape}')

In [None]:
# check that pooler output is just the last token of the last hidden state
torch.equal(last_hidden_state[:, -1], pooler_output)

## Run vision model

In [None]:
# create image processor
image_processor = CLIPImageProcessor.from_pretrained(model_name)

# initialize image model
image_model = CLIPVisionModel.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map='auto'
)

In [None]:
# preprocess images
image_inputs = image_processor(image, return_tensors='pt')

print(f'Pixel values shape: {image_inputs['pixel_values'].shape}')

In [None]:
# run image model
with torch.no_grad():
    image_out = image_model(**image_inputs.to(image_model.device))

last_hidden_state = image_out.last_hidden_state.cpu()  # (batch, sequence, features)
pooler_output = image_out.pooler_output.cpu()  # (batch, features)

print(f'Last hidden state shape: {last_hidden_state.shape}')  # includes an additional classification token
print(f'Pooler output shape: {pooler_output.shape}')

In [None]:
# check that pooler output is the normalized first token
torch.allclose(image_model.vision_model.post_layernorm(last_hidden_state[:, 0]), pooler_output)