In [17]:
import torch
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import Blip2Processor, Blip2ForConditionalGeneration, Blip2Config
from PIL import Image
import requests

In [18]:
CHECK_POINT = "Salesforce/blip2-flan-t5-xxl"

In [19]:
config = Blip2Config.from_pretrained(CHECK_POINT)
with init_empty_weights():
    model = Blip2ForConditionalGeneration(config)
    model.tie_weights()

In [20]:
device_map = infer_auto_device_map(model, no_split_module_classes=["T5Block"], dtype="float32", max_memory={0: "10GiB", 1: "10GiB", 2: "10Gib", 3: "10Gib", 4: "9Gib", 5: "10Gib", 6: "10Gib", 7: "10Gib"})

In [21]:
device_map

OrderedDict([('query_tokens', 0),
             ('vision_model', 0),
             ('qformer', 0),
             ('language_projection', 0),
             ('language_model.shared', 0),
             ('language_model.decoder.embed_tokens', 0),
             ('language_model.encoder.embed_tokens', 0),
             ('language_model.encoder.block.0', 0),
             ('language_model.encoder.block.1', 0),
             ('language_model.encoder.block.2', 0),
             ('language_model.encoder.block.3', 0),
             ('language_model.encoder.block.4', 0),
             ('language_model.encoder.block.5', 0),
             ('language_model.encoder.block.6', 1),
             ('language_model.encoder.block.7', 1),
             ('language_model.encoder.block.8', 1),
             ('language_model.encoder.block.9', 1),
             ('language_model.encoder.block.10', 1),
             ('language_model.encoder.block.11', 1),
             ('language_model.encoder.block.12', 1),
             ('language_mo

In [22]:
device_map["language_projection"] = device_map["language_model.shared"] = device_map["language_model.decoder.embed_tokens"] = device_map["language_model.encoder.embed_tokens"] = device_map["language_model.lm_head"]

In [23]:
device_map

OrderedDict([('query_tokens', 0),
             ('vision_model', 0),
             ('qformer', 0),
             ('language_projection', 4),
             ('language_model.shared', 4),
             ('language_model.decoder.embed_tokens', 4),
             ('language_model.encoder.embed_tokens', 4),
             ('language_model.encoder.block.0', 0),
             ('language_model.encoder.block.1', 0),
             ('language_model.encoder.block.2', 0),
             ('language_model.encoder.block.3', 0),
             ('language_model.encoder.block.4', 0),
             ('language_model.encoder.block.5', 0),
             ('language_model.encoder.block.6', 1),
             ('language_model.encoder.block.7', 1),
             ('language_model.encoder.block.8', 1),
             ('language_model.encoder.block.9', 1),
             ('language_model.encoder.block.10', 1),
             ('language_model.encoder.block.11', 1),
             ('language_model.encoder.block.12', 1),
             ('language_mo

In [24]:
model = Blip2ForConditionalGeneration.from_pretrained(CHECK_POINT, device_map=device_map)

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

In [25]:
processor = Blip2Processor.from_pretrained(CHECK_POINT)



In [26]:
img_url = 'https://gker-love.oss-cn-beijing.aliyuncs.com/Naive/messages/6e6c01ed-29bb-447d-8790-4f068d0b6e8a/da6a1872-5d75-478d-a5ac-8e5e24864df4.jpeg' 
raw_image = Image.open(requests.get(img_url, stream=True).raw).convert('RGB')

question = "What do you see in the image?"
inputs = processor(raw_image, question, return_tensors="pt").to("cuda")

out = model.generate(**inputs)
print(processor.decode(out[0], skip_special_tokens=True))



a clock tower
