In [1]:
import gradio as gr
import cv2
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
import matplotlib.pyplot as plt
import numpy as np
LAYER_NUM = 32
HEAD_NUM = 32
HEAD_DIM = 128
HIDDEN_DIM = HEAD_NUM * HEAD_DIM

In [2]:
import requests
from PIL import Image
from transformers.image_transforms import (
    convert_to_rgb,
    get_resize_output_image_size,
    resize,
    center_crop)
from transformers.image_utils import (
    infer_channel_dimension_format,
    to_numpy_array)
def normalize(vector):
    max_value = max(vector)
    min_value = min(vector)
    vector1 = [(x-min_value)/(max_value-min_value) for x in vector]
    vector2 = [x/sum(vector1) for x in vector1]
    return vector2
def get_bsvalues(vector, model, final_var):
    vector = vector * torch.rsqrt(final_var + 1e-6)
    vector_rmsn = vector * model.language_model.model.norm.weight.data
    vector_bsvalues = model.language_model.lm_head(vector_rmsn).data
    return vector_bsvalues
def get_prob(vector):
    prob = torch.nn.Softmax(-1)(vector)
    return prob

In [3]:
model_id = "../scratch/save_models/llava-7b"
processor = AutoProcessor.from_pretrained(model_id)
model = LlavaForConditionalGeneration.from_pretrained(model_id)
model.eval()
model.cuda()

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/6 [00:00<?, ?it/s]

LlavaForConditionalGeneration(
  (vision_tower): CLIPVisionModel(
    (vision_model): CLIPVisionTransformer(
      (embeddings): CLIPVisionEmbeddings(
        (patch_embedding): Conv2d(3, 1024, kernel_size=(14, 14), stride=(14, 14), bias=False)
        (position_embedding): Embedding(577, 1024)
      )
      (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
      (encoder): CLIPEncoder(
        (layers): ModuleList(
          (0-23): 24 x CLIPEncoderLayer(
            (self_attn): CLIPAttention(
              (k_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (v_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (q_proj): Linear(in_features=1024, out_features=1024, bias=True)
              (out_proj): Linear(in_features=1024, out_features=1024, bias=True)
            )
            (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (mlp): CLIPMLP(
              (activation_fn): Quick

In [4]:
def vqa_predict1(image, prompt):
    inputs = processor(text=prompt, images=image, return_tensors="pt")
    inputs.to("cuda")
    outputs = model(**inputs)
    outputs_probs = get_prob(outputs["logits"][0][-1])
    outputs_probs_sort = torch.argsort(outputs_probs, descending=True)
    final_var = outputs[2][LAYER_NUM-1][4][0][-1].pow(2).mean(-1, keepdim=True)

    resample = 3
    shortest_edge = 336
    crop_size = {"height": 336, "width": 336}
    image_convert = convert_to_rgb(image)
    image_numpy = to_numpy_array(image_convert)
    input_data_format = infer_channel_dimension_format(image_numpy)
    output_size = get_resize_output_image_size(image_numpy, size=336,
                default_to_square=False, input_data_format=input_data_format)
    image_resize = resize(image_numpy, output_size, resample=resample, input_data_format=input_data_format)
    image_center_crop = center_crop(image_resize, size=(crop_size["height"], crop_size["width"]), input_data_format=input_data_format)
    img3 = Image.fromarray(image_center_crop)
    img3.save("tmp.png")

    predict_index = outputs_probs_sort[0]
    all_head_increase = []
    for test_layer in range(LAYER_NUM):
        cur_layer_input = outputs[2][test_layer][0][0]
        cur_v_heads = outputs[2][test_layer][5][0]
        cur_attn_o_split = model.language_model.model.layers[test_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)
        cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)
        cur_attn_subvalues_head_sum = torch.sum(cur_attn_subvalues_headrecompute, 0)
        cur_layer_input_last = cur_layer_input[-1]
        origin_prob = torch.log(get_prob(get_bsvalues(cur_layer_input_last, model, final_var))[predict_index])
        cur_attn_subvalues_head_plus = cur_attn_subvalues_head_sum + cur_layer_input_last
        cur_attn_plus_probs = torch.log(get_prob(get_bsvalues(
                cur_attn_subvalues_head_plus, model, final_var))[:, predict_index])
        cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob
        for i in range(len(cur_attn_plus_probs_increase)):
            all_head_increase.append([str(test_layer)+"_"+str(i), round(cur_attn_plus_probs_increase[i].item(), 4)])
    all_head_increase_sort = sorted(all_head_increase, key=lambda x:x[-1])[::-1]

    test_layer, head_index = all_head_increase_sort[0][0].split("_")
    test_layer, head_index = int(test_layer), int(head_index)
    cur_layer_input = outputs[2][test_layer][0][0]
    cur_v_heads = outputs[2][test_layer][5][0]
    cur_attn_o_split = model.language_model.model.layers[test_layer].self_attn.o_proj.weight.data.T.view(HEAD_NUM, HEAD_DIM, -1)
    cur_attn_subvalues_headrecompute = torch.bmm(cur_v_heads, cur_attn_o_split).permute(1, 0, 2)
    cur_attn_subvalues_headrecompute_curhead = cur_attn_subvalues_headrecompute[:, head_index, :]
    cur_layer_input_last = cur_layer_input[-1]
    origin_prob = torch.log(get_prob(get_bsvalues(
        cur_layer_input_last, model, final_var))[predict_index])
    cur_attn_subvalues_headrecompute_curhead_plus = cur_attn_subvalues_headrecompute_curhead + cur_layer_input_last
    cur_attn_plus_probs = torch.log(get_prob(get_bsvalues(
        cur_attn_subvalues_headrecompute_curhead_plus, model, final_var))[:, predict_index])
    cur_attn_plus_probs_increase = cur_attn_plus_probs - origin_prob
    head_pos_increase = cur_attn_plus_probs_increase.tolist()
    curhead_increase_scores = head_pos_increase[5:581]
    increase_scores_normalize = normalize(curhead_increase_scores)

    attn_scores_all = torch.tensor([0.0]*576).to("cuda")
    for layer_index in range(LAYER_NUM):
        for head_index in range(HEAD_NUM):
            attn_scores = outputs[2][layer_index][7][0][head_index][-1][5:581]
            attn_scores_all += attn_scores
    attn_scores_all = attn_scores_all/1024.0

    demo_img = plt.imread("tmp.png")
    demo_img_h, demo_img_w, demo_img_c = demo_img.shape
    demo_img_inc = np.array(increase_scores_normalize).reshape((24, 24))
    demo_img_inc = cv2.resize(demo_img_inc,
                              dsize=(demo_img_w, demo_img_h),
                              interpolation=cv2.INTER_CUBIC)
    demo_img_att_avg = np.array(attn_scores_all.tolist()).reshape((24, 24))
    demo_img_att_avg = cv2.resize(demo_img_att_avg,
                              dsize=(demo_img_w, demo_img_h),
                              interpolation=cv2.INTER_CUBIC)

    plt.figure(figsize=(60, 30))
    plt.subplot(1, 3, 1)
    plt.imshow(demo_img)
    plt.axis("off")
    plt.title("input image", fontsize=40)
    plt.subplot(1, 3, 2)
    plt.imshow(demo_img)
    plt.imshow(demo_img_inc, alpha=0.8, cmap="gray")
    plt.axis("off")
    plt.title("log probability increase", fontsize=40)
    plt.subplot(1, 3, 3)
    plt.imshow(demo_img)
    plt.imshow(demo_img_att_avg, alpha=0.8, cmap="gray")
    plt.axis("off")
    plt.title("avg attention score", fontsize=40)
    plt.savefig("tmp1.png")
    plt.close()
    image_show = Image.open("tmp1.png")

    prediction = processor.decode(outputs_probs_sort[0])
    top_heads = "important heads: "+str([(x[0], round(x[1],4)) for x in all_head_increase_sort[:10]])
    important_image_batches = image_show

    return prediction, top_heads, important_image_batches

In [5]:
#download the test example and upload: http://images.cocodataset.org/val2017/000000219578.jpg
#please use prompts similar to this to fit Llava: "USER: <image>\nWhat is the color of the dog?\nASSISTANT: The color of the dog is"

gr.Interface(fn=vqa_predict1, inputs=["image", "text"], outputs=["text", "text", "image"]).launch(share=True)

Running on local URL:  http://127.0.0.1:7860


--------


Running on public URL: https://66ac6212a844030a18.gradio.live

This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)


