In [1]:
import torch
from LLaVA.llava.model.builder import load_pretrained_model
from LLaVA.llava.mm_utils import get_model_name_from_path

from LLaVA.llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from LLaVA.llava.conversation import SeparatorStyle
from LLaVA.llava.model.builder import load_pretrained_model
from LLaVA.llava.utils import disable_torch_init
from LLaVA.llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
    KeywordsStoppingCriteria,
)

model_path = "checkpoints/llava-lora-1.10"
print(get_model_name_from_path(model_path))
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base="liuhaotian/llava-v1.5-7b",
    model_name=get_model_name_from_path(model_path)
)

ModuleNotFoundError: No module named 'LLaVA.llava'

In [100]:
query = "Descibe the image"
prompt = DEFAULT_IMAGE_TOKEN + "\n" + query

In [101]:

input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .cuda()
)

In [102]:
input_ids.shape

torch.Size([1, 9])

In [103]:
import requests
from PIL import Image
from io import BytesIO


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    return image

def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

images = load_images(["grasp-anything++/seen/image/0b40406618b99611efd8ca8798e8ce46f4951dc733c1fc17adf94f8b6afec999.jpg"])
images_tensor = process_images(
    images,
    image_processor,
    model.config
).to(model.device, dtype=torch.float16)

In [104]:
with torch.inference_mode():
    output_ids = model.generate(
        input_ids,
        images=images_tensor,
        do_sample=False,
        temperature=0.8,
        max_new_tokens=100,
        use_cache=True,
    )

In [105]:
output_ids

tensor([[    1, 29889,    13,    13, 29902,   871,  3867,   278,  3646, 25274,
          1203,  1024,  2729,   373,   278,  1881,  1347, 29889,   512,   445,
          1206, 29892,   278,  3646, 25274,  1203,  1024,   338,   376,  1731,
           261,  1642,   518,  5550, 29911, 29962, 28149,   518,  5550, 29911,
         29962, 29871,     2]], device='cuda:0')

In [106]:
input_token_len = input_ids.shape[1]
outputs = tokenizer.batch_decode(
        output_ids, skip_special_tokens=True
    )[0]
outputs

'.\n\nI only provide the target grasp object name based on the input string. In this case, the target grasp object name is "flower". [SPT] flower [SPT] '

In [85]:
output_ids

tensor([[    1,    13,    13, 10605,   338,   518,  5550, 29911, 29962, 26163,
           518,  5550, 29911, 29962, 29871,     2]], device='cuda:0')

In [None]:
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vision_tower_name = "openai/clip-vit-large-patch14-336"

processor = CLIPImageProcessor.from_pretrained(vision_tower_name)
vision_tower = CLIPVisionModel.from_pretrained(vision_tower_name, device_map="auto")
image_forward_out = vision_tower(images_tensor[0].to(device=device).unsqueeze(0), output_hidden_states=True)

In [None]:
image_forward_out.last_hidden_state.shape