In [1]:
import pytest
from loguru import logger
from one_model.common.registries import (
    LLM_MODEL_REGISTRY,
    TOKENIZER_REGISTRY,
    ENCODER_REGISTRY,
    PROJECTOR_REGISTRY,
    DECODER_REGISTRY,
)
from one_model.dataset import *
import transformers
from one_model.common.mm_utils import (
    get_model_name_from_path,
    load_image,
    tokenizer_image_token,
)

from addict import Dict
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    AutoConfig,
    BitsAndBytesConfig,
)
from one_model.dataset import *
from one_model.model.encoder import *
from one_model.model.decoder import *
from one_model.model.projector import *
from one_model.model.tokenizer import *
from addict import Dict
from one_model.common.config import Config
from one_model.model.llm import *
import os
from one_model.common.conversation import conv_templates, SeparatorStyle
from one_model.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
)
import torch
import numpy as np
import cv2
from pathlib import Path
import imp

  from .autonotebook import tqdm as notebook_tqdm


[2023-09-28 17:01:25,995] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


  import imp


In [3]:
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [4]:
# reload module
# imp.reload(torch)

<module 'torch' from '/root/anaconda3/envs/lisa/lib/python3.10/site-packages/torch/__init__.py'>

In [2]:
config_path = '/opt/product/one_model/tests/test_config_13B.yaml'
cur_dir = '/opt/product/one_model/tests'


In [3]:
device = "cuda:0"
config: Config = Config(Dict(cfg_path=config_path))
one_model_cfg = config.model_cfg
tokenizer_cfg = config.tokenizer_cfg[one_model_cfg.tokenizer]
tokenizer_cls = TOKENIZER_REGISTRY.get(tokenizer_cfg.type)
tokenizer = tokenizer_cls.from_config(tokenizer_cfg).tokenizer

llm_config = config.llm_cfg[one_model_cfg.llm]
llm_cls = LLM_MODEL_REGISTRY.get(llm_config.type)
model: LlavaLlamaForCausalLM = llm_cls.from_config(llm_config, tokenizer)
assert model is not None
logger.info("model {}", model)

conv_mode = "llava_llama_2"
conv = conv_templates[conv_mode].copy()

clip_encoder_large_config = config.encoder_cfg[one_model_cfg.encoder]
clip_encoder_cls = ENCODER_REGISTRY.get(clip_encoder_large_config.type)
clip_encoder: CLIPEncoder = clip_encoder_cls.from_config(clip_encoder_large_config)
clip_encoder = clip_encoder.cuda(0)

image_proj_13B_config = config.projector_cfg[one_model_cfg.in_projector]
projector_cls = PROJECTOR_REGISTRY.get(image_proj_13B_config.type)
logger.info("image_proj_13B_config {}", image_proj_13B_config)
in_projector = projector_cls.from_config(image_proj_13B_config)
logger.info("projector {}", in_projector)
in_projector = in_projector.cuda(0)

out_project_13B_config = config.projector_cfg[one_model_cfg.out_projector]
out_projector_cls = PROJECTOR_REGISTRY.get(out_project_13B_config.type)
out_projector = out_projector_cls.from_config(out_project_13B_config)
out_projector = out_projector.cuda(0)
logger.info("out_project_13B_config {}", out_project_13B_config)
logger.info("out_projector {}", out_projector)

sam_decoder_config = config.decoder_cfg[one_model_cfg.decoder]
sam_decoder_config.model_name_or_path = "/opt/product/llrs/checkpoints/sam_13b.pt"
sam_decoder_cls = DECODER_REGISTRY.get(sam_decoder_config.type)
sam_decoder: SamDecoder = sam_decoder_cls.from_config(sam_decoder_config)
sam_decoder = sam_decoder.cuda(0)

model.get_model().vision_tower = clip_encoder
model.get_model().mm_projector = in_projector

image_file = f"{cur_dir}/view.jpg"
image_processor = clip_encoder.image_processor

# image processor
image = load_image(image_file)
image_tensor = (
    image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
    .half()
    .cuda()
)

inp = "segment the lake"
if image is not None:
    # first message
    if model.config.mm_use_im_start_end:
        inp = (
            DEFAULT_IM_START_TOKEN
            + DEFAULT_IMAGE_TOKEN
            + DEFAULT_IM_END_TOKEN
            + "\n"
            + inp
        )
    else:
        inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
    conv.append_message(conv.roles[0], inp)
    image = None
else:
    # later messages
    conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

logger.info("prompt {}", prompt)
input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .cuda()
)
prompt_len = len(prompt)
model.eval()
with torch.no_grad():
    outputs = model.generate(
        input_ids=input_ids,
        images=image_tensor,
        max_new_tokens=512,
        num_beams=1,
        output_hidden_states=True,
        return_dict_in_generate=True,
    )

output_hidden_states = outputs.hidden_states[-1]
output_ids = outputs.sequences

text_output = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
seg_token_mask = output_ids[:, 1:] == model.seg_token_idx
seg_token_mask = seg_token_mask.to(device)
# hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
seg_token_mask = torch.cat(
    [
        torch.zeros((seg_token_mask.shape[0], 255))
        .bool()
        .to(seg_token_mask.device),
        seg_token_mask,
    ],
    dim=1,
)

output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
# text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
# logger.info("text_output {}", text_output)
# if len(text_output) > prompt_len:
#     text_output = text_output[prompt_len - 1 :]
# logger.info("text_output {}", text_output)
text_output = text_output.replace("\n", "").replace("</s>", "").replace("  ", " ")
text_output = text_output.split("ASSISTANT: ")[-1]

hidden_states = []

hidden_states.append(out_projector(output_hidden_states))
logger.info("hidden_states {}", hidden_states[0].shape)
logger.info("seg_token_mask shape {}", seg_token_mask.shape)

[32m2023-09-28 17:02:00.985[0m | [1mINFO    [0m | [36mone_model.model.llm.llava[0m:[36mfrom_config[0m:[36m545[0m - [1mllava model config LlavaConfig {
  "_name_or_path": "xinlai/LISA-13B-llama2-v1",
  "architectures": [
    "LISAForCausalLM"
  ],
  "bos_token_id": 1,
  "eos_token_id": 2,
  "freeze_mm_mlp_adapter": true,
  "hidden_act": "silu",
  "hidden_size": 5120,
  "image_aspect_ratio": "square",
  "image_grid_pinpoints": null,
  "initializer_range": 0.02,
  "intermediate_size": 13824,
  "max_position_embeddings": 4096,
  "mm_hidden_size": 1024,
  "mm_resampler_type": null,
  "mm_use_im_patch_token": false,
  "mm_use_im_start_end": true,
  "mm_vision_select_feature": "patch",
  "mm_vision_select_layer": -2,
  "mm_vision_tower": "openai/clip-vit-large-patch14",
  "model_type": "llava",
  "num_attention_heads": 40,
  "num_hidden_layers": 40,
  "num_key_value_heads": 40,
  "out_dim": 256,
  "pad_token_id": 0,
  "pretrain_mm_mlp_adapter": null,
  "pretraining_tp": 1,
  "rms_n

In [4]:
logger.info("hidden_states {}", hidden_states[0].shape)
logger.info("seg_token_mask shape {}", seg_token_mask.shape)

[32m2023-09-28 17:04:07.550[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m1[0m - [1mhidden_states torch.Size([1, 381, 256])[0m
[32m2023-09-28 17:04:07.552[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m2[0m - [1mseg_token_mask shape torch.Size([1, 381])[0m


In [15]:
# change sam 
sam_decoder_config = config.decoder_cfg[one_model_cfg.decoder]
sam_decoder_config.model_name_or_path = "/opt/product/llrs/checkpoints/sam_13b.pt"
sam_decoder_cls = DECODER_REGISTRY.get(sam_decoder_config.type)
sam_decoder: SamDecoder = sam_decoder_cls.from_config(sam_decoder_config)
sam_decoder = sam_decoder.cuda(0)

[32m2023-09-28 17:13:44.539[0m | [1mINFO    [0m | [36mone_model.model.decoder.sam_decoder[0m:[36m__init__[0m:[36m17[0m - [1msam decoder init, model_name_or_path /opt/product/llrs/checkpoints/sam_13b.pt, model_type sam_h[0m


In [10]:
conv_mode = "llava_llama_2"
conv = conv_templates[conv_mode].copy()

image = load_image(image_file)
image_tensor = (
    image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
    .half()
    .to(device)
)

inp = "segment the lake"
if image is not None:
    # first message
    if model.config.mm_use_im_start_end:
        inp = (
            DEFAULT_IM_START_TOKEN
            + DEFAULT_IMAGE_TOKEN
            + DEFAULT_IM_END_TOKEN
            + "\n"
            + inp
        )
    else:
        inp = DEFAULT_IMAGE_TOKEN + "\n" + inp
    conv.append_message(conv.roles[0], inp)
    image = None
else:
    # later messages
    conv.append_message(conv.roles[0], inp)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

logger.info("prompt {}", prompt)
input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    .to(device)
)
prompt_len = len(prompt)
model.eval()
with torch.no_grad():
    outputs = model.generate(
        input_ids=input_ids,
        images=image_tensor,
        max_new_tokens=512,
        num_beams=1,
        output_hidden_states=True,
        return_dict_in_generate=True,
    )

output_hidden_states = outputs.hidden_states[-1]
output_ids = outputs.sequences

text_output = tokenizer.decode(output_ids[0, input_ids.shape[1] :]).strip()
seg_token_mask = output_ids[:, 1:] == model.seg_token_idx
seg_token_mask = seg_token_mask.to(device)
# hack for IMAGE_TOKEN_INDEX (we suppose that there is only one image, and it is in the front)
seg_token_mask = torch.cat(
    [
        torch.zeros((seg_token_mask.shape[0], 255))
        .bool()
        .to(seg_token_mask.device),
        seg_token_mask,
    ],
    dim=1,
)

output_ids = output_ids[0][output_ids[0] != IMAGE_TOKEN_INDEX]
# text_output = tokenizer.decode(output_ids, skip_special_tokens=False)
# logger.info("text_output {}", text_output)
# if len(text_output) > prompt_len:
#     text_output = text_output[prompt_len - 1 :]
# logger.info("text_output {}", text_output)
text_output = text_output.replace("\n", "").replace("</s>", "").replace("  ", " ")
text_output = text_output.split("ASSISTANT: ")[-1]

hidden_states = []

hidden_states.append(out_projector(output_hidden_states))
logger.info("hidden_states {}", hidden_states[0].shape)
logger.info("seg_token_mask shape {}", seg_token_mask.shape)

[32m2023-09-28 17:09:54.176[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m32[0m - [1mprompt [INST] <<SYS>>
You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
<</SYS>>

<im_start><image><im_end>
segment the lake [/INST][0m
[32m2023-09-28 17:09:54.185[0m | [1mINFO    [0m | [36mone_model.model.llm.llava[0m:[36mencode_images[0m:[36m66[0m - [1mimages shape torch.Size([1, 3, 224, 224])[0m
[32m2023-09-28 17:09:54.212[0m | [1mINFO    [0m | [36mone_model.model.llm.llava[0m:[36mencode_images[0m:[36m69[0m - [1mimage_features shape torch.Size([1, 256, 5120])[0m
[32m2023-09-28 17:09:54.768[0m | [1mINFO    [0m | [36mone_model.model.llm.llava[0m:[36mencode_images[0m:[36m66[0m - [1mimages shape torch.Size([1, 3, 224, 224])[0m
[32m2023-09-28 17:09:54.793[0m | [1mINFO    [0m | [36mone_model.model.llm.llava

In [15]:
# from one_model.model.decoder import sam_decoder
# imp.reload(sam_decoder)

AssertionError: An object named 'sam' was already registered in 'Decoder' registry!

In [16]:
# test sam decoder
decoder_result = sam_decoder.forward(
        image_paths=[image_file],
        hidden_states=hidden_states,
        gt_masks=None,
        inference=True,
        seg_token_mask=seg_token_mask,
)

[32m2023-09-28 17:14:05.201[0m | [1mINFO    [0m | [36mone_model.model.decoder.sam_decoder[0m:[36mget_visual_embs_img_paths[0m:[36m84[0m - [1mpixel_value.shape torch.Size([3, 1024, 1024])[0m


In [17]:
decoder_result

{'pred_masks': [tensor([[[-1.1943e+01, -1.1939e+01, -1.1827e+01,  ..., -1.0311e+01,
            -1.0178e+01, -1.0174e+01],
           [-1.1939e+01, -1.1935e+01, -1.1823e+01,  ..., -1.0312e+01,
            -1.0181e+01, -1.0176e+01],
           [-1.1820e+01, -1.1817e+01, -1.1712e+01,  ..., -1.0350e+01,
            -1.0256e+01, -1.0253e+01],
           ...,
           [ 5.3950e-01,  5.4195e-01,  6.1592e-01,  ...,  3.0288e-01,
             1.3493e-01,  1.2936e-01],
           [ 2.4756e-03,  4.7638e-03,  7.3663e-02,  ..., -2.0243e-01,
            -3.4724e-01, -3.5204e-01],
           [-8.8530e-01, -8.8333e-01, -8.2416e-01,  ..., -1.1265e+00,
            -1.2560e+00, -1.2603e+00]]], device='cuda:0',
         grad_fn=<SelectBackward0>)],
 'gt_masks': None}

In [18]:
print("\n", {"prompt": prompt, "outputs": text_output}, "\n")
save_img = None
pred_masks = decoder_result["pred_masks"]
image_np = cv2.imread(image_file)
image_np = cv2.cvtColor(image_np, cv2.COLOR_BGR2RGB)

logger.info("pred_masks {}", len(pred_masks))
for i, pred_mask in enumerate(pred_masks):
    if pred_mask.shape[0] == 0:
        continue

    pred_mask = pred_mask.detach().cpu().numpy()[0]
    pred_mask = pred_mask > 0

    save_img = image_np.copy()
    save_img[pred_mask] = (
        image_np * 0.5
        + pred_mask[:, :, None].astype(np.uint8) * np.array([255, 0, 0]) * 0.5
    )[pred_mask]
image_name = Path(image_file).name
vis_save_path = "./vis"
if save_img is not None:
    save_path = f"{vis_save_path}/{image_name}"
    logger.info("save segment to {}", save_path)
    cv2.imwrite(save_path, save_img[:, :, ::-1])
logger.info("infer text out {}", text_output)

[32m2023-09-28 17:14:12.683[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m7[0m - [1mpred_masks 1[0m


[32m2023-09-28 17:14:12.742[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m24[0m - [1msave segment to ./vis/view.jpg[0m
[32m2023-09-28 17:14:12.768[0m | [1mINFO    [0m | [36m__main__[0m:[36m<module>[0m:[36m26[0m - [1minfer text out Sure,[SEG] .[0m



 {'prompt': '[INST] <<SYS>>\nYou are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.\n<</SYS>>\n\n<im_start><image><im_end>\nsegment the lake [/INST]', 'outputs': 'Sure,[SEG] .'} 

