# Interpretability

In [None]:
def analyze_attn_matrix(x, y, model, layer, head, sample, task):
    # print(x.shape)

    attention, _ = get_attention(x, y, model)
    attn_mat = attention[layer].detach().cpu()[sample][head]

    gs = gridspec.GridSpec(3, 4, width_ratios=[0.25, 0.25, 1, 1],
    wspace=0.04, hspace=0.04, top=0.95, bottom=0.05, left=0.17, right=0.845) 

    attn_x = attn_mat[::2]
    attn_y = attn_mat[1::2]

    k_vals = range(2, 10)
    scores_x = []
    scores_y = []

    for k in k_vals:
        kmeans_x = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(attn_x)
        scores_x.append(silhouette_score(attn_x, kmeans_x.labels_))
        kmeans_y = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(attn_y)
        scores_y.append(silhouette_score(attn_y, kmeans_y.labels_))

    plt.plot(k_vals, scores_x, label="scores x")
    plt.plot(k_vals, scores_y, label="scores y")
    plt.legend()
    plt.show()

    k_x = k_vals[np.argmax(np.asarray(scores_x))]
    k_y = k_vals[np.argmax(np.asarray(scores_y))]

    kmeans_x = KMeans(n_clusters=k_x, random_state=0, n_init="auto").fit(attn_x)
    kmeans_y = KMeans(n_clusters=k_y, random_state=0, n_init="auto").fit(attn_y)

    kmeans_all = [kmeans_x, kmeans_y]

    for i, (kmeans, k) in enumerate(zip([kmeans_x, kmeans_y], [k_x, k_y])):
        if i == 0:
            print("x query cluster:")
        else:
            print("y query cluster:")

        cls = svm.LinearSVC(multi_class="ovr").fit(x.cpu()[i], kmeans.labels_)
        if k == 2:
            print(f"SVM coefs: {cls.coef_[0]}")
            print(f"class1: {np.where(kmeans.labels_ == 0)[0]}")
            print_dominant_literals(kmeans, x, sample, 0)
            # print_cluster_y_weight(kmeans, ys, i, 0)
            print_cluster_pathways(kmeans, x, y, model, task, sample, 0)
            print(f"class2: {np.where(kmeans.labels_ == 1)[0]}")
            print_dominant_literals(kmeans, x, sample, 1)
            # print_cluster_y_weight(kmeans, ys, i, 1)
            print_cluster_pathways(kmeans, x, y, model, task, sample, 1)
        else:
            for c in range(k):
                c_labels = np.where(kmeans.labels_ == c)[0]
                print(f"SVM coefs for class {c}: {cls.coef_[c]}\\class elements: {c_labels}")
                print_dominant_literals(kmeans, x, sample, c)
                # print_cluster_y_weight(kmeans, ys, i, c)
                print_cluster_pathways(kmeans, x, y, model, task, sample, c)


    # print(kmeans_x.labels_)

    fig = plt.figure(figsize=(15,8))

    # ax5 = plt.subplot(3, 4, 7)
    ax5 = plt.subplot(gs[1,2])
    plt.imshow(attn_mat[::2, ::2], vmin=0, vmax=1, cmap='Purples')
    plt.grid(None)
    plt.setp(ax5.get_xticklabels(), visible=False)
    plt.setp(ax5.get_yticklabels(), visible=False)

    # ax6 = plt.subplot(3, 4, 8, sharey=ax5)
    ax6 = plt.subplot(gs[1,3])
    plt.imshow(attn_mat[::2, 1::2], vmin=0, vmax=1, cmap='Purples')
    plt.grid(None)
    plt.setp(ax6.get_yticklabels(), visible=False)
    plt.setp(ax6.get_xticklabels(), visible=False)

    # ax8 = plt.subplot(3, 4, 11, sharex=ax5)
    ax8 = plt.subplot(gs[2,2])
    plt.imshow(attn_mat[1::2, ::2], vmin=0, vmax=1, cmap='Purples')
    plt.grid(None)
    plt.xlabel("x key")
    plt.setp(ax8.get_yticklabels(), visible=False)

    # ax9 = plt.subplot(3, 4, 12, sharey=ax8, sharex=ax6)
    ax9 = plt.subplot(gs[2,3])
    plt.imshow(attn_mat[1::2, 1::2], vmin=0, vmax=1, cmap='Purples')
    plt.grid(None)
    plt.setp(ax9.get_yticklabels(), visible=False)
    plt.xlabel("y key")

    # ax45 = plt.subplot(3, 4, 6, sharey=ax5)
    ax45 = plt.subplot(gs[1,1])
    plt.imshow(np.tile(kmeans_x.labels_, (10, 1)).T)
    plt.setp(ax45.get_xticklabels(), visible=False)
    plt.grid(None)
    plt.setp(ax45.get_yticklabels(), visible=False)

    # print(np.tile(kmeans_x.labels_, (10, 1)).T)

    # ax78 = plt.subplot(3, 4, 10, sharey=ax7, sharex=ax45)
    ax78 = plt.subplot(gs[2,1])
    plt.imshow(np.tile(kmeans_y.labels_, (10, 1)).T)
    plt.setp(ax78.get_xticklabels(), visible=False)
    plt.grid(None)
    plt.setp(ax78.get_yticklabels(), visible=False)

    # ax4 = plt.subplot(3, 4, 5, sharey=ax45)
    ax4 = plt.subplot(gs[1,0])
    plt.imshow(x.cpu()[sample] >= 0)
    plt.ylabel("x query")
    plt.grid(None)
    plt.setp(ax4.get_xticklabels(), visible=False)

    # ax7 = plt.subplot(3, 4, 9, sharex=ax4)
    ax7 = plt.subplot(gs[2,0])
    plt.imshow(y.cpu()[None, sample].tile(20,1).T, cmap="bwr")
    plt.grid(None)
    plt.setp(ax7.get_xticklabels(), visible=False)
    plt.ylabel("y query")

    # ax2 = plt.subplot(3, 4, 3, sharex=ax5)
    ax2 = plt.subplot(gs[0,2])
    plt.imshow(x.cpu()[sample].transpose(1,0) >= 0)
    plt.grid(None)
    plt.setp(ax2.get_xticklabels(), visible=False)
    plt.setp(ax2.get_yticklabels(), visible=False)

    # ax3 = plt.subplot(3, 4, 4, sharex=ax6, sharey=ax2)
    ax3 = plt.subplot(gs[0,3])
    plt.imshow(y.cpu()[None, sample].tile(20,1), cmap="bwr")
    plt.grid(None)
    plt.setp(ax3.get_xticklabels(), visible=False)
    plt.setp(ax3.get_yticklabels(), visible=False)

    # plt.subplots_adjust(wspace=0, hspace=0)

    # plt.tight_layout()
    
    fig.suptitle(f"layer {layer}, head {head}, round {sample}")


    plt.show()
