# transformers: Vision-language models

In [None]:
%matplotlib inline

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import requests
from PIL import Image
from transformers import (
    pipeline,
    AutoProcessor,
    AutoModel,
    CLIPProcessor,
    CLIPModel
)

## Load image

In [None]:
# load image
url = 'http://images.cocodataset.org/val2017/000000039769.jpg'

image = Image.open(requests.get(url, stream=True).raw)

In [None]:
# show image
fig, ax = plt.subplots(figsize=(6, 4))
ax.imshow(np.asarray(image))
ax.set_aspect('equal', adjustable='box')
fig.tight_layout()

## Load model

In [None]:
# set model name
model_name = 'openai/clip-vit-base-patch32'

# create text and image preprocessors
processor = CLIPProcessor.from_pretrained(model_name)

# load model
model = CLIPModel.from_pretrained(
    model_name,
    # attn_implementation='sdpa',
    torch_dtype=torch.bfloat16,
    device_map='auto'
)
model = model.eval()

In [None]:
# load pipeline (preprocessors, model and postprocessor)
pipe = pipeline(
   task='zero-shot-image-classification',
   model=model_name,
   torch_dtype=torch.bfloat16,
   device_map='auto'
)

## Run model

In [None]:
# set candidate captions
candidate_labels = ['cat', 'dog', 'car']
candidate_captions = [f'a photo of a {label}' for label in candidate_labels]

# preprocess inputs
inputs = processor(
    text=candidate_captions,
    images=image,
    return_tensors='pt',
    padding=True
)

print(f'Input IDs shape: {inputs['input_ids'].shape}')
print(f'Pixel values shape: {inputs['pixel_values'].shape}')

In [None]:
# run model
with torch.no_grad():
    outputs = model(**inputs.to(model.device))

logits_per_image = outputs.logits_per_image.cpu() # get image-text similarity score
probs_per_image = logits_per_image.softmax(dim=-1) # get label probabilities

print(f'Logits shape: {logits_per_image.shape}')

In [None]:
# get predicted labels
label_ids = logits_per_image.argmax(dim=-1)

captions = candidate_captions[label_ids]
labels = candidate_labels[label_ids]

print(labels)

## Run pipeline

In [None]:
# run pipeline
results = pipe(
    images=image,
    candidate_labels=candidate_labels
)

print(results)