In [None]:
from PIL import Image
import torch
import numpy as np

from utils import show_image_caption
from utils import get_device_map

In [None]:
devices = [0, 4, 5]
start_device = 'cuda:' + str(devices[0])

### Configs

In [None]:
# checkpoint = "Salesforce/blip2-opt-2.7b"
checkpoint = "Salesforce/blip2-flan-t5-xxl"
result_file_path = '../results/coco_test_blip2.csv'
cache_dir = "/mnt/nas2/kjh/huggingface_cache"
dtype = torch.float16
batch_size = 32
num_workers = 8
max_new_tokens = 50

### Processor

In [None]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir=cache_dir,
)


### Model

In [None]:
from transformers import Blip2ForConditionalGeneration

device_map = get_device_map(checkpoint, devices)

model = Blip2ForConditionalGeneration.from_pretrained(
    checkpoint,
    cache_dir=cache_dir,
    torch_dtype=dtype,
    # device_map='auto',
    device_map=device_map
)

In [None]:
def denormalize_image(normalized_image, mean, std):
    image = normalized_image.numpy().transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    return image

### Inference Samples

In [None]:
image = '../datasets/cvpr-nice-val/val/215268662.jpg'
raw_image = Image.open(image).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to(start_device, dtype)

In [None]:
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

denormalized_image = denormalize_image(inputs['pixel_values'].cpu()[0], processor.image_processor.image_mean, processor.image_processor.image_std)
# show_image_caption(denormalized_image, [generated_text], show_fig=True, save_path='sample.png')
show_image_caption(raw_image, [generated_text], show_fig=True)