### Configure:

In [1]:
low_gpu_memory_optimization = True

### Install dependencies:

In [None]:
!git clone https://github.com/SkunkworksAI/BakLLaVA.git
%cd BakLLaVA
!pip install -e .
!pip uninstall transformers -y
!pip install transformers==4.34.0

### Run:

In [2]:
from transformers import AutoConfig, AutoTokenizer
from llava.model.language_model.llava_mistral import LlavaMistralForCausalLM
from huggingface_hub import notebook_login
from llava.mm_utils import tokenizer_image_token, KeywordsStoppingCriteria

from PIL import Image
import requests
from io import BytesIO

from llava.conversation import conv_templates, SeparatorStyle
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN

import torch

[2023-10-20 08:52:05,159] [INFO] [real_accelerator.py:110:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [None]:
notebook_login()

In [None]:
model_path = "SkunkworksAI/BakLLaVA-1"

cfg_pretrained = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
if low_gpu_memory_optimization:
    model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, load_in_4bit=True, config=cfg_pretrained, device_map="cpu", torch_dtype="auto")
else: 
    model = LlavaMistralForCausalLM.from_pretrained(model_path, config=cfg_pretrained)
    model.to("cuda")

In [82]:
model.dtype

torch.bfloat16

In [87]:
mm_use_im_start_end = getattr(model.config, "mm_use_im_start_end", False)
mm_use_im_patch_token = getattr(model.config, "mm_use_im_patch_token", True)
if mm_use_im_patch_token:
    tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
    tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True)
model.resize_token_embeddings(len(tokenizer))

vision_tower = model.get_vision_tower()
if not vision_tower.is_loaded:
    vision_tower.load_model()
vision_tower.to(device='cpu', dtype=torch.bfloat16)
image_processor = vision_tower.image_processor

if hasattr(model.config, "max_sequence_length"):
    context_len = model.config.max_sequence_length
else:
    context_len = 2048

In [84]:
def load_image(image_file):
    if image_file.startswith('http') or image_file.startswith('https'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


# image = load_image("https://t4.ftcdn.net/jpg/00/97/58/97/360_F_97589769_t45CqXyzjz0KXwoBZT9PRaWGHRk5hQqQ.jpg")
image = load_image("https://cdn.discordapp.com/attachments/1096822099345145969/1164641565550067852/heart_1.png?ex=6543f3fb&is=65317efb&hm=448cb26e19c141871e776af98077c4c1e97a8f29b96916ab671e5010c00e3625&")
# image = load_image("https://cdni.pornpics.com/1280/7/139/41415012/41415012_101_feaf.jpg")  # naked pussy
# image = load_image("https://cdn.dribbble.com/userupload/10885354/file/original-2f92d4d02be0989f11d7a6f709599324.png")

if low_gpu_memory_optimization:
    image_tensor = image_processor.preprocess(image, return_tensors='pt', device="cpu")['pixel_values']
else:
    image_tensor = image_processor.preprocess(image, return_tensors='pt')['pixel_values']

In [85]:
query = "Describe this image"
# query = "Does the woman show her naked pussy?"
# query = "Describe the app"

if model.config.mm_use_im_start_end:
    query = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + query
else:
    query = DEFAULT_IMAGE_TOKEN + '\n' + query

conv = conv_templates["llava_v1"].copy()

conv.append_message(conv.roles[0], query)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)

stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
keywords = [stop_str]
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)

In [None]:
with torch.inference_mode():
    output_ids = model.generate(
        input_ids=input_ids,
        images=image_tensor,
        do_sample=True,
        temperature=0.2,
        max_new_tokens=1024,
        use_cache=True,
        stopping_criteria=[stopping_criteria])

input_token_len = input_ids.shape[1]
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
if n_diff_input_output > 0:
    print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
outputs = outputs.strip()
if outputs.endswith(stop_str):
    outputs = outputs[:-len(stop_str)]
outputs = outputs.strip()
print(outputs)

### Streaming:

In [88]:
from transformers import TextStreamer


def on_finalized_text(text: str, stream_end: bool = False):
    print(text.replace(stop_str, ""), end="")


streamer = TextStreamer(tokenizer, skip_prompt=True)
streamer.on_finalized_text = on_finalized_text


# with torch.inference_mode():
output_ids = model.generate(
    input_ids=input_ids,
    images=image_tensor,
    do_sample=True,
    temperature=0.2,
    max_new_tokens=1024,
    use_cache=True,
    
    stopping_criteria=[stopping_criteria], streamer=streamer)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.


AssertionError: 

In [None]:
!r,