### Import dependencies

In [6]:
import torch
from PIL import Image

from lavis.models import load_model_and_preprocess

### Move device to GPU if available

In [7]:
# setup device to use
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Load BLIP large captioning model finetuned on COCO

In [8]:
# we associate a model with its preprocessors to make it easier for inference.
model, vis_processors, _ = load_model_and_preprocess(
    name="blip_caption", model_type="large_coco", is_eval=True, device=device
)
# uncomment to use base model
# model, vis_processors, _ = load_model_and_preprocess(
#     name="blip_caption", model_type="base_coco", is_eval=True, device=device
# )
vis_processors.keys()

dict_keys(['train', 'eval'])

Load all images and generate captions
-------------------------------------

In [13]:
import os
from tqdm import tqdm 
from PIL import ImageFile
from PIL import Image
ImageFile.LOAD_TRUNCATED_IMAGES = True

kym_memes = "D:/Memes2023_splitted_resized/finetuning"
reddit_memes = "D:/Memes2022Final2_resized"

kym_captions = 'C:/Users/Murgi/Documents/GitHub/meme_research/outputs/captions/kym_captions.txt'
reddit_captions = 'C:/Users/Murgi/Documents/GitHub/meme_research/outputs/captions/reddit_captions.txt'

def caption_dataset(dataset_path, output_path, kym):
    if kym:
        total = 23082
    else:
        total = 955593

    checkpoint = True
    with open(output_path, "w") as f:
        for file in tqdm(os.listdir(dataset_path), total=len(os.listdir(dataset_path))):
            if checkpoint:
                try:
                    img_path = os.path.join(dataset_path, file)
                    image = Image.open(img_path).convert("RGB")
                    image = vis_processors["eval"](image).unsqueeze(0).to(device)
                    caption = " ".join(model.generate({"image": image}))
                    f.write(img_path + "\t" + caption + "\n")
                except KeyboardInterrupt:
                    return
                except Exception as e:
                    print(e)
                    print("Error captioning image: " + img_path)
                    continue

# print('Captioning kym dataset...')
# caption_dataset(kym_memes, kym_captions, True)
print('Captioning reddit dataset...')
caption_dataset(reddit_memes, reddit_captions, False)

Captioning reddit dataset...


 67%|██████▋   | 643093/955593 [82:58:42<40:19:19,  2.15it/s] 
