# Open VLM Demo

In [None]:
%load_ext autoreload
%autoreload 2

import ipywidgets as widgets 
import matplotlib.pyplot as plt 
from transformers import LlavaNextForConditionalGeneration, LlavaNextProcessor
import torch
from PIL import Image

from driver_stalker.paths import get_repo_dpath
from driver_stalker.utils.fs import read_image, get_image_fpaths_from_folder

## Constants

In [None]:
PROJECT_DPATH = get_repo_dpath() 
DATA_DPATH = PROJECT_DPATH / "data"

DATASET_DPATH = DATA_DPATH / "drowsiness_dataset"
assert DATASET_DPATH.exists()

## Model Loading

In [None]:
processor = LlavaNextProcessor.from_pretrained("tiiuae/falcon-11B-vlm", tokenizer_class='PreTrainedTokenizerFast')
model = LlavaNextForConditionalGeneration.from_pretrained("tiiuae/falcon-11B-vlm", torch_dtype=torch.bfloat16)

## Demo

In [None]:
no_yawn_paths = get_image_fpaths_from_folder(DATASET_DPATH / "no_yawn")
yawn_paths = get_image_fpaths_from_folder(DATASET_DPATH / "yawn")

img_paths = no_yawn_paths + yawn_paths

len(no_yawn_paths), len(yawn_paths)

In [None]:
from PIL import Image
import requests
from transformers import AutoProcessor, LlavaNextForConditionalGeneration

model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
processor = AutoProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")

prompt = "[INST] <image>\nWhat is shown in this image? [/INST]"
url = "https://www.ilankelman.org/stopsigns/australia.jpg"
image = Image.open(requests.get(url, stream=True).raw)
inputs = processor(images=image, text=prompt, return_tensors="pt")

generate_ids = model.generate(**inputs, max_length=30)
processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]


In [None]:
@widgets.interact
def show(index=widgets.IntSlider(value=0, min=0, max=len(img_paths) - 1)):
    img_fpath = img_paths[index]
    image = read_image(img_fpath)

    instruction = 'Describe the level of drowsiness of ths person'
    prompt = f"""User:<image>\n{instruction} Falcon:"""

    import requests
    
    url = "http://images.cocodataset.org/val2017/000000039769.jpg"
    cats_image = Image.open(requests.get(url, stream=True).raw)
    instruction = 'Write a long paragraph about this picture.'
    prompt = f"""User:<image>\n{instruction} Falcon:"""


    inputs = processor(text=prompt, images=image, return_tensors="pt")#, padding=False)

    output = model.generate(**inputs, max_length=30)#, max_new_tokens=256)
    # prompt_length = inputs['input_ids'].shape[1]
    # generated_captions = processor.decode(output[0], skip_special_tokens=True).strip()
    # print(generated_captions)


    plt.figure(figsize=(10, 6))
    plt.imshow(image)
    plt.title(img_fpath.relative_to(DATASET_DPATH))
    plt.show()