In [10]:
import os
import pickle
import pandas as pd
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from datasets import load_dataset

In [3]:
ds_flickr30k = load_dataset("lmms-lab/flickr30k")

ds_flickr30k

DatasetDict({
    test: Dataset({
        features: ['image', 'caption', 'sentids', 'img_id', 'filename'],
        num_rows: 31783
    })
})

In [4]:
def check_for_children(caption_list):
    keywords = ["child", "girl", "boy", "baby"]
    return any(
        keyword in caption.lower() for keyword in keywords for caption in caption_list
    )

In [5]:
ds_flickr30k["test"] = ds_flickr30k["test"].add_column(
    "has_children",
    [check_for_children(example["caption"]) for example in ds_flickr30k["test"]],
)

filter_ds_flickr30k = ds_flickr30k["test"].filter(
    lambda example: example["has_children"]
)
filter_ds_flickr30k = filter_ds_flickr30k.remove_columns(["sentids", "has_children"])

filter_ds_flickr30k

Dataset({
    features: ['image', 'caption', 'img_id', 'filename'],
    num_rows: 10361
})

In [6]:
filter_ds_flickr30k.save_to_disk("datasets/flickr30k_children.json")

Saving the dataset (3/3 shards): 100%|██████████| 10361/10361 [00:04<00:00, 2269.41 examples/s]


In [7]:
ds_nocaps = load_dataset("lmms-lab/NoCaps")

ds_nocaps

DatasetDict({
    validation: Dataset({
        features: ['image', 'image_coco_url', 'image_date_captured', 'image_file_name', 'image_height', 'image_width', 'image_id', 'image_license', 'image_open_images_id', 'annotations_ids', 'annotations_captions'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['image', 'image_coco_url', 'image_date_captured', 'image_file_name', 'image_height', 'image_width', 'image_id', 'image_license', 'image_open_images_id', 'annotations_ids', 'annotations_captions'],
        num_rows: 10600
    })
})

In [8]:
ds_nocaps["validation"] = ds_nocaps["validation"].add_column(
    "has_children",
    [
        check_for_children(example["annotations_captions"])
        for example in ds_nocaps["validation"]
    ],
)

filter_ds_nocaps = ds_nocaps["validation"].filter(
    lambda example: example["has_children"]
)
filter_ds_nocaps = filter_ds_nocaps.remove_columns(
    [
        "annotations_ids",
        "has_children",
        "image_open_images_id",
        "image_license",
        "image_width",
        "image_height",
        "image_date_captured",
        "image_coco_url",
    ]
)

filter_ds_nocaps = filter_ds_nocaps.rename_columns(
    {
        "annotations_captions": "caption",
        "image_file_name": "filename",
        "image_id": "img_id",
    }
)

filter_ds_nocaps

Dataset({
    features: ['image', 'filename', 'img_id', 'caption'],
    num_rows: 667
})

In [9]:
filter_ds_nocaps.save_to_disk("datasets/nocaps_children.json")

Saving the dataset (1/1 shards): 100%|██████████| 667/667 [00:00<00:00, 788.10 examples/s]


In [3]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.images = [os.path.join(data_dir, img) for img in os.listdir(data_dir)]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image, caption = loader(image_path)
        if self.transform:
            image = self.transform(image)
        return image, caption


def loader(path):
    image = Image.open(path)
    caption = data.loc[data["image"] == os.path.basename(path), "caption"].values[0]
    return image, caption


In [4]:
data_dir = "datasets/sin_dataset_img/images"
csv_file = "datasets/sin_dataset_img/captions.csv"

data = pd.read_csv(csv_file)

transform = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()])

ds_sin = CustomDataset(data_dir, transform=transform)

In [11]:
with open("datasets/sin_children.pkl", "wb") as f:
    pickle.dump(ds_sin, f)