In [None]:
from PIL import Image
import torch
import numpy as np

from utils import show_image_caption

In [None]:
device = 'cuda'
print('Use ', torch.cuda.get_device_name(device))

In [None]:
checkpoint = "Salesforce/blip2-opt-2.7b"
dtype = torch.float32

### Processor

In [None]:
from transformers import Blip2Processor

processor = Blip2Processor.from_pretrained(
    checkpoint,
    cache_dir='../pretrained_files',
)


### Model

In [None]:
# device_map = "auto" 로 했을 때 다음과 같았음
device_map = {
    "query_tokens": 0,
    "vision_model.embeddings": 0,
    "vision_model.encoder.layers.0": 0,
    "vision_model.encoder.layers.1": 0,
    "vision_model.encoder.layers.2": 0,
    "vision_model.encoder.layers.3": 0,
    "vision_model.encoder.layers.4": 0,
    "vision_model.encoder.layers.5": 0,
    "vision_model.encoder.layers.6": 0,
    "vision_model.encoder.layers.7": 0,
    "vision_model.encoder.layers.8": 0,
    "vision_model.encoder.layers.9": 0,
    "vision_model.encoder.layers.10": 0,
    "vision_model.encoder.layers.11": 0,
    "vision_model.encoder.layers.12": 0,
    "vision_model.encoder.layers.13": 0,
    "vision_model.encoder.layers.14": 0,
    "vision_model.encoder.layers.15": 0,
    "vision_model.encoder.layers.16": 0,
    "vision_model.encoder.layers.17": 0,
    "vision_model.encoder.layers.18": 0,
    "vision_model.encoder.layers.19": 0,
    "vision_model.encoder.layers.20": 0,
    "vision_model.encoder.layers.21": 0,
    "vision_model.encoder.layers.22": 0,
    "vision_model.encoder.layers.23": 0,
    "vision_model.encoder.layers.24": 0,
    "vision_model.encoder.layers.25": 0,
    "vision_model.encoder.layers.26": 0,
    "vision_model.encoder.layers.27": 0,
    "vision_model.encoder.layers.28": 0,
    "vision_model.encoder.layers.29": 0,
    "vision_model.encoder.layers.30": 0,
    "vision_model.encoder.layers.31": 0,
    "vision_model.encoder.layers.32": 0,
    "vision_model.encoder.layers.33": 0,
    "vision_model.encoder.layers.34": 0,
    "vision_model.encoder.layers.35": 0,
    "vision_model.encoder.layers.36": 0,
    "vision_model.encoder.layers.38": 1,
    "vision_model.post_layernorm": 1,
    "qformer": 1,
    "language_projection": 1,
    "language_model.model.decoder.embed_tokens": 1,
    "language_model.lm_head": 1,
    "language_model.model.decoder.embed_positions": 1,
    "language_model.model.decoder.final_layer_norm": 1,
    "language_model.model.decoder.layers.0": 1,
    "language_model.model.decoder.layers.1": 1,
    "language_model.model.decoder.layers.2": 1,
    "language_model.model.decoder.layers.3": 1,
    "language_model.model.decoder.layers.4": 1,
    "language_model.model.decoder.layers.5": 1,
    "language_model.model.decoder.layers.6": 1,
    "language_model.model.decoder.layers.7": 1,
    "language_model.model.decoder.layers.8": 1,
    "language_model.model.decoder.layers.9": 2,
    "language_model.model.decoder.layers.10": 2,
    "language_model.model.decoder.layers.11": 2,
    "language_model.model.decoder.layers.12": 2,
    "language_model.model.decoder.layers.13": 2,
    "language_model.model.decoder.layers.14": 2,
    "language_model.model.decoder.layers.15": 2,
    "language_model.model.decoder.layers.16": 2,
    "language_model.model.decoder.layers.17": 2,
    "language_model.model.decoder.layers.18": 2,
    "language_model.model.decoder.layers.19": 2,
    "language_model.model.decoder.layers.20": 2,
    "language_model.model.decoder.layers.21": 2,
    "language_model.model.decoder.layers.22": 3,
    "language_model.model.decoder.layers.23": 3,
    "language_model.model.decoder.layers.24": 3,
    "language_model.model.decoder.layers.25": 3,
    "language_model.model.decoder.layers.26": 3,
    "language_model.model.decoder.layers.27": 3,
    "language_model.model.decoder.layers.28": 3,
    "language_model.model.decoder.layers.29": 3,
    "language_model.model.decoder.layers.30": 3,
    "language_model.model.decoder.layers.31": 3,
    "vision_model.encoder.layers.37": 1
}

In [None]:
from transformers import Blip2ForConditionalGeneration

model = Blip2ForConditionalGeneration.from_pretrained(
    checkpoint,
    cache_dir='../pretrained_files',
    torch_dtype=dtype,
    # device_map='auto',
    device_map=device_map
)
# model.parallelize(device_map)

In [None]:
image = '../datasets/cvpr-nice-val/val/215268662.jpg'
raw_image = Image.open(image).convert('RGB')

inputs = processor(raw_image, return_tensors="pt").to(device, dtype)

In [None]:
def denormalize_image(normalized_image, mean, std):
    image = normalized_image.numpy().transpose(1, 2, 0)
    image = std * image + mean
    image = np.clip(image, 0, 1)
    
    return image

In [None]:
generated_ids = model.generate(**inputs, max_new_tokens=20)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
print(generated_text)

denormalized_image = denormalize_image(inputs['pixel_values'].cpu()[0], processor.image_processor.image_mean, processor.image_processor.image_std)
show_image_caption(denormalized_image, [generated_text], show_fig=True, save_path='sample.png')
show_image_caption(raw_image, [generated_text], show_fig=True, save_path='sample.png')