In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel
import numpy as np
from PIL import Image

In [3]:
class DT(Dataset):
    def __init__(self, n_X=40, n_Y=20, embed_dim=512):
        num_samples = n_X + n_Y
        self.types = torch.tensor([1] * n_X + [0] * n_Y)
        self.images = np.array([Image.fromarray((np.random.rand(224, 224, 3) * 255).astype(np.uint8)) for _ in range(num_samples)])
        self.texts = [f"Caption {i}" for i in range(num_samples)]

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        return self.images[idx], self.texts[idx], self.types[idx]

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [5]:
def collate_fn(batch):
    images, texts, types = zip(*batch)
    inputs = processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
    types = torch.stack(types).to(device)
    return inputs, types

dataset = DT(n_X=40, n_Y=10)
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

for batch in dataloader:
    inputs, types = batch

    input_ids = inputs["input_ids"].to(device)
    attention_mask = inputs["attention_mask"].to(device)
    pixel_values = inputs["pixel_values"].to(device)
    types = types.to(device)

    print("Input IDs shape:", input_ids.shape)
    print("Attention mask shape:", attention_mask.shape)
    print("Pixel values shape:", pixel_values.shape)
    print("Types shape:", types.shape)

    break

Input IDs shape: torch.Size([8, 5])
Attention mask shape: torch.Size([8, 5])
Pixel values shape: torch.Size([8, 3, 224, 224])
Types shape: torch.Size([8])


In [None]:
criterion = torch.nn.CrossEntropyLoss(reduction="none") # We use 'none' to compute the loss for each sample !

def balanced_clip_loss(logits_per_text, logits_per_image, targets, types):
    """
    Classic CLIP loss:

        texts_loss = criterion(logits_per_text, targets)
        images_loss = criterion(logits_per_image, targets)
        loss = (images_loss + texts_loss) / 2.0

    """

    n_X = sum([1 if t==1 else 0 for t in dataset.types])
    n_Y = sum([1 if t==0 else 0 for t in dataset.types])

    fX = 1 / n_X
    fY = 1 / n_Y


    weights = [fX if t == 1 else fY for t in types]
    weights = torch.tensor(weights).to(device)

    texts_loss = criterion(logits_per_text, targets)
    images_loss = criterion(logits_per_image, targets)

    images_loss_balanced = images_loss * weights
    texts_loss_balanced = texts_loss * weights

    images_loss_balanced = images_loss_balanced.mean()
    texts_loss_balanced = texts_loss_balanced.mean()

    loss = (images_loss_balanced + texts_loss_balanced) / 2.0
    return loss

model.train()
for epoch in range(3):
    loss_tot = 0
    loss_tracker = []
    for batch in dataloader:
        
        inputs, types = batch

        input_ids = inputs["input_ids"].to(device)
        attention_mask = inputs["attention_mask"].to(device)
        pixel_values = inputs["pixel_values"].to(device)
        types = types.to(device)

        # Generate model outputs
        outputs = model(input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask)
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text

        # Compute the embeddings
        image_embeddings = outputs.image_embeds
        text_embeddings = outputs.text_embeds

        # Normalize
        image_embeddings /= image_embeddings.norm(dim=-1, keepdim=True)
        text_embeddings /= text_embeddings.norm(dim=-1, keepdim=True)

        # Ground truth for this batch
        batch_size = logits_per_image.size(0)
        targets = torch.arange(batch_size).to(device)  # Correct index corresponds to diagonal

        loss = balanced_clip_loss(logits_per_text, logits_per_image, targets, types)

        loss_tot += loss.item() * batch_size

    print(f"Epoch {epoch+1} — Loss: {loss_tot:.4f}")


Epoch 1 — Loss: 4.0336
Epoch 2 — Loss: 4.0483
Epoch 3 — Loss: 4.0370


: 