### Data Processor

In [None]:
from transformers import CLIPProcessor
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch16")

### Model Inference with ORT-QNN

In [None]:
from pathlib import Path

text_model_path = Path("./outputs/openai/clip/text/model/model.onnx_ctx.onnx").resolve()
vision_model_path = Path("./outputs/openai/clip/vision/model/model.onnx_ctx.onnx").resolve()
(text_model_path, vision_model_path)

In [None]:
from qnpumodel import QNPUModule

text_model = QNPUModule(text_model_path)
vision_model = QNPUModule(vision_model_path)

def get_image_embedding(image):
    inputs = processor(images=image, return_tensors="pt")
    output = vision_model.run(inputs)
    return output["embeds"]

def _create_4d_mask(mask, input_shape, masked_value=-50.0):
    batch_sz, seq_len = input_shape
    expanded_mask = mask[:, None, None, :].expand(
        batch_sz, 1, seq_len, seq_len)
    inverted_mask = 1.0 - expanded_mask.float()
    return inverted_mask.masked_fill(inverted_mask.bool(), masked_value)

def get_text_embedding(text):
    inputs = processor(
        text=text,
        padding="max_length",
        max_length=text_model.sequence_length,
        truncation=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    output = text_model.run({
        "input_ids": inputs["input_ids"].int(),
        "attention_mask": _create_4d_mask(
            inputs["attention_mask"],
            inputs["input_ids"].shape,
        ),
    })
    return output["embeds"]

def calculate_score(text_emb, image_emb):
    import torch
    image_emb /= torch.norm(image_emb, dim=-1, keepdim=True)
    text_emb /= torch.norm(text_emb, dim=-1, keepdim=True)
    return torch.softmax(torch.matmul(text_emb, image_emb.T) * 100.0, dim=0)

def ask(image, caption):
    image_emb = get_image_embedding(image)
    text_emb = get_text_embedding([caption, "a photo of a thing"])
    score = calculate_score(text_emb, image_emb)
    return round(score[0].item(), 2)

### Play with Samples

In [None]:
import requests
from PIL import Image

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)

import matplotlib.pyplot as plt
import numpy as np

plt.imshow(np.array(image))
plt.show()

In [None]:
ask(image, "a photo of tshirt")