In [None]:
!git clone https://github.com/dcxjn/prompting.git /content/prompting

fatal: destination path '/content/prompting' already exists and is not an empty directory.


In [None]:
import os
os.chdir('/content/prompting')

In [None]:
import sys
sys.path.append('/content/prompting')

In [None]:
from transformers import AutoTokenizer, BitsAndBytesConfig
from src.llava_prompting.llava.model import LlavaLlamaForCausalLM
from src.llava_prompting.llava.utils import disable_torch_init
from src.llava_prompting.llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from src.llava_prompting.llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria
from src.llava_prompting.llava.conversation import conv_templates, SeparatorStyle
import torch
from io import BytesIO

from src.utils.image_util import load_image, resize_image

In [None]:
# From https://towardsdatascience.com/create-your-vision-chat-assistant-with-llava-610b02c3283e

class LLaVAChatBot:
    def __init__(self,
                 model_path: str = 'liuhaotian/llava-v1.5-7b',
                 device_map: str = 'auto',
                 load_in_8_bit: bool = True,
                 **quant_kwargs) -> None:
        self.model = None
        self.tokenizer = None
        self.image_processor = None
        self.conv = None
        self.conv_img = None
        self.img_tensor = None
        self.roles = None
        self.stop_key = None
        self.load_models(model_path,
                         device_map=device_map,
                         load_in_8_bit=load_in_8_bit,
                         **quant_kwargs)

    def load_models(self, model_path: str,
                    device_map: str,
                    load_in_8_bit: bool,
                    **quant_kwargs) -> None:
        """Load the model, processor and tokenizer."""
        quant_cfg = BitsAndBytesConfig(**quant_kwargs)
        self.model = LlavaLlamaForCausalLM.from_pretrained(model_path,
                                                           low_cpu_mem_usage=True,
                                                           device_map=device_map,
                                                           load_in_8bit=load_in_8_bit,
                                                           quantization_config=quant_cfg)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path,
                                                       use_fast=False)
        vision_tower = self.model.get_vision_tower()
        vision_tower.load_model()
        vision_tower.to(device='cuda')
        self.image_processor = vision_tower.image_processor
        disable_torch_init()

    def setup_image(self, img_path: str) -> None:
        """Load and process the image."""
        if img_path.startswith('http') or img_path.startswith('https'):
            response = requests.get(img_path)
            self.conv_img = Image.open(BytesIO(response.content)).convert('RGB')
        else:
            self.conv_img = Image.open(img_path).convert('RGB')
        self.img_tensor = self.image_processor.preprocess(self.conv_img,
                                                          return_tensors='pt'
                                                          )['pixel_values'].half().cuda()

    def generate_answer(self, **kwargs) -> str:
        """Generate an answer from the current conversation."""
        raw_prompt = self.conv.get_prompt()
        input_ids = tokenizer_image_token(raw_prompt,
                                          self.tokenizer,
                                          IMAGE_TOKEN_INDEX,
                                          return_tensors='pt').unsqueeze(0).cuda()
        stopping = KeywordsStoppingCriteria([self.stop_key],
                                            self.tokenizer,
                                            input_ids)
        with torch.inference_mode():
            output_ids = self.model.generate(input_ids,
                                             images=self.img_tensor,
                                             stopping_criteria=[stopping],
                                             **kwargs)
        outputs = self.tokenizer.decode(
            output_ids[0, input_ids.shape[1]:]
        ).strip()
        self.conv.messages[-1][-1] = outputs

        return outputs.rsplit('</s>', 1)[0]

    def get_conv_text(self) -> str:
        """Return full conversation text."""
        return self.conv.get_prompt()

    def start_new_chat(self,
                       img_path: str,
                       prompt: str,
                       do_sample=True,
                       temperature=0.2,
                       max_new_tokens=1024,
                       use_cache=True,
                       **kwargs) -> str:
        """Start a new chat with a new image."""
        conv_mode = "v1"
        self.setup_image(img_path)
        self.conv = conv_templates[conv_mode].copy()
        self.roles = self.conv.roles
        first_input = (DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN +
                       DEFAULT_IM_END_TOKEN + '\n' + prompt)
        self.conv.append_message(self.roles[0], first_input)
        self.conv.append_message(self.roles[1], None)
        if self.conv.sep_style == SeparatorStyle.TWO:
            self.stop_key = self.conv.sep2
        else:
            self.stop_key = self.conv.sep
        answer = self.generate_answer(do_sample=do_sample,
                                      temperature=temperature,
                                      max_new_tokens=max_new_tokens,
                                      use_cache=use_cache,
                                      **kwargs)
        return answer

    def continue_chat(self,
                      prompt: str,
                      do_sample=True,
                      temperature=0.2,
                      max_new_tokens=1024,
                      use_cache=True,
                      **kwargs) -> str:
        """Continue the existing chat."""
        if self.conv is None:
            raise RuntimeError("No existing conversation found. Start a new"
                               "conversation using the `start_new_chat` method.")
        self.conv.append_message(self.roles[0], prompt)
        self.conv.append_message(self.roles[1], None)
        answer = self.generate_answer(do_sample=do_sample,
                                      temperature=temperature,
                                      max_new_tokens=max_new_tokens,
                                      use_cache=use_cache,
                                      **kwargs)
        return answer

In [None]:
def query(inputs: dict) -> str:
    chatbot = LLaVAChatBot(load_in_8bit=True,
                           bnb_8bit_compute_dtype=torch.float16,
                           bnb_8bit_use_double_quant=True,
                           bnb_8bit_quant_type='nf8')
    
    prompt1 = f"""
    Observe the given image and its details.
    Provide a detailed step-by-step guide on how a human would complete the task of: {inputs["task"]}.
    Link each instruction to an observation in the image in this format: "Observation - Instruction".
    """
    
    prompt2 = f"""
    Imagine you are in control of a robotic arm with the following commands: {inputs["bot_commands"]}
    Given the human instructions you have generated, provide a guide on how the robot would complete the task.
    """
    
    prompt3 = f"""
    By referencing an observation in the image, ensure each instruction is accurate. Do not make assumptions.
    Check that each instruction is logical.
    """

    output1 = chatbot.start_new_chat(img_path=inputs["image_path"],
                                     prompt=prompt1)
    
    print("\n=== OUTPUT 1 ===\n")
    print(output1)

    output2 = chatbot.continue_chat(prompt2)

    print("\n=== OUTPUT 2 ===\n")
    print(output2)

    output3 = chatbot.continue_chat(prompt3)

    return output3

In [None]:
# Robot commands available
bot_commands = """
    1. move_to(x, y)
    2. grab(object)
    3. release(object)
    4. push(object)
    5. pull(object)
    6. rotate(angle)
"""

In [None]:
# image_path = input("Enter the path of the image: ")
# image_path = r"images\fridge_lefthandle.jpg"
# image_path = r"images\housedoor_knob_push.jpg"
# image_path = r"images\browndoor_knob_pull.jpg"
# image_path = r"images\labdoor_straighthandle_pull.jpg"
image_path = r"images/bluedoor_knob_push.jpg"
# image_path = r"images\whitetable.jpg"

In [None]:
resize_image(image_path, image_path)
image = load_image({"image_path": image_path})["image"]

In [None]:
# Define the task to be performed
task = input("Enter the task to be performed: ")

Enter the task to be performed: open the door


In [None]:
result = query(
    {
        "image_path": image_path,
        "task": task,
        "bot_commands": bot_commands,
    }
)

In [None]:
print("\n==========\n")
print(result)