In [5]:
from pathlib import Path
from typing import Any

import pandas as pd
import torch
from datasets import DatasetDict, load_dataset
from PIL.JpegImagePlugin import JpegImageFile
from tqdm.auto import tqdm
from transformers import CLIPModel, CLIPProcessor
from transformers.models.clip.modeling_clip import CLIPOutput
from transformers.tokenization_utils_base import BatchEncoding


def process_and_save_clip_embeddings(
    output_dir: Path | str,
    topk: int = 1,
    shortest_edge: int = 224,
    device: str = "cuda" if torch.cuda.is_available() else "cpu",
) -> None:
    """
    Process Flickr30k dataset and save CLIP embeddings with topk similar captions.

    Args:
        output_dir: Directory to save the processed dataset
        split: Dataset split ('train', 'test', 'validation')
        topk: Number of most similar captions to keep per image
        batch_size: Batch size for processing
        device: Device to use for computation
    """
    # Load CLIP model and processor
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)  # type: ignore
    processor: CLIPProcessor = CLIPProcessor.from_pretrained(
        "openai/clip-vit-base-patch32"
    )  # type: ignore

    # Load Flickr dataset
    flickr = load_dataset("nlphuji/flickr30k")
    dataset: DatasetDict = flickr["test"]  # type: ignore

    # Create output directory
    output_dir = Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)

    # Prepare output file

    data_dict = {"train": [], "test": [], "val": []}

    # Process dataset in batches
    for b, row in enumerate(tqdm(dataset, desc="Processing Dataset", total=len(dataset))):
        image: JpegImageFile = row["image"]
        captions: list[str] = row["caption"]
        split: str = row["split"]
        width, height = image.size
        image_id: int = int(row["img_id"])
        filename: str = row["filename"]

        # Pass the image & 5 captions to the CLIP Processor
        vision_input: BatchEncoding = processor(
            images=image,
            return_tensors="pt",
            size={"shortest_edge": shortest_edge},
            padding=True,
        ).to(device)

        model_input: BatchEncoding = processor(
            text=captions,
            images=image,
            return_tensors="pt",
            size={"shortest_edge": shortest_edge},
            padding=True,
            truncation=True,
        ).to(device)

        # Pass this input into CLIP to get outputs
        model_output: CLIPOutput = model(**model_input)

        # This obtains the CLS token for the image (batch_size, d_model=512)
        image_output = model.get_image_features(**vision_input).squeeze()  # type: ignore

        # We will now pick the top-k most similar captions
        _, caption_indices = model_output["logits_per_image"].topk(k=topk)
        for idx in caption_indices[0].tolist():
            data_row = {
                "img_embedding": image_output.tolist(),
                "caption_text": captions[idx],
                "img_id": image_id,
                "filename": filename,
            }
            # Append the row to the data_list for the corresponding split
            data_dict[split].append(data_row)

        if b >= 50:
            break

    print("Saving data to parquet files")
    # Once done with making lists, create dataframes, save as parquet
    save_dataframe_parquet(
        data_dict=data_dict, topk=topk, split="train", output_dir=output_dir
    )
    save_dataframe_parquet(
        data_dict=data_dict, topk=topk, split="val", output_dir=output_dir
    )
    save_dataframe_parquet(
        data_dict=data_dict, topk=topk, split="test", output_dir=output_dir
    )


def save_dataframe_parquet(
    data_dict: dict[str, Any], topk: int, split: str, output_dir: Path | str
):
    df = pd.DataFrame(data_dict[split])
    filepath = Path(output_dir) / f"flickr_{split}_top{topk}.parquet"
    df.to_parquet(filepath)

In [6]:
from pathlib import Path

import pandas as pd
import torch
from torch.utils.data import Dataset


class Flicker30K(Dataset):
    # TODO: Try swapping out clip tokenizer & text embeddings for GPT2's
    def __init__(self, datafile: Path | str):
        super().__init__()
        datafile = Path(datafile)
        if datafile.is_file() and datafile.suffix == ".parquet":
            self.dataset = pd.read_parquet(datafile)
        else:
            raise FileNotFoundError(f"No datafile found in {datafile}")

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        img_emb, cap, img_id = self.dataset.iloc[idx][
            [
                "img_embedding",
                "caption_text",
                "img_id",
            ]
        ]

        img_emb = torch.tensor(img_emb)

        return {
            "image_emb": img_emb.to("cpu", dtype=torch.float32),
            "caption": cap,
            "img_id": img_id,
        }

In [7]:
output_dir = Path("../datafiles/")
output_dir.exists()

True

In [8]:
process_and_save_clip_embeddings(output_dir, 1, 150)

Processing Dataset:   0%|                          | 50/31014 [00:16<2:48:17,  3.07it/s]

Saving data to parquet files





In [10]:
filepath = output_dir / "flickr_train_top1.parquet"
filepath.exists()

True

In [12]:
ds = Flicker30K(filepath)

In [13]:
df = pd.read_parquet(filepath, )

In [15]:
ds[0]

{'image_emb': tensor([ 4.3015e-01,  4.9582e-01, -8.4673e-02, -1.6892e-02,  1.1948e-01,
         -2.8855e-01,  3.1858e-01,  3.3653e-01,  8.4516e-02,  1.9869e-02,
          2.0031e-01, -3.0132e-02,  3.3788e-01, -1.6115e-01, -1.3101e-01,
          7.2758e-02, -8.4261e-01,  4.3193e-01, -5.2723e-02, -6.8167e-02,
          1.0022e+00, -9.2721e-02,  7.1727e-02, -1.4961e-01,  3.3231e-01,
         -4.2866e-01,  1.8314e-01,  3.8839e-01,  2.4105e-01, -2.4624e-01,
         -1.0893e-03,  2.7882e-02, -2.8714e-02,  8.9607e-02,  2.0518e-01,
          1.3776e-01, -1.1391e-01, -1.6325e-01,  4.9538e-01, -1.2830e-01,
         -6.9628e-02, -5.2208e-02, -1.5452e-01, -3.8287e-01,  2.9060e-01,
         -7.2306e-01, -4.4488e-02,  2.2103e-02,  3.7364e-02,  8.6819e-02,
          2.0576e-01,  3.9557e-01, -5.1237e-02, -3.1715e-01,  1.0678e-01,
          2.2878e-01,  3.4304e-02,  1.8025e-01, -3.7274e-01, -1.3906e-01,
         -4.7516e-02, -1.9967e-01,  1.8347e-01,  1.7822e-01,  1.7647e-01,
         -2.1175e-01, -2.

In [18]:
df.head()

Unnamed: 0,img_embedding,caption_text,caption_embedding,attention_mask,caption_tokens,img_id,filename
0,"[0.4301545321941376, 0.49581533670425415, -0.0...",Two men in green shirts are standing in a yard.,"[[0.3392859101295471, 0.11646018177270889, 0.1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, ...","[49406, 1237, 1656, 530, 1901, 5803, 631, 2862...",0,1000092795.jpg
1,"[0.29502591490745544, 0.3597133755683899, -0.2...",Several men in hard hats are operating a giant...,"[[0.3392859101295471, 0.11646018177270889, 0.1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[49406, 5560, 1656, 530, 1626, 9812, 631, 1221...",1,10002456.jpg
2,"[0.16267558932304382, 0.13509173691272736, 0.0...",A little girl in a pink dress going into a woo...,"[[0.3392859101295471, 0.11646018177270889, 0.1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[49406, 320, 1274, 1611, 530, 320, 3360, 2595,...",2,1000268201.jpg
3,"[0.28787654638290405, 0.18737243115901947, -0....",Man in blue shirt and jeans on ladder cleaning...,"[[0.3392859101295471, 0.11646018177270889, 0.1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, ...","[49406, 786, 530, 1746, 2523, 537, 10157, 525,...",3,1000344755.jpg
4,"[-0.17353691160678864, 0.33811840415000916, 0....","Two men, one in a gray shirt, one in a black s...","[[0.3392859101295471, 0.11646018177270889, 0.1...","[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ...","[49406, 1237, 1656, 267, 637, 530, 320, 7048, ...",4,1000366164.jpg


In [58]:
img_emb, cap, cap_emb, attention_mask, cap_tokens = df.iloc[0][
            [
                "img_embedding",
                "caption_text",
                "caption_embedding",
                "attention_mask",
                "caption_tokens",
            ]
        ]


In [35]:
torch.tensor(cap_emb[1]).shape

torch.Size([512])

In [46]:
import numpy as np

In [54]:
torch.tensor(np.array(cap_emb.tolist())).shape

torch.Size([77, 512])

In [42]:
print(cap_emb.shape)
print(cap_emb[0].shape)

(77,)
(512,)
