In [1]:
from src.models.dual_encoding.clip import PreTrainedCLIP
from src.models.utils.data import FoodPricingDataset
from torch.utils.data import DataLoader
from torchvision import transforms

With PreTrainedCLIP class

In [2]:
pretrained_model_name_or_path = "clip-italian/clip-italian"

model_kwargs = {
    "pretrained_model_name_or_path": pretrained_model_name_or_path,
}

In [3]:
img_feature_dim, txt_feature_dim = None, None
model = PreTrainedCLIP(
    model_kwargs, 
    img_feature_dim=img_feature_dim,
    txt_feature_dim=txt_feature_dim,
)
model.clip.text_embed_dim
model.clip.vision_embed_dim

768

In [4]:
img_transform = transforms.Compose(
    [
        transforms.Resize(size=(224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],  # TODO: load from model
            std=[0.229, 0.224, 0.225], # TODO: load from model
        )
    ]
)
txt_transform = lambda x: x

training_data = FoodPricingDataset(
    img_transform=img_transform,
    txt_transform=txt_transform,
    split="train",
)
dataloader = DataLoader(
    training_data,
    shuffle=True,
    batch_size=4,
    num_workers=8,
)
sample = next(iter(dataloader))

In [5]:
inputs = model.processor(
    text=sample["txt"],
    images=sample["img"],
    # return_tensors="pt", # This fails due to a bug in transformers code
    return_tensors=None,
    padding=True,
)

In [6]:
import torch

inputs["input_ids"] = torch.tensor(inputs["input_ids"])
inputs["attention_mask"] = torch.tensor(inputs["attention_mask"])
if isinstance(inputs["pixel_values"], list) and len(inputs["pixel_values"]) == 1:
    inputs["pixel_values"] = inputs["pixel_values"][0]
else:
    raise ValueError("Pixel values could not be transformed into a tensor.")

In [8]:
outputs = model.clip(
    input_ids=inputs.input_ids,
    attention_mask=inputs.attention_mask,
    pixel_values=inputs.pixel_values,
    return_loss=True,
)
loss, logits_per_image = outputs.loss, outputs.logits_per_image  # this is the image-text similarity score

In [11]:
outputs.text_embeds.shape

torch.Size([4, 512])

In [12]:
outputs.image_embeds.shape

torch.Size([4, 512])