In [1]:
import numpy
import torch
import transformers
from PIL import Image

def image_parser(image_file):
    out = image_file.split(",")
    return out


def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    else:
        image = Image.open(image_file).convert("RGB")
    image.resize((300, 300))
    return image


def load_images(image_files):
    out = []
    for image_file in image_files:
        image = load_image(image_file)
        out.append(image)
    return out

In [2]:
import gc

if 'pipe' in locals() or 'pipe' in globals():
    print('deleting pipe')
    del pipe
    gc.collect()
    torch.cuda.empty_cache()

if 'model' in locals() or 'model' in globals():
    print('deleting model')
    del model
    gc.collect()
    torch.cuda.empty_cache()

In [3]:
from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path

model_path = "liuhaotian/llava-v1.5-7b"

args = type('Args', (), {
    "model_path": model_path,
    "model_base": None,
    "model_name": get_model_name_from_path(model_path),
    #"query": prompt,
    "conv_mode": None,
    #"image_file": image_file,
    #"sep": ",",
    "temperature": 0,
    "top_p": None,
    "num_beams": 1,
    "max_new_tokens": 512,
})()

from llava.utils import disable_torch_init
disable_torch_init()

model_name = get_model_name_from_path(args.model_path)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    args.model_path, args.model_base, model_name, offload_folder="offload", load_4bit=True
)
print(model.device)
print(model.dtype)
print('all loaded')

You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


cuda:0
torch.float16
all loaded


In [None]:
#prompt = "Describe the image."
prompt = "<image> Describe a blend of both images."
#image_file = "https://llava-vl.github.io/static/images/view.jpg"
# the smaller image does not cause OOM error
#image_file = "/home/dwangz/Pictures/Screenshots/small.png"
#image_file = "/home/dwangz/Downloads/calder.png"
image_file = "/home/dwangz/Downloads/calder.png,/home/dwangz/Downloads/blue_orange_green.png"
#image_file = "/home/dwangz/Downloads/calder.png,/home/dwangz/Pictures/Screenshots/small.png"

qs = prompt
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
import re
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
    if model.config.mm_use_im_start_end:
        qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
    else:
        qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
    if model.config.mm_use_im_start_end:
        qs = image_token_se + "\n" + qs
    else:
        qs = DEFAULT_IMAGE_TOKEN + "\n" + qs

if "llama-2" in model_name.lower():
    conv_mode = "llava_llama_2"
elif "mistral" in model_name.lower():
    conv_mode = "mistral_instruct"
elif "v1.6-34b" in model_name.lower():
    conv_mode = "chatml_direct"
elif "v1" in model_name.lower():
    conv_mode = "llava_v1"
elif "mpt" in model_name.lower():
    conv_mode = "mpt"
else:
    conv_mode = "llava_v0"

if args.conv_mode is not None and conv_mode != args.conv_mode:
    print(
        "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format(
            conv_mode, args.conv_mode, args.conv_mode
        )
    )
else:
    args.conv_mode = conv_mode

from llava.conversation import conv_templates, SeparatorStyle
conv = conv_templates[args.conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
print(prompt)

from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
)

image_files = image_parser(image_file)
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
    images,
    image_processor,
    model.config
).to(model.device, dtype=torch.float16)

input_ids = (
    tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
    .unsqueeze(0)
    #.cuda()
)

with torch.inference_mode():
    print('generating...')
    output_ids = model.generate(
        input_ids,
        images=images_tensor,
        image_sizes=image_sizes,
        do_sample=True if args.temperature > 0 else False,
        temperature=args.temperature,
        top_p=args.top_p,
        num_beams=args.num_beams,
        max_new_tokens=args.max_new_tokens,
        use_cache=True,
    )

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
print(outputs)

In [4]:
import gc

if 'pipe' in locals() or 'pipe' in globals():
    print('deleting pipe')
    del pipe
    gc.collect()
    torch.cuda.empty_cache()

if 'model' in locals() or 'model' in globals():
    print('deleting model')
    del model
    gc.collect()
    torch.cuda.empty_cache()

deleting model


In [5]:
from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float32)
print('pipe created')
pipe.to('cuda')
print(pipe.device)

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

pipe created
cuda:0


In [None]:
prompt = outputs
#prompt = "sunflowers"
image = pipe(prompt).images[0]
image.show()