#### Large RAM is required to load the larger models. Running on GPU can optimize inference speed.

In [None]:
import torch
import pandas as pd
from PIL import Image
from lavis.models import load_model_and_preprocess
from glob import glob
from tqdm import tqdm

In [3]:
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"

#### Load pretrained/finetuned BLIP2 captioning model

In [13]:
# we associate a model with its preprocessors to make it easier for inference.
model, vis_processors, _ = load_model_and_preprocess(
    name="blip2_t5", model_type="pretrain_flant5xxl", is_eval=True, device=device
)

# Other available models:
# 
# model, vis_processors, _ = load_model_and_preprocess(
#     name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
# )
# model, vis_processors, _ = load_model_and_preprocess(
#     name="blip2_opt", model_type="pretrain_opt6.7b", is_eval=True, device=device
# )
# model, vis_processors, _ = load_model_and_preprocess(
#     name="blip2_opt", model_type="caption_coco_opt2.7b", is_eval=True, device=device
# )
# model, vis_processors, _ = load_model_and_preprocess(
#     name="blip2_opt", model_type="caption_coco_opt6.7b", is_eval=True, device=device
# )
#
# model, vis_processors, _ = load_model_and_preprocess(
#     name="blip2_t5", model_type="pretrain_flant5xl", is_eval=True, device=device
# )
#
# model, vis_processors, _ = load_model_and_preprocess(
#     name="blip2_t5", model_type="caption_coco_flant5xl", is_eval=True, device=device
# )

vis_processors.keys()

dict_keys(['train', 'eval'])

In [ ]:
images = sorted(glob('../img_captioning_CNNRNN/images/train/*.jpg'))
captions = []
i=0
for image in tqdm(images):
    i+=1
    if i % 100 == 0:
        df = pd.DataFrame({'image': images[:i], 'caption': captions})
        df.to_csv(f'../captions_{i}.csv', index=False)
    raw_image = Image.open(image).convert('RGB')
    image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
    caption = model.generate({"image": image})
    captions.append(caption)
    
df = pd.DataFrame({'image': images, 'caption': captions})
df.to_csv('../captions.csv', index=False)