In [1]:
import sys
sys.path.append('..')

import argparse
import torch

from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path

from PIL import Image

import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer

from utils.data import get_data

In [2]:
MODEL_NAME = "liuhaotian/llava-v1.6-34b"
temperature = 0.2
max_new_tokens = 512
num_models = 2
models = []

disable_torch_init()

model_name = get_model_name_from_path(MODEL_NAME)
for i in range(num_models):
    tokenizer, model, image_processor, context_len = load_pretrained_model(MODEL_NAME, model_base=None, model_name=model_name, load_8bit=False, load_4bit=True, device_map="auto")
    models.append({"tokenizer":tokenizer, "model":model, "image_processor":image_processor, "context_len":context_len})

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/15 [00:00<?, ?it/s]

In [8]:
if "llama-2" in model_name.lower():
    conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
    conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
    conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

conv = []
roles = []

for i in range(num_models):
    conv.append(conv_templates[conv_mode].copy())
    if "mpt" in model_name.lower():
        roles.append(('user', 'assistant'))
    else:
        roles.append(conv[i].roles)

image, caption, img_path = get_data(5)
image_size = image.size
# Similar operation in model_worker.py
image_tensor = process_images([image], models[0]['image_processor'], models[0]['model'].config)
if type(image_tensor) is list:
    image_tensor = [image.to(models[0]['model'].device, dtype=torch.float16) for image in image_tensor]
else:
    image_tensor = image_tensor.to(models[0]['model'].device, dtype=torch.float16)



In [9]:
#print(f"{roles[1]}: ", end="")
num_rounds = 3
temp=""
for round in range(num_rounds+1):
    for i in range(num_models):
        if round == 0:
            inp = """{}: Given the text: {}. Does this text belong to the same context as the image or is the image being used out of context to spread misinformation?
                    Explain your answer. Do not describe the image, only answer the question.""".format(roles[i][0], caption)
            print("INPUT MESSAGE: ", inp)
        elif i == 1:
            inp = """ {}: This is what another AI agent said for the same image and text input when asked if the image and text were being used out of context: {}.
            Do you agree with the other agent? Refine your original answer using this new information. If you disagree with the other agent, clearly state your reasoning.
            Do not describe the image. 
            """.format(roles[i][0], temp)
            print("INPUT MESSAGE: ", inp)
        else:
            print("========== Debate Round: {} ==========".format(round))
            inp = """ {}: This is what another AI agent said for the same image and text input when asked if the image and text were being used out of context: {}.
            Do you agree with the other agent? Refine your original answer using this new information. If you disagree with the other agent, clearly state your reasoning.
            Do not describe the image. 
            """.format(roles[i][0], conv[(i+1)%num_models].messages[-1][-1])
            print("INPUT MESSAGE: ", inp)
        print("========================== Agent - {} =====================".format(i+1))
        if image is not None:
            # first message
            if models[i]['model'].config.mm_use_im_start_end:
                inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp
            else:
                inp = DEFAULT_IMAGE_TOKEN + '\n' + inp
            if round == 0 and i == 1:
                print("image is none!")
                image = None

        conv[i].append_message(conv[i].roles[0], inp)
        conv[i].append_message(conv[i].roles[1], None)
        prompt = conv[i].get_prompt()

        input_ids = tokenizer_image_token(prompt, models[i]['tokenizer'], IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(models[i]['model'].device)
        stop_str = conv[i].sep if conv[i].sep_style != SeparatorStyle.TWO else conv[i].sep2
        keywords = [stop_str]
        streamer = TextStreamer(models[i]['tokenizer'], skip_prompt=True, skip_special_tokens=True)

        with torch.inference_mode():
            output_ids = models[i]['model'].generate(
                input_ids,
                images=image_tensor,
                image_sizes=[image_size],
                do_sample=True if temperature > 0 else False,
                temperature=temperature,
                max_new_tokens=max_new_tokens,
                streamer=streamer,
                use_cache=True)

        outputs = models[i]['tokenizer'].decode(output_ids[0]).strip()
        conv[i].messages[-1][-1] = outputs
        if i == 0:
            temp = outputs

        #print("\n", {"prompt": prompt, "outputs": outputs}, "\n")

INPUT MESSAGE:  <|im_start|>user
: Given the text: The Brandenburg Gate stands illuminated during celebrations on the 25th anniversary of the fall of the Berlin Wall. Does this text belong to the same context as the image or is the image being used out of context to spread misinformation?
                    Explain your answer. Do not describe the image, only answer the question.
The text provided does not match the image. The image appears to depict a large crowd at a sporting event or concert, with a large screen displaying a bridge or stadium structure, and the crowd is holding up lights or cell phones, creating a sea of illuminated devices. The text about the Brandenburg Gate and the fall of the Berlin Wall is unrelated to the image and is being used out of context.
INPUT MESSAGE:  <|im_start|>user
: Given the text: The Brandenburg Gate stands illuminated during celebrations on the 25th anniversary of the fall of the Berlin Wall. Does this text belong to the same context as the im

KeyboardInterrupt: 