In [1]:
import torch
from LLaVA.llava.model.builder import load_pretrained_model
from LLaVA.llava.mm_utils import get_model_name_from_path

from LLaVA.llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from LLaVA.llava.conversation import conv_templates, SeparatorStyle
from LLaVA.llava.model.builder import load_pretrained_model
from LLaVA.llava.utils import disable_torch_init
from LLaVA.llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
    KeywordsStoppingCriteria,
)

model_path = "checkpoints/llava-lora-1.5"
print(get_model_name_from_path(model_path))
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base="liuhaotian/llava-v1.5-7b",
    model_name=get_model_name_from_path(model_path)
)



[2024-04-29 04:26:33,641] [INFO] [real_accelerator.py:191:get_accelerator] Setting ds_accelerator to cuda (auto detect)
llava-lora-1.5
Loading LLaVA from base model...


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


Loading additional LLaVA weights...
Loading LoRA weights...
Merging LoRA weights...
Model is loaded...


In [2]:
import re
query = "Give me the sunglasses"
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in query:
    if model.config.mm_use_im_start_end:
        qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, query)
    else:
        qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, query)
else:
    if model.config.mm_use_im_start_end:
        qs = image_token_se + "\n" + query
    else:
        qs = DEFAULT_IMAGE_TOKEN + "\n" + query
conv_mode = "llava_llama_2"

conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

In [3]:

input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .cuda()
)
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

In [None]:
import requests
from PIL import Image
from io import BytesIO


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

images = load_images(["grasp-anything++/seen/image/0e73f01e5d4b1fc064f6ab381209891c376fa26bdc2dbf7578db9b611e6c6337.jpg"])
images_tensor = process_images(
    images,
    image_processor,
    model.config
).to(model.device, dtype=torch.float16)

In [None]:
with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=images_tensor,
        do_sample=True ,
        temperature=0.8,
        top_p=0.95,
        num_beams=1,
        max_new_tokens=40,
        use_cache=True,
        stopping_criteria=[stopping_criteria],
    )

In [None]:
input_token_len = input_ids.shape[1]
outputs = tokenizer.batch_decode(
        output_ids, skip_special_tokens=True
    )[0]
outputs

In [None]:
import torch.nn.functional as F

def response_to_batch(inputs, model, eos_token_id, max_length=100):
    stop_counter = torch.zeros(len(inputs["input_ids"])).to(inputs["input_ids"].device).long()
    input_ids = inputs["input_ids"]
    attention_mask = inputs["attention_mask"]
    
    for i in range(max_length):
        next_token_logits = []
        outputs = model(**inputs)
        token_logits = outputs.logits[:, -1, :]
        last_hidden_states = outputs.hidden_states[-1]
        next_token_logits.append(token_logits)
        next_token_logits = torch.cat(next_token_logits, dim=-1)
        probs = F.softmax(next_token_logits, dim=-1)
        next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
        input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
        
        inputs["input_ids"] = input_ids
        attention_mask = torch.cat([attention_mask, torch.ones((attention_mask.shape[0], 1)).to(attention_mask.device)], dim=-1).long()
        inputs["attention_mask"] = attention_mask
        
        stop_counter += (next_token == eos_token_id).detach().long()
        
        if (stop_counter > 0).all():
            break
        
        print(last_hidden_states)
        
    return input_ids[-max_length:]

output = response_to_batch(inputs, model, 32001)

generated_text = processor.batch_decode(output, skip_special_tokens=True)

In [None]:
processor(["Here is [SPT] handle of the mug [SPT]"])

In [None]:
generated_text

In [None]:
import pickle as pkl
with open("grasp-anything++/seen/grasp_instructions/1cdd8f145672b4a7959503e815d3bb67ede9f7e9fb7c7be8d9f743f04bbf5bc7_0_0.pkl", "rb") as f:
    instruction = pkl.load(f)

In [None]:
instruction

In [None]:
x