In [None]:
%cd ../

In [None]:
from pathlib import Path

import requests
import torch
from PIL import Image
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast

In [None]:
model_name = "openai/clip-vit-base-patch32"
tag = "July20"
device = "mps"
name_saved = "clip"

path_saved_model = Path("res") / f"{name_saved}_{tag}.pt"
path_saved_model.parent.mkdir(exist_ok=True, parents=True)

# Save model using `torchscript`

In [None]:
model = CLIPModel.from_pretrained(model_name, torchscript=True, return_dict=False).to(dtype=torch.float16, device=device)
processor = CLIPProcessor.from_pretrained(model_name)
tokenizer = CLIPTokenizerFast.from_pretrained(model_name)

In [None]:
# Get sample text and image
sample_text = "this is a cat"
out_sample_text = tokenizer(sample_text, return_tensors="pt", padding="max_length", truncation=True)

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
sample_image = Image.open(requests.get(url, stream=True).raw)
out_sample_image = processor(images=sample_image, return_tensors="pt")

In [None]:
converted = torch.jit.trace_module(
    model,
    {
        'get_text_features': [out_sample_text['input_ids'].to(device="mps"), out_sample_text['attention_mask'].to(device="mps")],
        'get_image_features': [out_sample_image['pixel_values'].to(device="mps", dtype=torch.bfloat16)]
    }
)
torch.jit.save(converted, path_saved_model)

## Test loading the model and doing inference

In [None]:
loaded_model = torch.jit.load(path_saved_model, map_location="cpu").to(dtype=torch.float16, device="mps")
loaded_model.eval()

In [None]:
text_embd = loaded_model.get_text_features(
    out_sample_text['input_ids'].to("mps"),
    out_sample_text['attention_mask'].to("mps")
)
img_embd = loaded_model.get_image_features(out_sample_image['pixel_values'].to("mps"))

In [None]:
text_embd

In [None]:
img_embd