In [7]:
from src.models.dual_encoding.pretrained_clip import PreTrainedCLIP
from src.models.utils.data import FoodPricingDataset
from torch.utils.data import DataLoader
from torchvision import transforms

With PreTrainedCLIP class

In [8]:
pretrained_model_name_or_path = "clip-italian/clip-italian"

model_kwargs = {
    "pretrained_model_name_or_path": pretrained_model_name_or_path,
}

In [9]:
img_feature_dim, txt_feature_dim = None, None
model = PreTrainedCLIP(
    model_kwargs, 
    img_feature_dim=img_feature_dim,
    txt_feature_dim=txt_feature_dim,
)

print("Text: ", model.clip.text_embed_dim)
print("Image: ", model.clip.vision_embed_dim)
print("Projection: ", model.clip.projection_dim)

Text:  768
Image:  768
Projection:  512


In [10]:
size = model.processor.feature_extractor.crop_size
mean = model.processor.feature_extractor.image_mean
std = model.processor.feature_extractor.image_std

img_transform = transforms.Compose(
    [
        transforms.Resize(size=(size, size)),
        transforms.ToTensor(),
        transforms.Normalize(
            mean=mean,
            std=std,
        )
    ]
)
txt_transform = lambda x: x

training_data = FoodPricingDataset(
    img_transform=img_transform,
    txt_transform=txt_transform,
    split="train",
)
dataloader = DataLoader(
    training_data,
    shuffle=True,
    batch_size=4,
    num_workers=8,
)
sample = next(iter(dataloader))

In [11]:
model(sample["txt"], sample["img"])

{'img': tensor([[ 0.0158, -0.0360,  0.0568,  ...,  0.1045, -0.0002,  0.0349],
         [-0.0071, -0.0447,  0.0286,  ...,  0.0730,  0.0426,  0.0427],
         [ 0.0124, -0.0156,  0.0531,  ...,  0.0001, -0.0558,  0.0945],
         [ 0.0320, -0.0583,  0.0682,  ...,  0.0345, -0.0620,  0.0810]]),
 'txt': tensor([[-0.0238, -0.0674,  0.0234,  ...,  0.1238,  0.0372,  0.0515],
         [-0.0335, -0.0699,  0.0038,  ...,  0.0652,  0.0178,  0.0332],
         [-0.0236, -0.0100,  0.0528,  ...,  0.1020, -0.0693,  0.0285],
         [-0.0157, -0.0665,  0.0233,  ...,  0.0490, -0.0183,  0.0253]])}