In [1]:
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from PIL import Image

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = 'BAAI/Bunny-v1_1-Llama-3-8B-V' # or 'BAAI/Bunny-Llama-3-8B-V' or 'BAAI/Bunny-v1_1-4B' or 'BAAI/Bunny-v1_0-4B' or 'BAAI/Bunny-v1_0-3B' or 'BAAI/Bunny-v1_0-3B-zh' or 'BAAI/Bunny-v1_0-2B-zh'
offset_bos = 1 # for Bunny-v1_1-Llama-3-8B-V, Bunny-Llama-3-8B-V, Bunny-v1_1-4B, Bunny-v1_0-4B and Bunny-v1_0-3B-zh
# offset_bos = 0 for Bunny-v1_0-3B and Bunny-v1_0-2B-zh

# create model
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16, # float32 for cpu
    device_map='auto',
    trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.07it/s]
Some weights of the model checkpoint at BAAI/Bunny-v1_1-Llama-3-8B-V were not used when initializing BunnyLlamaForCausalLM: ['model.vision_tower.vision_tower.vision_model.encoder.layers.26.layer_norm1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.26.layer_norm1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.26.layer_norm2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.26.layer_norm2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.26.mlp.fc1.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.26.mlp.fc1.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.26.mlp.fc2.bias', 'model.vision_tower.vision_tower.vision_model.encoder.layers.26.mlp.fc2.weight', 'model.vision_tower.vision_tower.vision_model.encoder.layers.26.self_attn.k_proj.bias', 'model.vision_tower.vision_tower.vision_model.encoder.laye

In [3]:
def get_concat_h(im1, im2):
    dst = Image.new('RGB', (im1.width + im2.width, im1.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (im1.width, 0))
    return dst

def get_concat_v(im1, im2):
    dst = Image.new('RGB', (im1.width, im1.height + im2.height))
    dst.paste(im1, (0, 0))
    dst.paste(im2, (0, im1.height))
    return dst

In [18]:
import json
import glob

example_paths = glob.glob("data/*.json")

results = []
for example_path in example_paths:
    with open(example_path) as f:
        example = json.load(f)
    example_idx = example_path.split("/")[-1].split(".json")[0]
    a_img = Image.open("data/imgs/" + example_idx + "_a.png")
    b_img = Image.open("data/imgs/" + example_idx + "_b.png")
    image = get_concat_h(a_img, b_img)
    image_tensor = model.process_images([image], model.config).to(
        dtype=model.dtype, device="cuda"
    )

    # text prompt
    for caption_origin in ["A", "B", "A_B", "B_A"]:
        caption = example[caption_origin]
        print(caption_origin, caption)

        prompt = f"Description: {caption}\n Does the description match any of the given images? Reply with correct or incorrect."
        text = f"A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. USER: <image>\n{prompt} ASSISTANT:"
        text_chunks = [tokenizer(chunk).input_ids for chunk in text.split("<image>")]
        input_ids = (
            torch.tensor(
                text_chunks[0] + [-200] + text_chunks[1][offset_bos:], dtype=torch.long
            )
            .unsqueeze(0)
            .to("cuda")
        )
        # generate
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            max_new_tokens=32,
            use_cache=True,
            repetition_penalty=1.0,  # increase this to avoid chattering
        )[0]

        # decode
        output_text = tokenizer.decode(output_ids, skip_special_tokens=True).split(
            "ASSISTANT: "
        )[-1]
        if "incorrect" in output_text.lower():
            results.append((example_idx, caption_origin, "incorrect", output_text))
        elif "correct" in output_text.lower():
            results.append((example_idx, caption_origin, "correct", output_text))
        else:
            results.append((example_idx, caption_origin, "unknown", output_text))

A Two adult and one baby elephant walking in the forest.
B a man on skis stands on a snowy hill 
A_B Two adult and one baby elephant are standing on a snowy hill
B_A Two adult and one baby elephant walking on a snowy hill
A A tree on the sidewalk of a road in a city area.
B a small child holding a soccer ball in a room
A_B A tree on a small child holding a soccer ball
B_A a small child holding a soccer ball on the sidewalk of a road in a city
A A woman pushing a cart filled with lots of luggage.
B This is a living room with a gray couch and yellow chair.
A_B A woman pushing a cart filled with a gray couch and yellow chair.
B_A A gray couch pushing a cart filled with lots of luggage.
A A man wears a wrap towel in a hospital room.
B A pizza is loaded with broccoli and chicken.
A_B A pizza is loaded with chicken in a hospital room.
B_A A pizza is loaded with a wrap towel in a hospital room.
A We are looking at a closeup of a tie.
B a tower with a clock on top with a sky background 
A_B We

In [24]:
import pandas as pd

df = pd.DataFrame(
    results, columns=["example_idx", "caption_origin", "result", "output_text"]
)

df.loc[df["caption_origin"].isin(["A", "B"])].result.value_counts()

result
correct      11
incorrect     9
Name: count, dtype: int64

In [25]:
df.loc[df["caption_origin"].isin(["A_B", "B_A"])].result.value_counts()

result
incorrect    14
correct       6
Name: count, dtype: int64