In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import httpx
from io import BytesIO
from PIL import Image
import perception_models.core.vision_encoder.pe as pe
import perception_models.core.vision_encoder.transforms as transforms

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using Device: {device}")

In [None]:
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image_bytes = httpx.get(url).content
pil_image = Image.open(BytesIO(image_bytes))


print("CLIP configs:", pe.CLIP.available_configs())

model = pe.CLIP.from_config("PE-Core-L14-336", pretrained=True)  # Downloads from HF
model = model.to(device)

preprocess = transforms.get_image_transform(model.image_size)
tokenizer = transforms.get_text_tokenizer(model.context_length)

image = preprocess(pil_image).unsqueeze(0).to(device)
text = tokenizer(["a diagram", "a dog", "a cat"]).to(device)

with torch.no_grad(), torch.autocast(device):
    image_features, text_features, logit_scale = model(image, text)
    text_probs = (logit_scale * image_features @ text_features.T).softmax(dim=-1)

print(image.shape)
print("Label probs:", text_probs)  # prints: [[0.0, 0.0, 1.0]]