In [5]:
import matplotlib.pyplot as plt
import numpy as np
import math
import pandas as pd
import seaborn as sns
import torch


In [46]:
# Calculate the cosine similarity
# X = batch_size, seq_len, heads, head_dim
def cosine_similarity(X1, X2, dim=1):
    print(X1.shape)
    print(X2.shape)
    # we want to compare the locations of the rotated results
    tensor1 = torch.from_numpy(X1)
    tensor2 = torch.from_numpy(X2)
    return torch.nn.functional.cosine_similarity(tensor1, tensor2, dim=(1,2,4))

# Plot similarity
def plot_similarity_matrix(similarity_matrix):
    plt.figure(figsize=(8, 6))
    plt.imshow(similarity_matrix, cmap='viridis', origin='lower')
    plt.colorbar(label='Similarity')
    plt.title('Cosine Similarity Matrix')
    plt.show()

def plot_spatial_similarity(m, w, h, similarity_matrix, line_length):
    plt.figure(figsize=(8, 6))

    map = np.zeros((h,w))
    center = m + w//2 + h//2*line_length
    for j in range(h):
        for i in range(w):
            realpos = m + j*line_length + i
            val = similarity_matrix[center,realpos]
            map[h-j-1,i] = val
            plt.text(i, h-j-1, f'{val:.2f}', ha='center', va='center', color="black")

    # Set the range from 0 to 1
    plt.imshow(map, cmap='viridis', origin='lower', vmax=1, vmin=0)
    # Set x ticks and y ticks
    plt.xticks(np.arange(0, w, 1), np.arange(0, w, 1)-w//2)
    plt.yticks(np.arange(0, h, 1), np.arange(h, 0, -1)-h//2-1)
    plt.colorbar(label='Similarity')
    plt.show()


In [11]:
def plot_attention_maps(input_data, attn_maps, idx=0):
    if input_data is not None:
        input_data = input_data[idx].detach().cpu().numpy()
    else:
        input_data = np.arange(attn_maps[0][idx].shape[-1])
    attn_maps = [m[idx].detach().cpu().numpy() for m in attn_maps]

    num_heads = attn_maps[0].shape[0]
    num_layers = len(attn_maps)
    seq_len = input_data.shape[0]
    fig_size = 4 if num_heads == 1 else 3
    fig, ax = plt.subplots(num_layers, num_heads, figsize=(num_heads*fig_size, num_layers*fig_size))
    if num_layers == 1:
        ax = [ax]
    if num_heads == 1:
        ax = [[a] for a in ax]
    for row in range(num_layers):
        for column in range(num_heads):
            ax[row][column].imshow(attn_maps[row][column], origin='lower', vmin=0)
            ax[row][column].set_xticks(list(range(seq_len)))
            ax[row][column].set_xticklabels(input_data.tolist())
            ax[row][column].set_yticks(list(range(seq_len)))
            ax[row][column].set_yticklabels(input_data.tolist())
            ax[row][column].set_title(f"Layer {row+1}, Head {column+1}")
    fig.subplots_adjust(hspace=0.5)
    plt.show()


In [None]:
# X_rope_before = np.load('./custom-llama/key_states.npy')
X_rope = np.load('./custom-llama/key_states_rope_applied.npy')

# print(X_rope)
similarity_matrix_rope = cosine_similarity(X_rope, X_rope)
print(similarity_matrix_rope.shape)
print(similarity_matrix_rope)
# sns.heatmap(similarity_matrix_rope, annot=True, cmap="YlGnBu")
# plt.title("Similarity Matrix")
# plt.show()
#plot_similarity_matrix(similarity_matrix_rope)
#plot_spatial_similarity(3, 11, 11, similarity_matrix_rope, LINE_LENGTH)