In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoImageProcessor
from PIL import Image
import requests

  from .autonotebook import tqdm as notebook_tqdm


In [3]:

# helper function to format input prompts
TASK2INSTRUCTION = {
    "caption": "画像を詳細に述べてください。",
    "tag": "与えられた単語を使って、画像を詳細に述べてください。",
    "vqa": "与えられた画像を下に、質問に答えてください。",
}


def build_prompt(task="caption", input=None, sep="\n\n### "):
    assert (
        task in TASK2INSTRUCTION
    ), f"Please choose from {list(TASK2INSTRUCTION.keys())}"
    if task in ["tag", "vqa"]:
        assert input is not None, "Please fill in `input`!"
        if task == "tag" and isinstance(input, list):
            input = "、".join(input)
    else:
        assert input is None, f"`{task}` mode doesn't support to input questions"
    sys_msg = "以下は、タスクを説明する指示と、文脈のある入力の組み合わせです。要求を適切に満たす応答を書きなさい。"
    p = sys_msg
    roles = ["指示", "応答"]
    instruction = TASK2INSTRUCTION[task]
    msgs = [": \n" + instruction, ": \n"]
    if input:
        roles.insert(1, "入力")
        msgs.insert(1, ": \n" + input)
    for role, msg in zip(roles, msgs):
        p += sep + role + msg
    return p

# load model
device = "mps"
model = AutoModelForVision2Seq.from_pretrained("stabilityai/japanese-stable-vlm", trust_remote_code=True, torch_dtype=torch.float16, variant= "fp16")
processor = AutoImageProcessor.from_pretrained("stabilityai/japanese-stable-vlm")
tokenizer = AutoTokenizer.from_pretrained("stabilityai/japanese-stable-vlm")
model.to(device)
processor.to(device)



Loading checkpoint shards:  50%|█████     | 2/4 [00:24<00:24, 12.44s/it]


KeyboardInterrupt: 

In [None]:

print('generate image token')
# prepare inputs
url = "https://images.unsplash.com/photo-1582538885592-e70a5d7ab3d3?ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D&auto=format&fit=crop&w=1770&q=80"
image = Image.open(requests.get(url, stream=True).raw).convert("RGB")
prompt = build_prompt(task="caption")
# prompt = build_prompt(task="tag", input=["河津桜", "青空"])
# prompt = build_prompt(task="vqa", input="季節はいつですか？")

inputs = processor(images=image, return_tensors="pt")
text_encoding = tokenizer(prompt, add_special_tokens=False, return_tensors="pt")
inputs.update(text_encoding)

print('generate answer')
# generate
outputs = model.generate(
    **inputs.to(device, dtype=model.dtype),
    do_sample=False,
    num_beams=5,
    max_new_tokens=128,
    min_length=1,
    repetition_penalty=1.5,
)
generated_text = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)
# 桜越しの東京スカイツリー