In [None]:
import torch
from models import (
    ModelConfig,
    build_openclip_wrapper,
    build_siglip2_wrapper,
    build_vlm2vecv2_wrapper,
)

device = torch.device("cuda:0")

# OpenCLIP, native mode
cfg_clip = ModelConfig(name="openclip", device=device, mode="native", precision="fp16")
openclip_model = build_openclip_wrapper(cfg_clip, "ViT-B-16", "laion2b_s34b_b88k")

# SigLIP-2, native mode
cfg_siglip = ModelConfig(name="siglip2", device=device, mode="native", precision="fp16")
siglip_model = build_siglip2_wrapper(cfg_siglip, "google/siglip-2-base-patch16-224")

# VLM2Vec-V2, shared mode (for ablation)
cfg_vlm = ModelConfig(name="vlm2vecv2", device=device, mode="shared", precision="fp16")
vlm_model = build_vlm2vecv2_wrapper(cfg_vlm, "tiger-research/vlm2vec-v2-base")

# Example call:
img_embeds, img_stats = openclip_model.encode_images(list_of_pil_images, return_stats=True)
txt_embeds, txt_stats = openclip_model.encode_texts(list_of_captions, return_stats=True)
print(img_stats, txt_stats)