In [None]:
from PIL import Image
import os
import torch
import numpy as np

from utils import show_image_caption
from utils import get_device_map

In [None]:
devices = [1, 5, 6, 7]
start_device = 'cuda:' + str(devices[0])

### Configs

In [None]:
# checkpoint = "Salesforce/blip2-opt-2.7b"
checkpoint = "Salesforce/blip2-flan-t5-xl"
result_file_path = '../results/coco_test_blip2.csv'
# cache_dir = "/mnt/nas2/kjh/huggingface_cache"
cache_dir = "../caches"
cache_pretrained_files_dir = os.path.join(cache_dir, "pretrained_files")
cache_dataset_dir = os.path.join(cache_dir, "datasets")
dtype = torch.float16
batch_size = 32
num_workers = 8
max_new_tokens = 50

### Processor

In [None]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir=cache_pretrained_files_dir,
)


### Model

In [None]:
from transformers import Blip2ForConditionalGeneration

device_map = get_device_map(checkpoint, devices)

model = Blip2ForConditionalGeneration.from_pretrained(
    checkpoint,
    cache_dir=cache_pretrained_files_dir,
    torch_dtype=dtype,
    # device_map='auto',
    device_map=device_map
)

### Inference Samples

In [None]:
image = '../datasets/cvpr-nice-val/val/215268662.jpg'
# caption_gt = 'Bicycles leaning against tree in wood Close up low angle view'
raw_image = Image.open(image).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to(start_device, dtype)

In [None]:
from utils import denormalize_image, plot_images

generated_ids = model.generate(**inputs, max_new_tokens=50)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

denormalized_image = denormalize_image(inputs['pixel_values'].cpu()[0], processor.image_processor.image_mean, processor.image_processor.image_std)
plot_images(denormalized_image, generated_text)