In [43]:
import os
import wandb
import pandas as pd
from PIL import Image
from torch.utils.data import Dataset
from datasets import load_dataset, concatenate_datasets
from transformers import BlipProcessor, BlipForConditionalGeneration
from bert_score import score
from nltk.tokenize import word_tokenize
from nltk.translate.meteor_score import meteor_score
from pycocoevalcap.cider.cider import Cider
from collections import Counter

In [25]:
ds_flickr30k = load_dataset("lmms-lab/flickr30k")

ds_flickr30k

DatasetDict({
    test: Dataset({
        features: ['image', 'caption', 'sentids', 'img_id', 'filename'],
        num_rows: 31783
    })
})

In [26]:
def check_for_children(caption_list):
    keywords = ["child", "girl", "boy", "baby"]
    return any(
        keyword in caption.lower() for keyword in keywords for caption in caption_list
    )

In [27]:
ds_flickr30k["test"] = ds_flickr30k["test"].add_column(
    "has_children",
    [check_for_children(example["caption"]) for example in ds_flickr30k["test"]],
)

filter_ds_flickr30k = ds_flickr30k["test"].filter(
    lambda example: example["has_children"]
)
filter_ds_flickr30k = filter_ds_flickr30k.remove_columns(
    ["sentids", "has_children", "img_id", "filename"]
)

filter_ds_flickr30k

Dataset({
    features: ['image', 'caption'],
    num_rows: 10361
})

In [28]:
ds_nocaps = load_dataset("lmms-lab/NoCaps")

ds_nocaps

DatasetDict({
    validation: Dataset({
        features: ['image', 'image_coco_url', 'image_date_captured', 'image_file_name', 'image_height', 'image_width', 'image_id', 'image_license', 'image_open_images_id', 'annotations_ids', 'annotations_captions'],
        num_rows: 4500
    })
    test: Dataset({
        features: ['image', 'image_coco_url', 'image_date_captured', 'image_file_name', 'image_height', 'image_width', 'image_id', 'image_license', 'image_open_images_id', 'annotations_ids', 'annotations_captions'],
        num_rows: 10600
    })
})

In [29]:
ds_nocaps["validation"] = ds_nocaps["validation"].add_column(
    "has_children",
    [
        check_for_children(example["annotations_captions"])
        for example in ds_nocaps["validation"]
    ],
)

filter_ds_nocaps = ds_nocaps["validation"].filter(
    lambda example: example["has_children"]
)
filter_ds_nocaps = filter_ds_nocaps.remove_columns(
    [
        "annotations_ids",
        "has_children",
        "image_open_images_id",
        "image_license",
        "image_width",
        "image_height",
        "image_date_captured",
        "image_coco_url",
        "image_file_name",
        "image_id",
    ]
)

filter_ds_nocaps = filter_ds_nocaps.rename_columns({"annotations_captions": "caption"})

filter_ds_nocaps

Dataset({
    features: ['image', 'caption'],
    num_rows: 667
})

In [30]:
ds_train = concatenate_datasets([filter_ds_flickr30k, filter_ds_nocaps])

ds_train

Dataset({
    features: ['image', 'caption'],
    num_rows: 11028
})

In [32]:
ds_train.save_to_disk("datasets/ds_train")

Saving the dataset (0/4 shards):   0%|          | 0/11028 [00:00<?, ? examples/s]

In [33]:
class CustomDataset(Dataset):
    def __init__(self, data_dir, data, transform=None):
        self.data_dir = data_dir
        self.data = data
        self.transform = transform
        self.images = [os.path.join(data_dir, img) for img in os.listdir(data_dir)]

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image, caption = loader(image_path, self.data)
        if self.transform:
            image = self.transform(image)
        return image, caption


def loader(path, data):
    image = Image.open(path)
    caption = data.loc[data["image"] == os.path.basename(path), "caption"].values[0]
    return image, caption

In [34]:
data_dir = "datasets/sin_dataset_img"

data = pd.read_csv(f"{data_dir}/captions.csv")

ds_sin = CustomDataset(f"{data_dir}/images", data)

In [35]:
def metric_cider(predicted_captions, reference_captions):
    predicted_dict = {i: [caption] for i, caption in enumerate(predicted_captions)}
    reference_dict = {i: captions for i, captions in enumerate(reference_captions)}

    cider_scorer = Cider()

    cider_score, _ = cider_scorer.compute_score(reference_dict, predicted_dict)

    return cider_score

In [36]:
def metric_bertscore(predicted_captions, reference_captions):
    P, R, F1 = score(predicted_captions, reference_captions, lang="en", verbose=False)

    BERTScore_F1 = F1.mean().item()

    return BERTScore_F1

In [37]:
def metric_meteor(predicted_captions, reference_captions):
    predicted_captions_tokenized = [
        word_tokenize(caption) for caption in predicted_captions
    ]
    reference_captions_tokenized = [
        [word_tokenize(caption) for caption in ref] for ref in reference_captions
    ]

    scores = [
        meteor_score(ref, gen)
        for ref, gen in zip(reference_captions_tokenized, predicted_captions_tokenized)
    ]
    average_meteor = sum(scores) / len(scores)

    return average_meteor

In [38]:
def metric_spice(predicted_captions, reference_captions):
    spice_scores = []

    for gen_caption, ref_captions in zip(predicted_captions, reference_captions):
        gen_tokens = word_tokenize(gen_caption.lower())

        ref_tokens = [word_tokenize(ref.lower()) for ref in ref_captions]

        gen_counter = Counter(gen_tokens)
        ref_counters = [Counter(ref) for ref in ref_tokens]

        precisions = []
        recalls = []
        for ref_counter in ref_counters:
            common = gen_counter & ref_counter
            precisions.append(sum(common.values()) / len(gen_tokens))
            recalls.append(sum(common.values()) / len(ref_counter))

        precision = sum(precisions) / len(ref_counters)
        recall = sum(recalls) / len(ref_counters)
        if precision + recall > 0:
            spice_score = (precision * recall) / (precision + recall)
        else:
            spice_score = 0.0

        spice_scores.append(spice_score)

    average_spice_score = sum(spice_scores) / len(spice_scores)

    return average_spice_score

In [44]:
models = {
    "blip_base": "Salesforce/blip-image-captioning-base",
    "blip_large": "Salesforce/blip-image-captioning-large",
    "pic2story": "abhijit2111/Pic2Story",
}

In [15]:
for name, model_name in models:
    wandb.init(project="child_diary", group=name, job_type="base")

    processor = BlipProcessor.from_pretrained(model_name)
    model = BlipForConditionalGeneration.from_pretrained(model_name)

    predicted_captions = []
    reference_captions = []

    for image, captions in ds_sin:
        inputs = processor(image, return_tensors="pt").to("cuda")
        out = model.generate(**inputs)
        predicted_captions.append(processor.decode(out[0], skip_special_tokens=True))
        reference_captions.append([captions])

    wandb.log(
        {
            "CIDEr": metric_cider(predicted_captions, reference_captions),
            "BERTScore_F1": metric_bertscore(predicted_captions, reference_captions),
            "METEOR": metric_meteor(predicted_captions, reference_captions),
            "SPICE": metric_spice(predicted_captions, reference_captions),
        }
    )

    wandb.finish()