In [1]:
from typing import Iterable
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import numpy as np
from llava.conversation import Conversation, SeparatorStyle, conv_templates
from llava.utils import disable_torch_init
from transformers import CLIPImageProcessor, StoppingCriteria

from PIL import Image

import os
from PIL import Image
from io import BytesIO

import http.server
import pickle
from functools import partial


DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"


CONVERSATION_PROMPT = Conversation(
    system="You are an assistant that is able to understand the visual content that the user provides. "
    "You answer questions about the visual content in a short and concise manner.",
    roles=("Human", "Assistant"),
    messages=(),
    offset=2,
    sep_style=SeparatorStyle.SINGLE,
    sep="###",
)


params_path = "../llava-weights/"
# load model
disable_torch_init()
tokenizer = AutoTokenizer.from_pretrained(params_path)
model = AutoModelForCausalLM.from_pretrained(
    params_path, torch_dtype=torch.float16
).cuda()
image_processor = CLIPImageProcessor.from_pretrained(
    model.config.mm_vision_tower, torch_dtype=torch.float16
)

mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
    tokenizer.add_tokens(
        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
    )

vision_tower = model.model.vision_tower[0]
vision_tower.to(device="cuda", dtype=torch.float16)
vision_config = vision_tower.config
vision_config.im_patch_token = tokenizer.convert_tokens_to_ids(
    [DEFAULT_IMAGE_PATCH_TOKEN]
)[0]
vision_config.use_im_start_end = mm_use_im_start_end
if mm_use_im_start_end:
    (
        vision_config.im_start_token,
        vision_config.im_end_token,
    ) = tokenizer.convert_tokens_to_ids(
        [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN]
    )
image_token_len = (vision_config.image_size // vision_config.patch_size) ** 2

if mm_use_im_start_end:
    image_tokens = (
        DEFAULT_IM_START_TOKEN
        + DEFAULT_IMAGE_PATCH_TOKEN * image_token_len
        + DEFAULT_IM_END_TOKEN
    )
else:
    image_tokens = DEFAULT_IMAGE_PATCH_TOKEN * image_token_len


  from .autonotebook import tqdm as notebook_tqdm
Some weights of the model checkpoint at openai/clip-vit-large-patch14 were not used when initializing CLIPVisionModel: ['text_model.encoder.layers.11.self_attn.q_proj.weight', 'text_model.encoder.layers.8.self_attn.k_proj.weight', 'text_model.encoder.layers.10.self_attn.v_proj.weight', 'text_model.encoder.layers.9.layer_norm1.weight', 'text_model.encoder.layers.6.layer_norm2.weight', 'text_model.encoder.layers.5.self_attn.q_proj.weight', 'text_model.encoder.layers.8.layer_norm2.bias', 'text_model.encoder.layers.4.self_attn.q_proj.weight', 'text_model.encoder.layers.9.mlp.fc1.bias', 'text_model.encoder.layers.10.layer_norm1.weight', 'text_model.encoder.layers.4.self_attn.k_proj.weight', 'text_model.encoder.layers.11.layer_norm1.bias', 'text_model.encoder.layers.1.self_attn.v_proj.bias', 'text_model.encoder.layers.2.self_attn.k_proj.bias', 'text_model.encoder.layers.5.self_attn.v_proj.weight', 'text_model.encoder.layers.3.mlp.fc2.weight',

In [5]:
# os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
PROMPT = "You are an assistant that is able to understand the visual content that the user provides. "
"You answer questions about the visual content in a short and concise manner.\n###Human: "

tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

with torch.inference_mode():
    images = [Image.open("monkey.jpg"), Image.open("monkey.jpg")]
    queries = [
        ["What", "Who is washing the dishes?"],
        ["What is the monkey doing?", "What color's the monkey?"],
    ]

    assert len(images) == len(queries)
    assert np.all([len(queries[0]) == len(q) for q in queries])

    queries = np.array(queries)  # (batch_size, num_queries_per_image)

    # preprocess images
    images = image_processor(images, return_tensors="pt")["pixel_values"]
    images = images.to("cuda", dtype=torch.float16)

    # first, get the activations for the image tokens
    initial_prompts = [PROMPT + image_tokens + "\n" for _ in range(len(images))]
    initial_input_ids = tokenizer(initial_prompts, return_tensors="pt").input_ids.cuda()
    initial_out = model(initial_input_ids, images=images, use_cache=True)
    initial_key_values = initial_out.past_key_values
    attention_mask = (initial_input_ids != tokenizer.pad_token_id).long()

    # broadcast the key values across the queries
    # becomes shape (batch_size * num_queries_per_image, ...)
    initial_key_values = [
        [
            x.unsqueeze(1)
            .expand(-1, queries.shape[1], -1, -1, -1)
            .reshape(-1, *x.shape[1:])
            for x in y
        ]
        for y in initial_key_values
    ]

    # broadcast the attention mask across the queries
    # becomes shape (batch_size * num_queries_per_image, ...)
    attention_mask = (
        attention_mask.unsqueeze(1)
        .expand(-1, queries.shape[1], -1)
        .reshape(-1, attention_mask.shape[1])
    )

    # flatten queries into one big batch
    queries_flat = queries.reshape(-1)  # (batch_size * num_queries_per_image)

    # prepare inputs for the queries
    prompts = [q + "###" for q in queries_flat]
    input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.cuda()
    print(input_ids)

    # stop upon seeing any of these tokens
    stop_tokens = torch.as_tensor(
        tokenizer.convert_tokens_to_ids(["▁###", "##", "#"]),
        dtype=torch.long,
        device="cuda",
    )

    # generation loop
    output_ids = []
    key_values = initial_key_values
    finished = torch.zeros(input_ids.shape[0], dtype=torch.bool, device="cuda")
    for i in range(50):
        attention_mask = torch.cat(
            [attention_mask, input_ids != tokenizer.pad_token_id], dim=-1
        )

        # create position_ids on the fly for batch generation
        # position_ids = attention_mask.long().cumsum(-1) - 1
        # position_ids.masked_fill_(attention_mask == 0, 1)
        # position_ids = position_ids[:, -input_ids.shape[1]]
        # print(position_ids)

        # print(attention_mask[0, -10:])
        out = model(
            input_ids=input_ids,
            use_cache=True,
            past_key_values=key_values,
            attention_mask=attention_mask,
            # position_ids=position_ids,
        )
        key_values = out.past_key_values
        next_tokens = torch.argmax(out.logits[:, -1], dim=-1)

        next_tokens = torch.where(finished, tokenizer.pad_token_id, next_tokens)

        finished = finished | (next_tokens.unsqueeze(-1) == stop_tokens).any(dim=-1)

        if finished.all():
            break

        output_ids.append(next_tokens)
        input_ids = next_tokens.unsqueeze(-1)

output_ids = torch.stack(output_ids, dim=-1)
outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

# clean outputs
outputs_clean = []
for output in outputs:
    for pattern in ["###", "##", "#"]:
        if pattern in output:
            output = output.split(pattern)[0]

    if "Assistant:" in output:
        output = output.split("Assistant:")[1]
    outputs_clean.append(output.strip())

# reshape outputs back to (batch_size, num_queries_per_image)
outputs_clean = np.array(outputs_clean).reshape(queries.shape)

print(outputs_clean.tolist())


tensor([[    2,     2,     2,     2,     2,     2,     2,     1,  1724,  2277,
         29937],
        [    1, 11644,   338,   471,  2790,   278,   270, 17006, 29973,  2277,
         29937],
        [    2,     1,  1724,   338,   278,  1601,  1989,  2599, 29973,  2277,
         29937],
        [    1,  1724,  2927, 29915, 29879,   278,  1601,  1989, 29973,  2277,
         29937]], device='cuda:0')
[['The image shows a monkey inside a kitchen, washing dishes in a sink. The monkey is surrounded by various kitchen items, including multiple cups and bowls. Some of the cups are placed near the sink,', 'A monkey is washing the dishes in the image.'], ['The monkey is washing dishes in a kitchen sink, specifically washing a blue bowl.', 'The monkey is brown.']]


In [37]:
input_ids = tokenizer(prompts, return_tensors="pt", padding=True).input_ids.cpu()
print(input_ids[0])
print(input_ids[1])
print(tokenizer.decode(input_ids[0, :-1], skip_special_tokens=False))
print(tokenizer.decode(input_ids[1], skip_special_tokens=False))

tensor([    1,  1724,   338,   278,  1601,  1989,  2599, 29973,  2277, 29937,
            2])
tensor([    1,  1724,  2927, 29915, 29879,   278,  1601,  1989, 29973,  2277,
        29937])
 What is the monkey doing?###
 What color's the monkey?###


In [5]:
images = [Image.open("monkey.jpg")]
queries = [["What is the color of the monkey?", "What is the monkey doing?"]]

with torch.inference_mode():
    queries = np.array(queries)
    prompts = []
    for query_set in queries:
        message = image_tokens + "\n" + query_set[0]
        conv = CONVERSATION_PROMPT.copy()
        conv.append_message(conv.roles[0], message)
        prompts.append(conv.get_prompt())

    inputs = tokenizer(prompts, return_tensors="pt", padding=True)
    image_tensors = (
        image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
        .half()
        .cuda()
    )
    input_ids = inputs.input_ids.cuda()

    for i in range(1, queries.shape[1]):
        output_ids = model.generate(
            input_ids,
            images=image_tensors,
            do_sample=True,
            temperature=0.7,
            max_new_tokens=128,
            stopping_criteria=[stopping_criteria(tokenizer, input_ids.shape[0])],
        )

    output_ids = output_ids[:, input_ids.shape[1] :]
    # input_token_len = input_ids.shape[1]
    # n_diff_input_output = (
    #     (input_ids != output_ids[:, :input_token_len]).sum().item()
    # )
    # if n_diff_input_output > 0:
    #     print(
    #         f"[Warning] {n_diff_input_output} output_ids are not the same as the input_ids"
    #     )
    outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)

    outputs_clean = []
    for output in outputs:
        for pattern in ["###", "##", "#"]:
            if pattern in output:
                output = output.split(pattern)[0]
        if "Assistant:" in output:
            output = output.split("Assistant:")[1]
        outputs_clean.append(output.strip())

    # try:
    #     index = outputs.index(conv.sep)
    # except ValueError:
    #     outputs += conv.sep
    #     index = outputs.index(conv.sep)

    print(outputs_clean)

NameError: name 'stopping_criteria' is not defined