In [1]:
from gradio_client import Client
import threading
import time
import subprocess

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Define models
models = ["liuhaotian/llava-v1.5-7b", "liuhaotian/llava-v1.6-vicuna-7b", "liuhaotian/llava-v1.5-13b", "liuhaotian/llava-v1.6-vicuna-13b"]
model_names = ["llava-v1.5-7b", "llava-v1.6-vicuna-7b", "llava-v1.5-13b", "llava-v1.6-vicuna-13b"]

# Select which model to use
local_model_index = 3
used_model = model_names[local_model_index]

In [3]:
# Launch a controller
def launch_controller():
    !python -m llava.serve.controller --host 0.0.0.0 --port 10000

# Launch the model worker
def launch_model(model, port):
    model_launch = f"python -m llava.serve.model_worker --host 0.0.0.0 --controller http://localhost:10000 --port {port} --worker http://localhost:{port} --model-path {model} --load-4bit"
    !{model_launch}

# Launch a gradio web server
def launch_web_server():
    !python -m llava.serve.gradio_web_server --controller http://localhost:10000 --model-list-mode reload

controller_thread = threading.Thread(target=launch_controller)
web_server_thread = threading.Thread(target=launch_web_server)
model_thread = threading.Thread(target=launch_model, args=(models[local_model_index], 40000))

In [None]:
# Launch a controller
controller_thread.start()

# Launch the web server
web_server_thread.start()

# Launch the model worker
model_thread.start()

## IMPORTANT: Wait until the model worker has successfully started before proceeding.

In [5]:
# Defining the prompt template
prompt_template = "Score the image caption on a scale from 1 to 4, with a 1 indicating that the caption does not describe the image at all, \
a 2 indicating the caption describes minor aspects of the image but does not describe the image, a 3 indicating that the caption almost \
describes the image with minor mistakes, and a 4 indicating that the caption describes the image. Your output shall only consist of your score. Image caption: "

In [6]:
# Create dictionary for all image captions
caption_dict = {}
with open("Flickr8k_text/Flickr8k.token.txt", 'r') as file:
    for line in file:
        parts = line.strip().split('\t')
        caption_dict[parts[0]] = parts[1]

In [None]:
# Accessing LLM through API
client = Client("http://localhost:7860/")
image_directory = "Flickr8k_Dataset/"

file_name = used_model + "_results.csv"

with open("Flickr8k_text/ExpertAnnotations.txt", 'r') as file:
    print("Model:", used_model)
    counter = 0
    starting_number = 0
    for line in file:
        if counter < starting_number:
            counter += 1
            continue
        parts = line.strip().split('\t')
        #print(parts)
        image_path = image_directory + "/" + parts[0]

        # Create prompt
        caption_text = caption_dict[parts[1]]
        prompt = prompt_template + caption_text

        result = client.predict(
                prompt,	# textual input
                image_path,	# filepath  for visual input
                "Default",	# Literal['Crop', 'Resize', 'Pad', 'Default']  in 'Preprocess for non-square image' Radio component
                api_name="/add_text"
        )
        result = client.predict(
                used_model,	# Literal['llava-v1.5-7b']  in 'parameter_10' Dropdown component
                0,	# float (numeric value between 0.0 and 1.0) in 'Temperature' Slider component
                0,	# float (numeric value between 0.0 and 1.0) in 'Top P' Slider component
                512,	# float (numeric value between 0 and 1024) in 'Max output tokens' Slider component
                api_name="/http_bot"
        )

        # Compute average expert score
        avg_expert_score = round((int(parts[2]) + int(parts[3]) + int(parts[4])) / 3)

        # Append result to result file
        with open(file_name, 'a') as results_file:
            #print("Index:", counter, "Score:",result[0][-1], "Expert-Score", str(avg_expert_score))
            result_line = parts[0] + ";" + caption_text + ";" + result[0][-1] + ";" + str(avg_expert_score) + "\n"
            results_file.write(result_line)

        if counter % 50 == 0:
            print(counter)
        counter += 1

Loaded as API: http://localhost:7860/ ✔
Model: llava-v1.6-vicuna-13b
0
50
100
150
200
250
300
350
400
450
500
550
