In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
from PIL import Image
import torch
from transformers import (
    LlavaNextProcessor,
    LlavaNextForConditionalGeneration
)
from IPython.display import display
from typing import List, Dict
import plotly.graph_objects as go
import numpy as np
import textwrap

In [None]:
## configurations
model_id : str = "llava-hf/llava-v1.6-mistral-7b-hf"
seed : int = 42
max_new_tokens: int = 500


torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

## Load the models and processor


In [None]:
processor = LlavaNextProcessor.from_pretrained(model_id)
model = LlavaNextForConditionalGeneration.from_pretrained(
            model_id,
            torch_dtype=torch.float16,  # <--bf16 if supported
            low_cpu_mem_usage=True,
            device_map="auto",
            # quantization_config=quantization_config,
        )

### 1. Test the model inference

In [None]:
test_sample = Image.open("test_sample.png")
display(test_sample)

In [None]:
conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "Describe the image in detail."},
        ],
    },
]
## prepare input
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(prompt, test_sample, return_tensors="pt").to("cuda:0")

## now generate
output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
output = processor.decode(output[0], skip_special_tokens=True)
output = output.split("[/INST]")[-1].strip()
print(f"Image description: {output}")

# Multimodal Uncertainty Estimation

## 1. Question-induced Hallucination test

In [None]:
## Question induced hallucination
# in the image, the ping pong table is not present. Let's check!

question_prompt = "Is the target picture hanging on the wall in the room with the ping pong table?"

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": question_prompt },
        ],
    },
]
## prepare input
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(prompt, test_sample, return_tensors="pt").to("cuda:0")

## now generate
output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)
output = processor.decode(output[0], skip_special_tokens=True)
output = output.split("[/INST]")[-1].strip()

# Here I got: "Yes, the target picture is hanging on the wall in the room with the ping pong table."
# which, of course, is not correct.
print(f"Question-induced hallucination? The model's answer is :---> {output}") 

## 2. Multimodal Uncertainty Estimation

In [None]:
question_prompt = "Is the target picture hanging on the wall in the room with the ping pong table? You must answer only with 'Yes','No', or '?=I don't know'."

conversation = [
    {
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": question_prompt },
        ],
    },
]
## prepare input
prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
inputs = processor(prompt, test_sample, return_tensors="pt").to("cuda:0")

## now generate
output = model.generate(**inputs, 
                        max_new_tokens=max_new_tokens, 
                        do_sample=False,                 
                        output_logits=True,  # <-- we want the logits
                        return_dict_in_generate=True 
                        )

output_decoded = processor.decode(output["sequences"][0], skip_special_tokens=True)
output_string = output_decoded.split("[/INST]")[-1].strip()

logits = output["logits"]
print("Answer of the model: ", output_string) # <--- the model still answers "Yes"


########## Now let's check the logits
for timestep_t in range(len(logits)):
    logits_time_t = logits[timestep_t][0]  # -> shape num_vocab
    logits_time_t = torch.softmax(logits_time_t, dim=-1)

    # get tokens likelihood
    top_k = 3
    topk_scores, topk_indices = torch.topk(logits_time_t, top_k)
    topk_indices = topk_indices.cpu().numpy()

    token_likelihood : List[Dict] = []
    for i in range(top_k):
        decoded: str = processor.decode([topk_indices[i]], skip_special_tokens=True)
        likelihood = round(logits_time_t[topk_indices[i]].cpu().item(), 3)
        token_likelihood.append((decoded, likelihood))

    break

print(f"The top-{top_k} tokens are: {token_likelihood}")

### Compute the normalized entropy estimation
tau = 0.75
token_likelihood = token_likelihood
logits = torch.tensor([logit for _ , logit in token_likelihood])
print(f"Logits: {logits}")
entropy = -torch.sum(logits * torch.log(logits))

entropy_max = torch.log(torch.tensor(len(logits)))
entropy_normalized = entropy / entropy_max

model_certainty = "Certain" if entropy_normalized <= tau else "Uncertain"
print(f"Model Certainty: Entropy normalized: {entropy_normalized} - Model Certainty: {model_certainty} (tau={tau})")


#### Create an image for clarity
fig = go.Figure()
tokens = [tokens for tokens , _ in token_likelihood]
scores = [scores for _ , scores in token_likelihood]
fig.add_trace(go.Bar(x=tokens, y=scores, text=scores, textposition='auto'))

img_array = np.array(test_sample)
fig.add_layout_image(
    dict(
        source=test_sample,
        xref="paper",
        yref="paper",
        x=1.05,
        y=1,
        sizex=0.6,
        sizey=1,
        yanchor="top",
        sizing="contain",
        layer="above"
    )
)

wrapped_question = "<br>".join(textwrap.wrap(f"Question: {question_prompt}", width=100))

fig.update_layout(
    xaxis=dict(domain=[0, 0.7]),
      yaxis=dict(domain=[0, 0.85]),
    annotations=[
        dict(
            text=f"<b>Question:</b> {wrapped_question}",
            xref="paper", yref="paper",
            x=0, y=1.15,
            showarrow=False,
            font=dict(size=14)
        ),
        dict(
            text=f"<b>Result after Normalized-entropy estimation technique:</b> {model_certainty}",
            xref="paper", yref="paper",
            x=0, y=1.0,
            showarrow=False,
            font=dict(size=14, color="gray"),
            align="left"
        )
    ],
    margin=dict(l=50, r=400, t=70, b=50),
    width=1000,
    height=500,
)
fig.show()
