In [1]:
import os
import random
import json
from datasets import load_dataset, concatenate_datasets

In [2]:
def save_pil_image(pil_img, save_dir, idx, ext="jpg"):
    os.makedirs(save_dir, exist_ok=True)
    if pil_img.mode == "RGBA":
        pil_img = pil_img.convert("RGB")
    path = os.path.join(save_dir, f"img_{idx}.{ext}")
    pil_img.save(path)
    return path

In [3]:
# Literal dataset: Flickr8k
def load_flickr8k(n_samples=6027):
    print("Downloading Flickr8k...")
    flickr = load_dataset("jxie/flickr8k", split="train+validation+test")

    indices = random.sample(range(len(flickr)), n_samples)
    flickr_sampled = flickr.select(indices)

    flickr_sampled = flickr_sampled.map(lambda x: {
        "image": x["image"],
        "caption": x["caption_0"],
        "label": 0 # Literal
    })
    return flickr_sampled

# Figurative dataset: V-FLUTE
def load_vflute():
    print("Downloading V-FLUTE...")
    vflute = load_dataset("ColumbiaNLP/V-FLUTE", split="train+validation+test") # Size: 6027
    
    vflute = vflute.map(lambda x: {
        "image": x["image"],
        "caption": x["claim"],
        "label": 1 # Figurative
    })
    return vflute

In [4]:
flickr = load_flickr8k()
vflute = load_vflute()

print("Combining dataset...")
combined = concatenate_datasets([flickr, vflute])
combined = combined.shuffle(seed=42)

# Images paths
save_dir_flickr = "images/flickr8k"
save_dir_vflute = "images/vflute"
os.makedirs(save_dir_flickr, exist_ok=True)
os.makedirs(save_dir_vflute, exist_ok=True)

# Save images and create JSONL with paths
with open("combined_dataset.jsonl", "w", encoding="utf-8") as f_out:
    for i, item in enumerate(combined):
        pil_img = item["image"] # PIL image object

        if item["label"] == 0:
            path = save_pil_image(pil_img, save_dir_flickr, i)
        else:
            path = save_pil_image(pil_img, save_dir_vflute, i)

        record = {
            "image_path": path,
            "caption": item["caption"],
            "label": item["label"]
        }
        f_out.write(json.dumps(record) + "\n")

Downloading Flickr8k...


Map:   0%|          | 0/6027 [00:00<?, ? examples/s]

Downloading V-FLUTE...
Combining dataset...
