In [None]:
!pwd

In [None]:
import os
from typing import Union, List
import copy
import PIL
from PIL import Image
import io
import itertools
import math
import numpy as np
import requests
from dataclasses import dataclass
import torch

from torchvision.transforms.functional import to_pil_image

import transformers

from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
from llava.model.builder import load_pretrained_model
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.mm_utils import process_images, load_image_from_base64, tokenizer_image_token, KeywordsStoppingCriteria
from llava import conversation as conversation_lib
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN


from attention_analyse import (
    load_image,
    transpose_list,
    make_image_grid,
    visualize_1d_energy,
    visualize_spatial_energy
)
device='cuda:1'



# ARGS

In [None]:
@dataclass
class ARGS():
    model_path="liuhaotian/llava-v1.5-7b"
    model_base=None
    model_name=None
    device=device
    multi_modal=True
    load_8bit=False
    load_4bit=False
args = ARGS() 

# helpers

# Analyser

In [None]:
class Analyser():
    is_multimodal=True
    def __init__(self, args):
        self.set_args(args)
        
    def set_args(self, args):
        self.args = args
        model_path = args.model_path
        if model_path.endswith("/"):
            model_path = model_path[:-1]

        if args.model_name is None:
            model_paths = model_path.split("/")
            if model_paths[-1].startswith('checkpoint-'):
                self.model_name = model_paths[-2] + "_" + model_paths[-1]
            else:
                self.model_name = model_paths[-1]
        else:
            self.model_name = args.model_name
        self.device = self.args.device
    
    def load_from_args(self):
        model_path = self.args.model_path
        model_base = self.args.model_base
        load_8bit = self.args.load_8bit
        load_4bit = self.args.load_4bit
        model_name = self.model_name
        self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
            model_path, model_base, model_name, load_8bit, load_4bit, device=self.device)
        self.is_multimodal = 'llava' in self.model_name.lower()

    def set_modules(self, model=None, tokenizer=None, image_processor=None):
        self.model = model if model is not None else self.model
        self.tokenizer = tokenizer if tokenizer is not None else self.tokenizer
        self.image_processor = image_processor if image_processor is not None else self.image_processor
    
    @torch.inference_mode()
    def analyse_attention(self, start_energy, attentions):
        from_seq_len = start_energy.shape[2]
        # energy_transforme_matrix = torch.diag(torch.ones(to_seq_len))
        # energy_transforme_matrix = torch.unsqueeze(energy_transforme_matrix, 0)
        # print(energy_transforme_matrix.shape)
        # return
        energies = []
        energy = start_energy.to(dtype=torch.float32)
        for layer_attention in attentions:
            layer_attention = layer_attention.to(dtype=energy.dtype)
            layer_attention = layer_attention.mean(1) # mean over multi heads
            
            if layer_attention.shape[1] != from_seq_len:
                raise ValueError('')
            
            # print(layer_attention.shape)
            # print(layer_attention.sum(2)) # all one
            # print(energy.shape)
            energy = energy @ layer_attention
            energies.append(energy)
            energy[:, :, 0] = 0
            energy = energy / energy.sum(2) 
            # mean, std = energy.mean(2, keepdim=True), energy.std(2, keepdim=True)
            # min_, max_ = mean-3*std, mean+3*std
            # print(min_, max_)
            # energy = energy.clip(min_, max_)
            # print(energy)
            # energy = torch.softmax(energy , dim=2) # soft arg max
            # print(energy)
            # energy = (energy - energy.mean(1)) / energy.std() + energy.mean(1)
            # print(energy.shape)
            # print('sum to 1:', energy.sum(2))
        return energies

    @torch.inference_mode()
    def run4attention(self, input_ids, images):
        tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
        images = process_images(images, image_processor, model.config)
        if type(images) is list:
            images = [image.to(self.model.device, dtype=torch.float16) for image in images]
        else:
            images = images.to(self.model.device, dtype=torch.float16)
    
        lm_model_out = model(
            input_ids=input_ids,
            images=images,
            output_attentions=True,
        )
        lm_attentions = lm_model_out.attentions
        
        vision_tower = model.get_vision_tower().vision_tower
        vision_model_out = vision_tower(images, output_attentions=True)
        vision_attentions = vision_model_out.attentions
        return lm_attentions, vision_attentions

    @torch.inference_mode()
    def generate(self, query, images, **params):
        tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
        
        # using conversation_lib to preprocess the text input
        conv = conversation_lib.conv_llava_v1.copy()
        user_input, assistant_output, target_text = query
        conv.append_message(conv.roles[0], (user_input, images[0]))
        # conv.append_message(conv.roles[1], assistant_output)
        prompt = conv.get_prompt()
        ori_prompt = prompt
        num_image_tokens = 0
        if images is not None and len(images) > 0 and self.is_multimodal:
            if len(images) > 0:
                if len(images) != prompt.count(DEFAULT_IMAGE_TOKEN):
                    raise ValueError("Number of images does not match number of <image> tokens in prompt")

                images = [load_image(image) for image in images]
                images = process_images(images, image_processor, model.config)

                if type(images) is list:
                    images = [image.to(self.model.device, dtype=torch.float16) for image in images]
                else:
                    images = images.to(self.model.device, dtype=torch.float16)

                replace_token = DEFAULT_IMAGE_TOKEN
                if getattr(self.model.config, 'mm_use_im_start_end', False):
                    replace_token = DEFAULT_IM_START_TOKEN + replace_token + DEFAULT_IM_END_TOKEN
                prompt = prompt.replace(DEFAULT_IMAGE_TOKEN, replace_token)

                num_image_tokens = prompt.count(replace_token) * model.get_vision_tower().num_patches
            else:
                images = None
            image_args = {"images": images}
        else:
            images = None
            image_args = {}

        temperature = float(params.get("temperature", 1.0))
        top_p = float(params.get("top_p", 1.0))
        max_context_length = getattr(model.config, 'max_position_embeddings', 2048)
        max_new_tokens = min(int(params.get("max_new_tokens", 256)), 1024)
        stop_str = params.get("stop", "</s>")
        do_sample = True if temperature > 0.001 else False

        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).to(self.device)
        keywords = [stop_str]
        stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
        max_new_tokens = min(max_new_tokens, max_context_length - input_ids.shape[-1] - num_image_tokens)

        return model.generate(
            inputs=input_ids,
            do_sample=do_sample,
            temperature=temperature,
            top_p=top_p,
            max_new_tokens=max_new_tokens,
            stopping_criteria=[stopping_criteria],
            use_cache=True,
            **image_args
        )
    
    def construct_target_indexs(self, input_ids, response_token_ids):
        b, seq_len = input_ids.shape
        select_feature = self.model.get_vision_tower().select_feature
        num_vis_tokens = self.model.get_vision_tower().num_patches if select_feature != 'cls_patch' else self.model.get_vision_tower().num_patches + 1
        batch_token_index, vis_token_index = torch.where(input_ids == IMAGE_TOKEN_INDEX)
        vis_token_index = vis_token_index
        assert torch.all(batch_token_index == torch.arange(b)), 'only support each instance containing one image '
        print('vision input start location', vis_token_index)
        # under a batch style
        input_ids_unsqueezed = input_ids.unsqueeze(2)
        response_token_ids_unsqueezed = response_token_ids.unsqueeze(1)
        batch_indexs, token_indexs = torch.any(input_ids_unsqueezed == response_token_ids_unsqueezed, dim=2).nonzero(as_tuple=True)
        token_indexs = token_indexs - 1 # what causes the model to generate the response tokens?
        token_indexs = torch.where(token_indexs > vis_token_index, token_indexs + num_vis_tokens - 1, token_indexs)
        from_indexs = (batch_indexs, token_indexs)
    
        to_indexs_vis = (torch.repeat_interleave(batch_token_index, num_vis_tokens), torch.arange(num_vis_tokens).unsqueeze(0).add(vis_token_index.repeat_interleave(b)).mT.flatten())
        # print( (input_ids == input_ids).nonzero(as_tuple=True))
        # print( (input_ids == input_ids).shape)
        batch_indexs, token_indexs = (input_ids == input_ids).nonzero(as_tuple=True)
        token_indexs = torch.where(token_indexs > vis_token_index, token_indexs + num_vis_tokens - 1, token_indexs)
        to_indexs_text = (batch_indexs, token_indexs)
        
        return from_indexs, to_indexs_vis, to_indexs_text

    # only single round analyse, supports 
    def analyse(self, query, images, deep_layer=-1, shallow_layer=0, mode='multiple'):
        tokenizer, model, image_processor = self.tokenizer, self.model, self.image_processor
        
        # using conversation_lib to preprocess the text input
        conv = conversation_lib.conv_llava_v1.copy()
        user_input, assistant_output, target_text = query
        conv.append_message(conv.roles[0], (user_input, images[0]))
        conv.append_message(conv.roles[1], assistant_output)
        prompt = conv.get_prompt()
        print('model input:', prompt)
    
        # 
        response_token_ids = tokenizer(target_text, return_tensors='pt', add_special_tokens=False).input_ids
        input_ids = tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0)
        
        # torch.cuda.empty_cache()
        lm_attentions, vision_attentions = self.run4attention(
            input_ids.to(device),
            images,
        )
        
        # lm_attentions = [attention.to(dtype=torch.float32) for attention in lm_attentions]
        # vision_attentions = [attention.to(dtype=torch.float32) for attention in vision_attentions]
        attentions = list(itertools.chain(lm_attentions, vision_attentions))
        
        first_attention = lm_attentions[0]
        b, first_num_heads, from_seq_len, to_seq_len = first_attention.shape
        
        print(to_seq_len)
        
        from_indexs, to_indexs_vis, to_indexs_text = self.construct_target_indexs(input_ids, response_token_ids)
        
        start_energy = torch.zeros((b, to_seq_len)).to(first_attention.device, first_attention.dtype)
        start_energy[from_indexs] = 1
        start_energy = start_energy.unsqueeze(1) # for batch process
        start_energy = start_energy / start_energy.sum(2, keepdim=True)

        assert deep_layer > shallow_layer or deep_layer + shallow_layer < 0, ''
        print(f'attention analysing in layers:deep_layer{deep_layer} - shallow_layer{shallow_layer}')

        if mode == 'multiple':
            lm_attentions = lm_attentions[deep_layer:shallow_layer:-1]
        elif mode == 'average':
            lm_attentions = lm_attentions[deep_layer:shallow_layer:-1]
            lm_attentions = [torch.stack(lm_attentions).mean(0)]
            
        energies = self.analyse_attention(start_energy, lm_attentions)
        vis_energies, all_energies = [], []
        # layers_instances_vis_pils = []
        # layers_instances_text_pils = []
        for i, energy in enumerate(energies):
            energy = energy[:, 0, :] # get rid of the vector dim  
            # print("sum to one", energy.sum(1)) # sum to one
            # mean, std = energy.mean(1, keepdim=True), energy.std(1, keepdim=True)
            # min_, max_ = mean-3*std, mean+3*std
            # print(min_, max_)
            # num_vis_tok = (to_indexs_vis[0]==0).sum()
            # vis_energy = [energy[i, to_indexs_vis[1][i * num_vis_tok],...] for i in range(b)]
            vis_energy = energy[:, to_indexs_vis[1]]
            text_energy = energy

            vis_energies.append(vis_energy)
            all_energies.append(text_energy)

        return vis_energies, all_energies


# print(args.model_path)
# analyser = Analyser(args)
# analyser.load_from_args()
# model, tokenizer, image_processor = analyser.model, analyser.tokenizer, analyser.image_processor

# temp_tokenizer = copy.deepcopy(tokenizer)
# temp_tokenizer.add_tokens("<image>")
# image_token_id = temp_tokenizer.added_tokens_encoder['<image>']
analyser = Analyser(args)
analyser.set_modules(model, temp_tokenizer, image_processor)



# Analyse with Model Output 

In [None]:
queries = [
    ['Base on this input image, tell me who is the author of the painting?', 'The painting is the famous Monalisa, and the author is Da Vinci', 'Da Vinci '],
    ["Base on this input image, tell me where it might been shot?", '', ''],
]
images = ['https://llava-vl.github.io/static/images/monalisa.jpg', 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png']

In [None]:
for query, image in zip(queries, images):
    res = analyser.generate(query, [image])
    res = torch.where(res == -200, image_token_id, res)
    print(temp_tokenizer.batch_decode(res))

# Analyse LM Model

In [None]:
# temp = torch.Tensor([1,2,3]).to(dtype=torch.uint8) 
# temp = temp * 257 # over flows
# print(temp)

# Do the analyse

In [None]:
temp = torch.arange(12).reshape(3,4)
print(temp[temp.nonzero()[0]])

In [None]:
temp.min(1, keepdim=True)

In [None]:
from matplotlib import colormaps

def convert_pilgray2pyplotpil(img_pil):
    from matplotlib import colormaps
    cm = plt.get_cmap('viridis')
    img_np = np.array(img_pil)
    img_np = cm(img_np)
    img_pil = Image.fromarray((img_np[:, :, :3] * 255).astype(np.uint8), 'RGB')
    return img_pil
    
from matplotlib import pyplot as plt
def visualize_spatial_energy(energy, shape=None):
    b, seq_len = energy.shape
    if shape is None:
        l = math.isqrt(seq_len)
        shape = (l, l)
        
    if math.prod(shape) != seq_len:
        raise ValueError('')

    energy = energy.reshape(b, *shape)
    # def to_pil_image(img):
    #     # basically this is what the to_pil_image in torch looks like for a 1 channel image
    #     # print(img.shape)
    #     # print(type(img))
    #     # print(img.dtype)
    #     # min_, max_ = img.min(), img.max()
    #     # img = (img - min_) / (min_ - max_)
    #     if img.ndimension()==2:
    #         img = img.unsqueeze(0)
    #     img = img.mul(255).byte()
    #     img = np.transpose(img.cpu().numpy(), (1, 2, 0))
    #     img = img[..., 0]
    #     print(img)
    #     img = Image.fromarray(img, mode='L')
    #     return img
        
    energy_map_pils = [convert_pilgray2pyplotpil(to_pil_image(e.clip(0, 1))) for e in energy]    
    return energy_map_pils


def visualize_1d_energy(energy):

    b, seq_len = energy.shape
    energy = energy.reshape(b, 1, seq_len)
    energy_map_pils = [convert_pilgray2pyplotpil(to_pil_image(e.clip(0, 1))) for e in energy]   
    return energy_map_pils

def replace_batch_zero_with_nonzero_min(energy, nonzero_min):
    energy = torch.where(energy == 0, torch.finfo(energy.dtype).max, energy)
    energy = torch.where(energy == torch.finfo(energy.dtype).max, nonzero_min, energy)
    return energy
    
def analyse_gap(analyser, queries, images, start, gap, mode='average'):
    for query, imgs in zip(queries, images):
        print('=' * 200)
        current = start
        while current - gap > 0:
            print('-' * 200)
            print(query)
            print(imgs)
            current = current - gap
            deep_layer = current + gap
            shallow_layer = current
            # analyser.analyse(query, imgs, deep_layer, shallow_layer, mode=mode)
            
            vis_energies, all_energies = analyser.analyse(query, imgs, deep_layer, shallow_layer, mode=mode)
            
            layers_instances_vis_pils, layers_instances_text_pils = [], []
            for vis_energy, energy in zip(vis_energies, all_energies):
                # fill with the none zero min
                temp = torch.where(energy == 0, torch.finfo(energy.dtype).max, energy)

                nonzero_min = temp.min(1, keepdim=True)[0]
                energy = replace_batch_zero_with_nonzero_min(energy, nonzero_min)
                vis_energy = replace_batch_zero_with_nonzero_min(vis_energy, nonzero_min)
                energy = energy.log()
                vis_energy = vis_energy.log()
                mean, std = energy.mean(1, keepdim=True), energy.std(1, keepdim=True)
                min_, max_ = mean-3*std, mean+3*std
                # min_, _ = torch.min(energy, keepdim=True, dim=1)
                # max_, _ = torch.max(energy, keepdim=True, dim=1)
                energy = (energy - min_) / (max_ - min_)
                vis_energy = (vis_energy - min_) / (max_ - min_) 
                energy = energy.clip(0, 1)
                vis_energy = vis_energy.clip(0, 1)
                
                layer_vis_pils = visualize_spatial_energy(vis_energy)
                layer_text_pils = visualize_1d_energy(energy) 
                layers_instances_vis_pils.append(layer_vis_pils)
                layers_instances_text_pils.append(layer_text_pils)
            
            instances_layers_vis_pils = transpose_list(layers_instances_vis_pils)
            instances_layers_text_pils = transpose_list(layers_instances_text_pils)
            
            display(make_image_grid(instances_layers_vis_pils[0], resize=128)) # index by batch
            display(make_image_grid( instances_layers_text_pils[0], cols=1, resize=(2000, 20)))

In [None]:

queries = [
    ['Base on this input image, tell me who is the author of the painting?', 'The author of the painting is Leonardo Da Vinci.', 'Leonardo Da Vinci'], # Model Output
    ['Base on this input image, tell me who is the author of the painting?', 'The author of the painting is Claude Monet.', 'Claude Monet'], # Injected Halu
    ['Base on this input image, tell me who is the author of the painting?', 'The painting is the famous Monalisa, and the author is Da Vinci', 'Da Vinci '], # Made up Output    
    
]
images = [['https://llava-vl.github.io/static/images/monalisa.jpg'],] * len(queries)
images = [list(map(load_image, i) ) for i in images]

display(images[0][0])
analyse_gap(analyser, queries, images, 16, 15, mode='multiple')

In [None]:
queries = [
    ['Base on this input image, tell me who is the author of the painting?', 'The painting is the famous Monalisa, and the author is Da Vinci', 'Da Vinci '],
    ["Base on this input image, tell me where it might been shot?", '', ''],
]
images = ['https://llava-vl.github.io/static/images/monalisa.jpg', 'https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png']

In [None]:
query = queries[0]
vis_energies, all_energies = analyser.analyse(query, [load_image(image) for image in images[:1]], 32, 0, mode='average')
# vis_energies, all_energies = analyser.analyse(query, [load_image(image) for image in images[:1]], 32, 0, mode='multiple')

layers_instances_vis_pils, layers_instances_text_pils = [], []
for vis_energy, energy in zip(vis_energies, all_energies):
    # energy = energy.to(dtype=torch.float32)
    # vis_energy = vis_energy.to(dtype=torch.float32)
    # ep = energy[energy.nonzero()].min()
    # energy = torch.where(energy == 0, ep, energy)
    # vis_energy = torch.where(vis_energy == 0, ep, vis_energy)
    # energy = energy.log()
    # vis_energy = vis_energy.log()
    mean, std = energy.mean(1, keepdim=True), energy.std(1, keepdim=True)
    min_, max_ = mean-3*std, mean+3*std
    # min_, _ = torch.min(energy, keepdim=True, dim=1)
    # max_, _ = torch.max(energy, keepdim=True, dim=1)
    # energy = energy.clip(min_, max_)
    # vis_energy = vis_energy.clip(min_, max_)
    energy = (energy - min_) / (max_ - min_)
    vis_energy = (vis_energy - min_) / (max_ - min_) 
    energy = energy.clip(0, 1)
    vis_energy = vis_energy.clip(0, 1)
    
    vis_energy = vis_energy * vis_energy.shape[1] * 128
    energy = energy * energy.shape[1] * 128
    
    layer_vis_pils = visualize_spatial_energy(vis_energy)
    layer_text_pils = visualize_1d_energy(energy)
    layers_instances_vis_pils.append(layer_vis_pils)
    layers_instances_text_pils.append(layer_text_pils)

instances_layers_vis_pils = transpose_list(layers_instances_vis_pils)
instances_layers_text_pils = transpose_list(layers_instances_text_pils)

display(make_image_grid(instances_layers_vis_pils[0], resize=128)) # index by batch
display(make_image_grid( instances_layers_text_pils[0], cols=1, resize=(2000, 20)))

In [None]:
# def analyse_gap(analyser, queries, images, start, gap, mode='average'):
#     for query, imgs in zip(queries, images):
#         print('=' * 200)
#         current = start
#         while current - gap > 0:
#             print('-' * 200)
#             print(query)
#             print(imgs)
#             current = current - gap
#             deep_layer = current + gap
#             shallow_layer = current
#             # analyser.analyse(query, imgs, deep_layer, shallow_layer, mode=mode)
            
#             vis_energies, all_energies = analyser.analyse(query, imgs, deep_layer, shallow_layer, mode=mode)
            
#             layers_instances_vis_pils, layers_instances_text_pils = [], []
#             for vis_energy, energy in zip(vis_energies, all_energies):
#                 mean, std = energy.mean(1, keepdim=True), energy.std(1, keepdim=True)
#                 min_, max_ = mean-3*std, mean+3*std
#                 layer_vis_pils = visualize_spatial_energy(vis_energy, min_=min_, max_=max_)
#                 layer_text_pils = visualize_1d_energy(energy, min_=min_, max_=max_)
#                 layers_instances_vis_pils.append(layer_vis_pils)
#                 layers_instances_text_pils.append(layer_text_pils)
            
#             instances_layers_vis_pils = transpose_list(layers_instances_vis_pils)
#             instances_layers_text_pils = transpose_list(layers_instances_text_pils)
            
#             display(make_image_grid(instances_layers_vis_pils[0], resize=128)) # index by batch
#             display(make_image_grid( instances_layers_text_pils[0], cols=1, resize=(2000, 20)))
# def visualize_spatial_energy(energy, min_, max_, shape=None):
#     b, seq_len = energy.shape
#     if shape is None:
#         l = math.isqrt(seq_len)
#         shape = (l, l)
        
#     if math.prod(shape) != seq_len:
#         raise ValueError('')

#     energy = (energy - min_) * 255 / (max_ - min_)
#     energy = energy.clip(0, 255)
#     energy = energy.reshape(b, *shape)
#     # energy = energy.expand(2, -1, -1).unsqueeze(1)
#     # print(energy.shape)
    
#     energy_map_pils = [to_pil_image(e) for e in energy]
    
#     return energy_map_pils

# def visualize_1d_energy(energy, min_, max_):

#     b, seq_len = energy.shape
        
#     energy = (energy - min_) * 255 / (max_ - min_)
#     energy = energy.clip(0, 255)
#     energy = energy.reshape(b, 1, seq_len)
    
#     energy_map_pils = [to_pil_image(e) for e in energy]
    
#     return energy_map_pils

# queries = [
#     ['Base on this input image, tell me who is the author of the painting?', 'The author of the painting is Leonardo Da Vinci.', 'Da Vinci'], # Model Output
#     ['Base on this input image, tell me who is the author of the painting?', 'The painting is the famous Monalisa, and the author is Da Vinci', 'Da Vinci '], # Made up Output    
#     ['Base on this input image, tell me who is the author of the painting?', 'The author of the painting is Leonardo Monnet.', 'Monnet'], # Injected Halu
    
# ]
# images = [['https://llava-vl.github.io/static/images/monalisa.jpg'],] * len(queries)
# images = [list(map(load_image, i) ) for i in images]

# # display(images[0][0])
# analyse_gap(analyser, queries, images, 32, 16, mode='average')

# More hacked Outputs

In [None]:

queries = [
    ('what is this?', 'Monalisa', 'Monalisa'),
    ['Base on this input image, tell me who is the author of the painting?', 'The painting is the famous Monalisa, and the author is Da Vinci', 'Da Vinci '],
    ['Base on this input image, tell me who is the author of the painting?', 'The painting is the famous Monalisa, and the author is Da Vinci', 'Monalisa '],    
    ['Base on this input image, tell me who is the author of the painting?', 'The painting is the famous Monalisa, and the author is Monnet', 'Monnet '],
    ['Base on this input image, tell me who is the author of the painting?', 'The painting is the famous Monalisa, and the author is Monnet', 'Monalisa '],
]
images = [['https://llava-vl.github.io/static/images/monalisa.jpg'],] * len(queries)
images = [list(map(load_image, i) ) for i in images]

# display(images[0][0])
analyse_gap(analyser, queries, images, 16, 5)

In [None]:
# display(images[0][0])
analyse_gap(analyser, queries, images, 16, 3)

In [None]:
images[0][0]

In [None]:
# res = 'In this case, the white four-door pickup truck seems parked on a street in a downtown city area. Specifically at a street corner near the edge of a red building. There is a sidewalk next to the parked truck, and a potted plant is visible nearby as well.',
# queries = [
#     [
#         'Base on this input image, tell me where it might been shot?', 
#         res,
#         'truck'
#      ],
#     ['Base on this input image, tell me where it might been shot?', 'The photo contains a car, it might been shot on a street', 'car'],
#     ['Base on this input image, tell me where it might been shot?', 'The photo contains a truck, it might been shot on a street', 'truck'],
#     ['Base on this input image, tell me where it might been shot?', 'The photo contains a bicycle, it might been shot on a street', 'bicycle'],
#     ['Base on this input image, tell me where it might been shot?', 'The photo contains a truck, it might been shot on a street', 'street'],
# ]
# images = [['https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png'],] * len(queries)
# images = [list(map(load_image, i) ) for i in images]
# # display(images[0][0])
# analyse_gap(analyser, queries, images, 16, 5, mode='multiple')

question = 'Base on this input image, tell me where it might been shot?'
res = 'In this case, the white four-door pickup truck seems parked on a street in a downtown city area. Specifically at a street corner near the edge of a red building. There is a sidewalk next to the parked truck, and a potted plant is visible nearby as well.'
queries = [
    [question, res, 'truck'], # Model Output    
    [question, res, 'potted plant'], # Model Output    
]
images = [['https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png'],] * len(queries)
images = [list(map(load_image, i) ) for i in images]

# display(images[0][0])
# analyse_gap(analyser, queries, images, 16, 3, mode='multiple')
analyse_gap(analyser, queries, images, 32, 5, mode='average')

# analyse vision 

In [None]:
vision_model = model.get_vision_tower()
vision_model.select_feature