In [None]:
import requests
from PIL import Image
import matplotlib.pyplot as plt
import torch
from transformers import AutoProcessor, LlavaForConditionalGeneration
import seaborn as sns
import matplotlib.colors as Colormap
from matplotlib.colors import LogNorm


In [None]:
def visualize_attention(multihead_attention,output_path="atten_map_1.png",title=""):
    # Assuming the input is a numpy array of shape (1, num_heads, n_tokens, n_tokens)
    # First, we average the attention scores over the multiple heads
    averaged_attention = torch.mean(multihead_attention, axis=1)[0].float()# Shape: (n_tokens, n_tokens)
    
    # pooling the attention scores  with stride 20
    averaged_attention = torch.nn.functional.avg_pool2d(averaged_attention.unsqueeze(0).unsqueeze(0), 20, stride=20).squeeze(0).squeeze(0)
    
    cmap = plt.cm.get_cmap("viridis")
    plt.figure(figsize=(5, 5),dpi=400)

    # Log normalization
    log_norm = LogNorm(vmin=0.0007, vmax=averaged_attention.max())

    averaged_attention = averaged_attention.cpu().numpy()


    ax = sns.heatmap(averaged_attention,
                cmap=cmap,  # custom color map
                norm=log_norm,  # 
                # cbar_kws={'label': 'Attention score'},
                )
    
    # remove the x and y ticks
    
    # replace the x and y ticks with string

    x_ticks = [str(i*20) for i in range(0,averaged_attention.shape[0])]
    y_ticks = [str(i*20) for i in range(0,averaged_attention.shape[0])]
    ax.set_xticks([i for i in range(0,averaged_attention.shape[0])])
    ax.set_yticks([i for i in range(0,averaged_attention.shape[0])])
    ax.set_xticklabels(x_ticks)
    ax.set_yticklabels(y_ticks)

    # change the x tinks font size
    plt.xticks(fontsize=3)
    plt.yticks(fontsize=3)
    
    # make y label vertical
    plt.yticks(rotation=0)
    plt.xticks(rotation=90)     
    
    plt.title(title)
    # # tight layout
    # plt.savefig(output_path, bbox_inches='tight')
    # # plt.show()

    # top_five_attentions = []
    # for row in averaged_attention:
    #     # Use torch.topk to get the top 5 values and their indices
    #     top_values, top_indices = torch.topk(row, 10)
    #     # Convert to lists and append to the overall list
    #     top_five_line = list(zip(top_indices.tolist(), top_values.tolist()))
    #     top_five_attentions.append(top_five_line)
        
    # return top_five_attentions,averaged_attention 

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import kendalltau

def analyze_vit_and_llm_attentions(batch_id, vit_attention, llm_attention, vit_to_llm_mapping):
    
    vit_to_llm_mapping = vit_to_llm_mapping[batch_id]
    llm_attention = torch.stack([i[batch_id] for i in llm_attention])  # [num_layers, num_heads, num_context, num_context]
    vit_attention = torch.stack([i[batch_id] for i in vit_attention])  # [num_layers, num_heads, num_patches, num_patches]

    num_tokens = vit_to_llm_mapping.shape[-1]

    llm_attention = llm_attention.sum(dim=1)  # [num_layers, num_context, num_context]
    vit_attention = vit_attention.sum(dim=1)  # [num_layers, num_patches, num_patches]

    llm_attention = llm_attention.sum(dim=1)  # [num_layers, num_context]
    vit_attention = vit_attention[:,0,-num_tokens:]  # [num_layers, num_tokens]

    llm_attention = llm_attention[:, vit_to_llm_mapping]


    vit_num_layers = vit_attention.shape[0]
    llm_num_layers = llm_attention.shape[0]

    # to cpu and numpy
    vit_attention = vit_attention.cpu().numpy()
    llm_attention = llm_attention.cpu().numpy()

    kendall_matrix = np.zeros((vit_num_layers, llm_num_layers))

    for vit_layer in range(vit_num_layers):
        for llm_layer in range(llm_num_layers):
            if vit_layer == 12 and llm_layer == 0:
                # reshape vit_layer to 24x24 and show the matrix
                plt.figure(figsize=(6, 6))
                sns.heatmap(vit_attention[vit_layer].reshape(24, 24), annot=False, cmap='coolwarm', cbar=False)
            vit_sorted_index = np.argsort(vit_attention[vit_layer])
            llm_sorted_index = np.argsort(llm_attention[llm_layer])
            kendall, _ = kendalltau(vit_sorted_index, llm_sorted_index)
            kendall_matrix[vit_layer, llm_layer] = kendall

    plt.figure(figsize=(8, 6))
    sns.heatmap(kendall_matrix, annot=False, cmap='coolwarm', cbar=True)
    plt.title('Kendall Tau')
    plt.xlabel('LLM Layers')
    plt.ylabel('ViT Layers')
    plt.show()

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import kendalltau

def print_vit_vs_llm_attention_similarity(batch_id, vit_attention, llm_attention, vit_to_llm_mapping):
    
    vit_to_llm_mapping = vit_to_llm_mapping[batch_id]
    llm_attention = torch.stack([i[batch_id] for i in llm_attention])  # [num_layers, num_heads, num_context, num_context]
    vit_attention = torch.stack([i[batch_id] for i in vit_attention])  # [num_layers, num_heads, num_patches, num_patches]

    num_tokens = vit_to_llm_mapping.shape[-1]

    llm_attention = llm_attention.sum(dim=1)  # [num_layers, num_context, num_context]
    vit_attention = vit_attention.sum(dim=1)  # [num_layers, num_patches, num_patches]

    llm_attention = llm_attention.sum(dim=1)  # [num_layers, num_context]
    vit_attention = vit_attention[:,0,-num_tokens:]  # [num_layers, num_tokens]

    llm_attention = llm_attention[:, vit_to_llm_mapping]


    vit_num_layers = vit_attention.shape[0]
    llm_num_layers = llm_attention.shape[0]

    # to cpu and numpy
    vit_attention = vit_attention.cpu().numpy()
    llm_attention = llm_attention.cpu().numpy()

    for vit_layer in range(vit_num_layers):
        if vit_layer == -2:
            # reshape vit_layer to 24x24 and show the matrix
            plt.figure(figsize=(6, 6))
            sns.heatmap(vit_attention[vit_layer].reshape(24, 24), annot=False, cmap='coolwarm', cbar=False)
        vit_sorted_index = np.argsort(vit_attention[vit_layer])
        kendall_sum = 0
        for llm_layer in range(llm_num_layers):
            llm_sorted_index = np.argsort(llm_attention[llm_layer])
            kendall, _ = kendalltau(vit_sorted_index, llm_sorted_index)
            kendall_sum += kendall
        
        print(f'ViT Layer {vit_layer} Kendall Tau: {kendall_sum}')

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import kendalltau

def print_vit_vs_llm_attention_similarity(batch_id, vit_attention, llm_attention, vit_to_llm_mapping):
    
    vit_to_llm_mapping = vit_to_llm_mapping[batch_id]
    llm_attention = torch.stack([i[batch_id] for i in llm_attention])  # [num_layers, num_heads, num_context, num_context]
    vit_attention = torch.stack([i[batch_id] for i in vit_attention])  # [num_layers, num_heads, num_patches, num_patches]

    num_tokens = vit_to_llm_mapping.shape[-1]

    llm_attention = llm_attention.sum(dim=1)  # [num_layers, num_context, num_context]
    vit_attention = vit_attention.sum(dim=1)  # [num_layers, num_patches, num_patches]

    llm_attention = llm_attention.sum(dim=1)  # [num_layers, num_context]
    vit_attention = vit_attention[:,0,-num_tokens:]  # [num_layers, num_tokens]

    llm_attention = llm_attention[:, vit_to_llm_mapping]


    vit_num_layers = vit_attention.shape[0]
    llm_num_layers = llm_attention.shape[0]

    # to cpu and numpy
    vit_attention = vit_attention.cpu().numpy()
    llm_attention = llm_attention.cpu().numpy()

    for vit_layer in range(vit_num_layers):
        if vit_layer == -2:
            # reshape vit_layer to 24x24 and show the matrix
            plt.figure(figsize=(6, 6))
            sns.heatmap(vit_attention[vit_layer].reshape(24, 24), annot=False, cmap='coolwarm', cbar=False)
        vit_sorted_index = np.argsort(vit_attention[vit_layer])
        kendall_sum = 0
        for llm_layer in range(llm_num_layers):
            llm_sorted_index = np.argsort(llm_attention[llm_layer])
            kendall, _ = kendalltau(vit_sorted_index, llm_sorted_index)
            kendall_sum += kendall
        
        print(f'ViT Layer {vit_layer} Kendall Tau: {kendall_sum}')

In [None]:
model_id = "llava-hf/llava-1.5-7b-hf"

In [None]:
model = LlavaForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
    attn_implementation="eager",
).to(0)
processor = AutoProcessor.from_pretrained(model_id)

In [None]:
url = "https://picsum.photos/400"
response = requests.get(url)

if response.status_code == 200:
    with open('image.jpg', 'wb') as file:
        file.write(response.content)
    print("Image successfully retrieved and saved.")
else:
    print(f"Failed to retrieve image. HTTP Status code: {response.status_code}")
raw_image = Image.open('image.jpg')
raw_image.show()

In [None]:
prompt = "USER: <image>\nDescribe the image in details\nASSISTANT:"

In [None]:
inputs = processor(prompt, raw_image, return_tensors='pt').to(0, torch.float16)

with torch.inference_mode():
    output_ids = model.generate(
        **inputs,
        do_sample=False,
        max_new_tokens=256,
        use_cache=True,
        output_attentions=True,
        output_scores=True,
        return_dict_in_generate=True,
        )

output_text = processor.decode(output_ids['sequences'][0], skip_special_tokens=False)
print(output_text)
inputs = processor(output_text, raw_image, return_tensors='pt').to(0, torch.float16)
with torch.inference_mode():
    output = model(**inputs, output_attentions=True, return_dict = True)

In [None]:
llm_attention = output.attentions
vit_attention = output.vit_attentions
vit_to_llm_mapping = output.vit_to_llm_mapping

# make each numpy

