In [None]:
import os
import torch
from llava.conversation import conv_templates
from llava.mm_utils import (
    KeywordsStoppingCriteria,
    get_model_name_from_path,
    load_pretrained_model,
    process_images,
    tokenizer_image_token,
)
from llava.model.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, key_info
from PIL import Image

In [None]:
def query(inputs: dict) -> dict:

    model_path = os.path.expanduser(model_path)
    key_info["model_path"] = model_path
    get_model_name_from_path(model_path)
    tokenizer, model, image_processor, _ = load_pretrained_model(model_path)

    conv = conv_templates["mm_default"].copy()
    roles = conv.roles

    image = Image.open(inputs["image_path"]).convert("RGB")
    image_tensor = process_images([image], image_processor, model.config)
    image_tensor = image_tensor.to(model.device, dtype=torch.bfloat16)

    prompt1 = f"""
    Observe the given image and its details.
    Provide a detailed step-by-step guide on how a human would complete the task of: {inputs["task"]}.
    Link each instruction to an observation in the image in this format: Observation - Instruction.
    """

    input1 = DEFAULT_IMAGE_TOKEN + "\n" + f"{roles[0]}: "
    conv.append_messages(conv.roles[0], input1)
    conv.append_messages(conv.roles[1], None)
    
    input_ids = (
            tokenizer_image_token(
                prompt1, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt"
            )
            .unsqueeze(0)
            .to(model.device)
        )
    
    stop_str = conv.sep
    keywords = [stop_str]
    stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
    
    with torch.inference_mode():
            output_ids = model.generate(
                input_ids,
                images=image_tensor,
                do_sample=True,
                temperature=0.2,
                max_new_tokens=1024,
                use_cache=True,
                stopping_criteria=[stopping_criteria],
            )

    input_token_len = input_ids.shape[1]
    outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
    outputs = outputs.strip()
    if outputs.endswith(stop_str):
        outputs = outputs[: -len(stop_str)]
    output1 = outputs.strip()

    prompt2 = f"""
    Imagine you are in control of a robotic arm with the following commands: {inputs["bot_commands"]}
    Given the human instructions you have generated, provide a guide on how the robot would complete the task.
    """

    

    prompt3 = f"""
    By referencing an observation in the image, ensure each instruction is accurate. Do not make assumptions.
    Check that each instruction is logical.
    """

    

    return {"bot_inst": output1}