In [1]:
import json
import random
from PIL import Image
from torch.utils.data import Dataset
import torch
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader
from transformers import CLIPProcessor
import numpy as np
from tqdm import tqdm
from transformers import CLIPModel

In [2]:
class ImageTextDataset(Dataset):
    def __init__(self, jsonl_path, processor):
        with open(jsonl_path, "r", encoding="utf-8") as f:
            data = [json.loads(line) for line in f]
        random.seed(42)
        random.shuffle(data)
        self.data = data
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]
        image = Image.open(item["image_path"]).convert("RGB")
        caption = item["caption"]
        label = int(item["label"])
        inputs = self.processor(text=caption, images=image, return_tensors="pt", padding="max_length", truncation=True)
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        return inputs, label

In [3]:
def extract_embeddings(model, dataloader, device):
    model.eval()
    all_embeds = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc="Extracting embeddings"):
            inputs = {k: v.to(device) for k, v in inputs.items()}

            outputs = model(
                input_ids=inputs["input_ids"],
                attention_mask=inputs["attention_mask"],
                pixel_values=inputs["pixel_values"]
            )

            # Mean on text and image embeddings
            combined_embeds = (outputs.text_embeds + outputs.image_embeds) / 2

            all_embeds.append(combined_embeds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    all_embeds = np.vstack(all_embeds)
    all_labels = np.array(all_labels)
    np.savez(f"clip_embeddings.npz", embeddings=all_embeds, labels=all_labels)

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_path = "combined_dataset.jsonl"

processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=False)
dataset = ImageTextDataset(dataset_path, processor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)

extract_embeddings(model, dataloader, device)

Extracting embeddings: 100%|██████████| 377/377 [07:26<00:00,  1.18s/it]
