In [1]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch import nn, optim
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
from torchvision import transforms
import pandas as pd

In [40]:
class TripletPaperDataset(Dataset):
    def __init__(self, csv_path, image_root, processor, image_size=224):
        """
        csv_path: path to your CSV with columns [paper_id, question, positive, negative]
        image_root: root folder for 'train/' 
        processor: a HuggingFace CLIPProcessor
        """
        self.df = pd.read_csv(csv_path)
        self.image_root = image_root
        self.processor = processor
        # fallback transforms (if not using processor for images)
        self.default_transform = transforms.Compose([
            transforms.Resize((image_size, image_size)),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=self.processor.feature_extractor.image_mean,
                std=self.processor.feature_extractor.image_std
            )
        ])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        question = row['question']

        # load images
        pos_path = os.path.join(self.image_root, row['positive'])
        neg_path = os.path.join(self.image_root, row['negative'])
        img_pos = Image.open(pos_path).convert('RGB')
        img_neg = Image.open(neg_path).convert('RGB')

        # use processor to tokenize text and preprocess images
        encoding = self.processor(
            text=[question, question],
            images=[img_pos, img_neg],
            return_tensors='pt',
            padding=True,
            truncation=True,                             # enforce at most 77 tokens for CLIP
            max_length=self.processor.tokenizer.model_max_length
        )

        # encoding fields: input_ids, attention_mask, pixel_values
        return {
            'question':      question,
            'input_ids':      encoding['input_ids'][0],
            'attention_mask': encoding['attention_mask'][0],
            'img_pos':        img_pos,
            'img_neg':        img_neg
        }

In [41]:
def train_clip(
    csv_path,
    image_root='',
    model_name='openai/clip-vit-base-patch32',
    batch_size=16,
    epochs=3,
    lr=5e-6,
    margin=0.2,
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    # 1. Prepare model & processor
    processor = CLIPProcessor.from_pretrained(model_name)
    model = CLIPModel.from_pretrained(model_name).to(device)

    # 2. Dataset & DataLoader

    dataset = TripletPaperDataset(csv_path, image_root, processor)
    def collate_fn(batch):
        questions = [b["question"] for b in batch]
        images     = sum([[b["img_pos"], b["img_neg"]] for b in batch], [])
    
        batch_enc = processor(
            text=questions,
            images=images,
            padding=True,
            truncation=True,
            max_length=processor.tokenizer.model_max_length,
            return_tensors="pt"
        )

        pv = batch_enc.pixel_values.view(len(batch), 2, *batch_enc.pixel_values.shape[1:])
    
        return {
            "input_ids":      batch_enc.input_ids,
            "attention_mask": batch_enc.attention_mask,
            "img_pos":        pv[:,0],
            "img_neg":        pv[:,1],
        }

    loader = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=True,
                num_workers=0,
                collate_fn=collate_fn
            )

    # 3. Loss & optimizer
    criterion = nn.TripletMarginLoss(margin=margin, p=2)
    optimizer = optim.AdamW(model.parameters(), lr=lr)

    model.train()
    for epoch in range(1, epochs+1):
        total_loss = 0.0
        for batch in loader:
            optimizer.zero_grad()

            # move inputs to device
            input_ids      = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            img_pos        = batch['img_pos'].to(device)
            img_neg        = batch['img_neg'].to(device)

            # 4. Forward passes
            text_outputs    = model.get_text_features(input_ids=input_ids,
                                                      attention_mask=attention_mask)
            pos_img_outputs = model.get_image_features(pixel_values=img_pos)
            neg_img_outputs = model.get_image_features(pixel_values=img_neg)

            # 5. L2-normalize embeddings
            text_emb = text_outputs / text_outputs.norm(p=2, dim=-1, keepdim=True)
            pos_emb  = pos_img_outputs / pos_img_outputs.norm(p=2, dim=-1, keepdim=True)
            neg_emb  = neg_img_outputs / neg_img_outputs.norm(p=2, dim=-1, keepdim=True)

            # 6. Compute Triplet Loss
            loss = criterion(text_emb, pos_emb, neg_emb)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(loader)
        print(f"Epoch {epoch}/{epochs} — avg loss: {avg_loss:.4f}")

    # 7. Save the fine-tuned model
    save_dir = "clip_finetuned"
    os.makedirs(save_dir, exist_ok=True)
    model.save_pretrained(save_dir)
    processor.save_pretrained(save_dir)
    print(f"Model saved to {save_dir}/")

In [42]:
train_clip(
        csv_path="train.csv",
        image_root=".",
        model_name="openai/clip-vit-base-patch32",
        batch_size=8,
        epochs=5,
        lr=1e-5,
        margin=0.3,
        device='mps'
    )



Epoch 1/5 — avg loss: 0.2240
Epoch 2/5 — avg loss: 0.1688
Epoch 3/5 — avg loss: 0.1130
Epoch 4/5 — avg loss: 0.0674
Epoch 5/5 — avg loss: 0.0385
Model saved to clip_finetuned/
