In [1]:
import sys
sys.path.append('..')

import torch
import json
from PIL import Image
from transformers import BitsAndBytesConfig, LlavaForConditionalGeneration, AutoProcessor
from utils.external_retrieval import get_matching_urls, get_webpage_title, get_summary
from utils.data import get_data

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16
)

In [2]:
device1 = "cuda"
device2 = "cuda:1"

In [3]:
MODEL_NAME = "llava-hf/llava-1.5-13b-hf"

model1 = LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
model2 =  LlavaForConditionalGeneration.from_pretrained(MODEL_NAME, device_map="auto")
processor = AutoProcessor.from_pretrained(MODEL_NAME)



Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
def initial_prediction_prompt(history, caption):
    prompt =  """
    CHAT HISTORY: {}
    USER: <image>
    TEXT: "{}"

    Your task is to determine whether the given text and image are from the same context (e.g., the same news article or event) or if the image is being used out of context. 
    To aid your analysis, you have the option to request an Internet search based on the image to gather more information about the context in which the image is used.

    Provide a detailed analysis explaining your reasoning behind your decision. Consider the content of the text and the objects, actions, or scenes depicted in the image. 
    Analyze whether they align and provide context for each other or if they appear to be unrelated. 
    If you need more information about the context of the image state it separately after "INFO REQUIRED: ".

    Do not describe the image. Focus on performing the task and providing your answer as per the instructions.

    ASSISTANT:""".format(history, caption)
    return prompt

In [5]:
def debate_prompt(history, caption, agent_response):
    prompt = """
    CHAT HISTORY: {}
    USER: <image> 
    TEXT: "{}"

    This is the answer another AI agent generated for the same image and text pair:
    "{}"

    Your task is to critically analyze the other agent's response and provide a refined answer based on this new information. 
    To strengthen your analysis and argument, you have the option to request an Internet search based on the image to gather more information 
    about the context in which the image is used.

    However, instead of blindly agreeing or repeating yourselves, focus on the following:

    1. Identify any potential inconsistencies, flaws, or counterarguments in the other agent's reasoning or analysis regarding whether the image and text are from the same context or not.
    2. Determine if there are any gaps or missing information that could lead to a more comprehensive understanding of the image-text relationship and their contextual alignment (or misalignment). 
    3. If you disagree with the other agent's assessment, respectfully point out the specific areas of disagreement and provide evidence or reasoning to support your stance on whether the image and text are from the same context or not.
    4. If you agree with the other agent's assessment, explain why their reasoning is valid and how it complements or strengthens your initial analysis of the image-text context.
    5. If you need more information to strengthen your argument or verify the other agent's argument state it separately as "INFO REQUIRED: ".
    The goal is to have a constructive debate that challenges each other's perspectives, uncovers potential blind spots, and ultimately leads to a more robust and well-reasoned conclusion about whether the image and text are from the same context or if the image is being used out of context. Use the option to request an Internet search if you need additional information to strengthen your argument. Avoid simply repeating or agreeing without critical evaluation.

    Do not describe the image. Focus on performing the task and providing your answer as per the instructions.
    ASSISTANT:""".format(history, caption, agent_response)
    return prompt

In [6]:
def retrieval_prompt(history, caption, search_results):
    prompt = """
    CHAT HISTORY: {}
    USER: <image>
    TEXT: "{}"

    Search Results: This is a summary of the internet search for the context in which the image is used: {}

    Based on the additional context and information gathered from the Internet search, 
    please reevaluate your initial prediction and provide an updated analysis on whether the given image and text are from the same context 
    or if the image is being used out of context.

    Incorporate the relevant information from the search results into your analysis and explain how it either supports or contradicts your initial assessment. 
    If the search results provide new insights or perspectives, discuss how they impact your understanding of the image-text relationship and 
    their contextual alignment (or misalignment).

    The goal is to leverage the additional information from the Internet search to refine your analysis and provide a more comprehensive and 
    well-reasoned conclusion about whether the image and text are from the same context or if the image is being used out of context.

    Do not describe the image. Focus on performing the task and providing your answer as per the instructions.
    ASSISTANT:""".format(history, caption, search_results)
    return prompt

In [7]:
num_iters = 3
data_sample = 7
image, caption,_ = get_data(data_sample)
print("data loaded!")
CUDA_LAUNCH_BLOCKING=1
chat_history1 = []
chat_history2 = []

print("running llm-1...")
prompt = initial_prediction_prompt(chat_history1, caption)
inputs = processor(text=prompt, images=image, return_tensors="pt")
inputs = inputs.to("cuda")
result1 = model1.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
result1 = processor.batch_decode(result1, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
chat_history1.append({'user':prompt[prompt.find("TEXT:")+len("TEXT:"):prompt.find("ASSISTANT")], 'assistant':result1[result1.find("ASSISTANT:")+len("ASSISTANT:"):]})
print("AGENT-1: {}\n\n".format(result1[result1.find("ASSISTANT:")+len("ASSISTANT:"):]))

if "INFO" in result1[result1.find("ASSISTANT:")+len("ASSISTANT:"):]:
    urls = get_matching_urls(data_sample)
    info = get_summary(urls)
    post_retrieval_prompt = retrieval_prompt(chat_history1, caption, info)
    retrieval_inputs = processor(text=post_retrieval_prompt, images=image, return_tensors="pt")
    retrieval_inputs = retrieval_inputs.to(device1)
    result1 = model1.generate(**retrieval_inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
    result1 = processor.batch_decode(result1, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    chat_history1.pop(-1)
    chat_history1.append({'user':post_retrieval_prompt[post_retrieval_prompt.find("TEXT:")+len("TEXT:"):post_retrieval_prompt.find("ASSISTANT")], 'assistant':result1[result1.find("ASSISTANT:")+len("ASSISTANT:"):]})
    print("AGENT-1 after internet access: {}\n\n".format(result1[result1.find("ASSISTANT:")+len("ASSISTANT:"):]))


print("running llm-2...")
inputs = inputs.to(device2)
result2 = model2.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
result2 = processor.batch_decode(result2,skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
chat_history2.append({'user':prompt[prompt.find("TEXT:")+len("TEXT:"):prompt.find("ASSISTANT")], 'assistant':result2[result2.find("ASSISTANT:")+len("ASSISTANT:"):]})
print("AGENT-2: {}\n\n".format(result2[result2.find("ASSISTANT:")+len("ASSISTANT:"):]))

if "INFO" in result2[result2.find("ASSISTANT:")+len("ASSISTANT:"):]:
    urls = get_matching_urls(data_sample)
    info = get_summary(urls)
    post_retrieval_prompt = retrieval_prompt(chat_history2, caption, info)
    inputs = processor(text=post_retrieval_prompt, images=image, return_tensors="pt")
    inputs = inputs.to(device=device2)
    result2 = model2.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
    result2 = processor.batch_decode(result2, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    chat_history2.pop(-1)
    chat_history2.append({'user':post_retrieval_prompt[post_retrieval_prompt.find("TEXT:")+len("TEXT:"):post_retrieval_prompt.find("ASSISTANT")], 'assistant':result2[result2.find("ASSISTANT:")+len("ASSISTANT:"):]})
    print("AGENT-2 after internet access: {}\n\n".format(result2[result2.find("ASSISTANT:")+len("ASSISTANT:"):]))

print("COMMENCING DEBATE NOW...")
temp = result1

for i in range(num_iters):

    print("=======================================================================================")
    print("\t\t\t\tDEBATE ROUND - ", i+1)
    print("=======================================================================================")

    prompt = debate_prompt(chat_history1, caption, result2[result2.find("ASSISTANT:")+len("ASSISTANT:"):])
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    inputs = inputs.to(device1)
    result1 = model1.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
    result1 = processor.batch_decode(result1, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
    chat_history1.append({'user': prompt[prompt.find("TEXT:")+len("TEXT:"):prompt.find("ASSISTANT")], 'assistant':result1[result1.find("ASSISTANT:")+len("ASSISTANT:"):]})
    print("AGENT-1: {}\n\n ".format(result1[result1.find("ASSISTANT:")+len("ASSISTANT:"):]))

    prompt = debate_prompt(chat_history2, caption, temp[temp.find("ASSISTANT:")+len("ASSISTANT:"):])
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    inputs = inputs.to(device2)
    result2 = model2.generate(**inputs, max_new_tokens=200, do_sample=True, temperature=0.7)
    chat_history2.append({'user': prompt[prompt.find("TEXT:")+len("TEXT:"):prompt.find("ASSISTANT")], 'assistant':result2[result1.find("ASSISTANT:")+len("ASSISTANT:"):]})
    temp = result1
    result2 = processor.batch_decode(result2,skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]

    print("AGENT-2: {}\n\n".format(result2[result2.find("ASSISTANT:")+len("ASSISTANT:"):]))
    
    #keep chat history from only 5 timesteps ago
    if len(chat_history1) > 5:
        chat_history1.pop()
        chat_history2.pop()

data loaded!
running llm-1...
AGENT-1:  Based on the image, the train is silver and black in color, with many black doors and a large black pipe protruding from its undercarriage. The train is either stopped or moving on railroad tracks.

The text provided states that the Waverley line was closed, which could potentially be related to the train depicted in the image. Given this information, it is possible that the train could be a part of the Waverley line, but without further context, it is not possible to definitively confirm the relationship between the train and the text.

INFO REQUIRED: Details about the Waverley line, such as its location, purpose, and history.


Searching the web now!


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  warn_deprecated(
  warn_deprecated(


ValueError: Input length of input_ids is 10000, but `max_length` is set to 10000. This can lead to unexpected behavior. You should consider increasing `max_length` or, better yet, setting `max_new_tokens`.