# Interpretability

In [50]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from data_gen import get_data, format_data
from models import TransformerModel
import yaml
from munch import Munch
import time
import einops

with open(f"configs/model_selection.yaml", "r") as yaml_file:
    args = Munch.fromYAML(yaml_file)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
model = TransformerModel(
    n_dims=len(args.data.data_alphas) + 1,
    n_positions=args.data.N,
    n_layer=args.model.n_layer,
    n_head=args.model.n_head,
    n_embd=args.model.n_embd
).to(device)
model.load_state_dict(torch.load("models/model_epoch1750_time16868839754.pth", map_location=torch.device('cpu'))) # TODO: Remove map_location
model.eval()

TransformerModel(
  (_read_in): Linear(in_features=3, out_features=64, bias=True)
  (_backbone): GPT2Model(
    (wte): Embedding(50257, 64)
    (wpe): Embedding(21, 64)
    (drop): Dropout(p=0.0, inplace=False)
    (h): ModuleList(
      (0): GPT2Block(
        (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.0, inplace=False)
          (resid_dropout): Dropout(p=0.0, inplace=False)
        )
        (ln_2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.0, inplace=False)
        )
      )
      (1): GPT2Block(
        (ln_1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Drop

In [27]:
data_dict = get_data(alphas=args.data.data_alphas, N=args.data.N, d_d=args.data.d_d, train_samp_per_class=2)#args.data.train_samp_per_class)

alphas, X, y = format_data(data_dict, train_samples_per_alpha=int(2 / len(args.data.data_alphas)))

Alphas: torch.Size([2]), X: torch.Size([2, 21, 3]), y: torch.Size([2])


### Get Weights

In [43]:
def get_flat_weights(model):
    flat_weights = {}
    for name, module in model.named_modules():
        try:
            flat_weights[name] = module.weight
        except:
            pass
    return flat_weights

unformatted_flat_weights = get_flat_weights(model)
print(unformatted_flat_weights.keys())

dict_keys(['_read_in', '_backbone.wte', '_backbone.wpe', '_backbone.h.0.ln_1', '_backbone.h.0.attn.c_attn', '_backbone.h.0.attn.c_proj', '_backbone.h.0.ln_2', '_backbone.h.0.mlp.c_fc', '_backbone.h.0.mlp.c_proj', '_backbone.h.1.ln_1', '_backbone.h.1.attn.c_attn', '_backbone.h.1.attn.c_proj', '_backbone.h.1.ln_2', '_backbone.h.1.mlp.c_fc', '_backbone.h.1.mlp.c_proj', '_backbone.h.2.ln_1', '_backbone.h.2.attn.c_attn', '_backbone.h.2.attn.c_proj', '_backbone.h.2.ln_2', '_backbone.h.2.mlp.c_fc', '_backbone.h.2.mlp.c_proj', '_backbone.ln_f', '_read_out'])


In [52]:
def format_flat_weights(model):
    formatted_flat_weights = {}

    formatted_flat_weights["embed.W_E"] = model._backbone.wte.weight
    formatted_flat_weights["pos_embed.W_pos"] = model._backbone.wpe.weight

    for l in range(args.model.n_layer):
        formatted_flat_weights[f"blocks.{l}.ln1.w"] = model._backbone.h[l].ln_1.weight
        formatted_flat_weights[f"blocks.{l}.ln1.b"] = model._backbone.h[l].ln_1.bias

        # In GPT-2, q,k,v are produced by one big linear map, whose output is
        # concat([q, k, v])
        W = model._backbone.h[l].attn.c_attn.weight
        W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1)
        W_Q = einops.rearrange(W_Q, "m (i h)->i m h", i=args.model.n_head)
        W_K = einops.rearrange(W_K, "m (i h)->i m h", i=args.model.n_head)
        W_V = einops.rearrange(W_V, "m (i h)->i m h", i=args.model.n_head)

        formatted_flat_weights[f"blocks.{l}.attn.W_Q"] = W_Q
        formatted_flat_weights[f"blocks.{l}.attn.W_K"] = W_K
        formatted_flat_weights[f"blocks.{l}.attn.W_V"] = W_V

        qkv_bias = model._backbone.h[l].attn.c_attn.bias
        qkv_bias = einops.rearrange(
            qkv_bias,
            "(qkv index head)->qkv index head",
            qkv=3,
            index=args.model.n_head,
            head=32, # TODO: What should this number be?
        )
        formatted_flat_weights[f"blocks.{l}.attn.b_Q"] = qkv_bias[0]
        formatted_flat_weights[f"blocks.{l}.attn.b_K"] = qkv_bias[1]
        formatted_flat_weights[f"blocks.{l}.attn.b_V"] = qkv_bias[2]

        W_O = model._backbone.h[l].attn.c_proj.weight
        W_O = einops.rearrange(W_O, "(i h) m->i h m", i=args.model.n_head)
        formatted_flat_weights[f"blocks.{l}.attn.W_O"] = W_O
        formatted_flat_weights[f"blocks.{l}.attn.b_O"] = model._backbone.h[l].attn.c_proj.bias

        formatted_flat_weights[f"blocks.{l}.ln2.w"] = model._backbone.h[l].ln_2.weight
        formatted_flat_weights[f"blocks.{l}.ln2.b"] = model._backbone.h[l].ln_2.bias

        W_in = model._backbone.h[l].mlp.c_fc.weight
        formatted_flat_weights[f"blocks.{l}.mlp.W_in"] = W_in
        formatted_flat_weights[f"blocks.{l}.mlp.b_in"] = model._backbone.h[l].mlp.c_fc.bias

        W_out = model._backbone.h[l].mlp.c_proj.weight
        formatted_flat_weights[f"blocks.{l}.mlp.W_out"] = W_out
        formatted_flat_weights[f"blocks.{l}.mlp.b_out"] = model._backbone.h[l].mlp.c_proj.bias
    # formatted_flat_weights["unembed.W_U"] = gpt2.lm_head.weight.T # ?

    formatted_flat_weights["ln_final.w"] = model._backbone.ln_f.weight
    formatted_flat_weights["ln_final.b"] = model._backbone.ln_f.bias

    return formatted_flat_weights

flat_weights = format_flat_weights(model)
print(flat_weights)

{'embed.W_E': Parameter containing:
tensor([[-0.0043,  0.0468, -0.0046,  ...,  0.0046,  0.0077, -0.0073],
        [ 0.0088, -0.0115,  0.0034,  ...,  0.0360, -0.0267, -0.0020],
        [ 0.0029, -0.0516, -0.0153,  ...,  0.0073,  0.0196, -0.0176],
        ...,
        [-0.0033,  0.0085, -0.0081,  ..., -0.0072,  0.0006, -0.0175],
        [-0.0073,  0.0005,  0.0020,  ...,  0.0054, -0.0029, -0.0345],
        [ 0.0275,  0.0208, -0.0232,  ..., -0.0212, -0.0038, -0.0148]],
       requires_grad=True), 'pos_embed.W_pos': Parameter containing:
tensor([[ 2.4808e-05, -1.2077e-01, -2.1305e-01,  ...,  4.8523e-01,
         -2.6729e-01, -1.0821e+00],
        [-8.6075e-02, -3.4190e-02, -1.0396e-02,  ...,  5.0130e-01,
         -3.8300e-02, -5.6612e-01],
        [-7.0768e-02,  9.7565e-02, -6.6980e-02,  ...,  4.5082e-01,
         -1.9241e-01, -3.2436e-01],
        ...,
        [-4.5040e-02,  9.2468e-02, -2.5057e-02,  ..., -1.5930e-01,
         -2.4148e-01, -2.8253e-01],
        [-1.1311e-01,  6.1901e-02, -

### Get Activations

In [44]:
def rec_update_dict(keys, val, dict):
    if len(keys) == 1:
        dict[keys[0]]["val"] = val
        return dict
    if keys[0] not in dict:
        dict[keys[0]] = rec_update_dict(keys[1:], val, {})
        return dict
    dict[keys[0]] = rec_update_dict(keys[1:], val, dict[keys[0]])
    return dict

def get_hier_activations(flat_activations):
    hier_activations = {}

    for key, val in flat_activations.items():
        hier_activations = rec_update_dict(key.split("."), val, hier_activations)

    return hier_activations

def get_flat_activations(model, X):
    flat_activations = {}

    def get_activation(name):
        def hook(model, input, output):
            flat_activations[name] = output
        return hook
    
    def hook_model(model):
        for name, module in model.named_modules():
            module.register_forward_hook(get_activation(name))

    hook_model(model)
    model(X)

    return flat_activations

flat_activations = get_flat_activations(model, X)
# hier_activations = get_hier_activations(flat_activations)
print(flat_activations.keys())

dict_keys(['_read_in', '_backbone.wpe', '_backbone.drop', '_backbone.h.0.ln_1', '_backbone.h.0.attn.c_attn', '_backbone.h.0.attn.attn_dropout', '_backbone.h.0.attn.c_proj', '_backbone.h.0.attn.resid_dropout', '_backbone.h.0.attn', '_backbone.h.0.ln_2', '_backbone.h.0.mlp.c_fc', '_backbone.h.0.mlp.act', '_backbone.h.0.mlp.c_proj', '_backbone.h.0.mlp.dropout', '_backbone.h.0.mlp', '_backbone.h.0', '_backbone.h.1.ln_1', '_backbone.h.1.attn.c_attn', '_backbone.h.1.attn.attn_dropout', '_backbone.h.1.attn.c_proj', '_backbone.h.1.attn.resid_dropout', '_backbone.h.1.attn', '_backbone.h.1.ln_2', '_backbone.h.1.mlp.c_fc', '_backbone.h.1.mlp.act', '_backbone.h.1.mlp.c_proj', '_backbone.h.1.mlp.dropout', '_backbone.h.1.mlp', '_backbone.h.1', '_backbone.h.2.ln_1', '_backbone.h.2.attn.c_attn', '_backbone.h.2.attn.attn_dropout', '_backbone.h.2.attn.c_proj', '_backbone.h.2.attn.resid_dropout', '_backbone.h.2.attn', '_backbone.h.2.ln_2', '_backbone.h.2.mlp.c_fc', '_backbone.h.2.mlp.act', '_backbone.h.2

## Interpretability

See line 996 of [`TransformerLens/transformer_lens/loading_from_pretrained.py`](https://github.com/neelnanda-io/TransformerLens/blob/main/transformer_lens/loading_from_pretrained.py#L996).

In [48]:
# def get_attention(xs, ys, model, samples=3):
#     """Returns an attention matrix of full sequence inputs evaluated at a small number of samples"""
#     pred = model(xs[:samples], ys[:samples], output_attentions=True)  # Run model
#     attention = pred[-1]  # Retrieve attention from model outputs of a model evaluated at 
#     tokens = [val for pair in zip([f"x{i}" for i in range(xs[0].shape[0])], [f"y{i}={ys[0][i].item():.1f}" for i in range(xs[0].shape[0])]) for val in pair]
#     return attention, tokens

def get_attention_matrix(flat_activations):
    # TODO: Write
    print(flat_activations["_backbone.h.1.attn.c_attn"].shape)
    print(flat_activations["_backbone.h.1.attn.c_proj"].shape)
    print(flat_activations["_backbone.h.0.attn.resid_dropout"].shape)

get_attention_matrix(flat_activations)

torch.Size([64, 192])
torch.Size([64, 64])
torch.Size([2, 21, 192])
torch.Size([2, 21, 64])
torch.Size([2, 21, 64])


In [14]:
print(get_attention(X, y, model))



IndexError: invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()` in C++ to convert a 0-dim tensor to a number

In [None]:
def analyze_attn_matrix(attn_matrix layer, head, sample, task):
    # print(x.shape)

    attention, _ = get_attention(x, y, model)
    attn_mat = attention[layer].detach().cpu()[sample][head]

    gs = gridspec.GridSpec(3, 4, width_ratios=[0.25, 0.25, 1, 1],
    wspace=0.04, hspace=0.04, top=0.95, bottom=0.05, left=0.17, right=0.845) 

    attn_x = attn_mat[::2]
    attn_y = attn_mat[1::2]

    k_vals = range(2, 10)
    scores_x = []
    scores_y = []

    for k in k_vals:
        kmeans_x = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(attn_x)
        scores_x.append(silhouette_score(attn_x, kmeans_x.labels_))
        kmeans_y = KMeans(n_clusters=k, random_state=0, n_init="auto").fit(attn_y)
        scores_y.append(silhouette_score(attn_y, kmeans_y.labels_))

    plt.plot(k_vals, scores_x, label="scores x")
    plt.plot(k_vals, scores_y, label="scores y")
    plt.legend()
    plt.show()

    k_x = k_vals[np.argmax(np.asarray(scores_x))]
    k_y = k_vals[np.argmax(np.asarray(scores_y))]

    kmeans_x = KMeans(n_clusters=k_x, random_state=0, n_init="auto").fit(attn_x)
    kmeans_y = KMeans(n_clusters=k_y, random_state=0, n_init="auto").fit(attn_y)

    kmeans_all = [kmeans_x, kmeans_y]

    for i, (kmeans, k) in enumerate(zip([kmeans_x, kmeans_y], [k_x, k_y])):
        if i == 0:
            print("x query cluster:")
        else:
            print("y query cluster:")

        cls = svm.LinearSVC(multi_class="ovr").fit(x.cpu()[i], kmeans.labels_)
        if k == 2:
            print(f"SVM coefs: {cls.coef_[0]}")
            print(f"class1: {np.where(kmeans.labels_ == 0)[0]}")
            print_dominant_literals(kmeans, x, sample, 0)
            # print_cluster_y_weight(kmeans, ys, i, 0)
            print_cluster_pathways(kmeans, x, y, model, task, sample, 0)
            print(f"class2: {np.where(kmeans.labels_ == 1)[0]}")
            print_dominant_literals(kmeans, x, sample, 1)
            # print_cluster_y_weight(kmeans, ys, i, 1)
            print_cluster_pathways(kmeans, x, y, model, task, sample, 1)
        else:
            for c in range(k):
                c_labels = np.where(kmeans.labels_ == c)[0]
                print(f"SVM coefs for class {c}: {cls.coef_[c]}\\class elements: {c_labels}")
                print_dominant_literals(kmeans, x, sample, c)
                # print_cluster_y_weight(kmeans, ys, i, c)
                print_cluster_pathways(kmeans, x, y, model, task, sample, c)


    # print(kmeans_x.labels_)

    fig = plt.figure(figsize=(15,8))

    # ax5 = plt.subplot(3, 4, 7)
    ax5 = plt.subplot(gs[1,2])
    plt.imshow(attn_mat[::2, ::2], vmin=0, vmax=1, cmap='Purples')
    plt.grid(None)
    plt.setp(ax5.get_xticklabels(), visible=False)
    plt.setp(ax5.get_yticklabels(), visible=False)

    # ax6 = plt.subplot(3, 4, 8, sharey=ax5)
    ax6 = plt.subplot(gs[1,3])
    plt.imshow(attn_mat[::2, 1::2], vmin=0, vmax=1, cmap='Purples')
    plt.grid(None)
    plt.setp(ax6.get_yticklabels(), visible=False)
    plt.setp(ax6.get_xticklabels(), visible=False)

    # ax8 = plt.subplot(3, 4, 11, sharex=ax5)
    ax8 = plt.subplot(gs[2,2])
    plt.imshow(attn_mat[1::2, ::2], vmin=0, vmax=1, cmap='Purples')
    plt.grid(None)
    plt.xlabel("x key")
    plt.setp(ax8.get_yticklabels(), visible=False)

    # ax9 = plt.subplot(3, 4, 12, sharey=ax8, sharex=ax6)
    ax9 = plt.subplot(gs[2,3])
    plt.imshow(attn_mat[1::2, 1::2], vmin=0, vmax=1, cmap='Purples')
    plt.grid(None)
    plt.setp(ax9.get_yticklabels(), visible=False)
    plt.xlabel("y key")

    # ax45 = plt.subplot(3, 4, 6, sharey=ax5)
    ax45 = plt.subplot(gs[1,1])
    plt.imshow(np.tile(kmeans_x.labels_, (10, 1)).T)
    plt.setp(ax45.get_xticklabels(), visible=False)
    plt.grid(None)
    plt.setp(ax45.get_yticklabels(), visible=False)

    # print(np.tile(kmeans_x.labels_, (10, 1)).T)

    # ax78 = plt.subplot(3, 4, 10, sharey=ax7, sharex=ax45)
    ax78 = plt.subplot(gs[2,1])
    plt.imshow(np.tile(kmeans_y.labels_, (10, 1)).T)
    plt.setp(ax78.get_xticklabels(), visible=False)
    plt.grid(None)
    plt.setp(ax78.get_yticklabels(), visible=False)

    # ax4 = plt.subplot(3, 4, 5, sharey=ax45)
    ax4 = plt.subplot(gs[1,0])
    plt.imshow(x.cpu()[sample] >= 0)
    plt.ylabel("x query")
    plt.grid(None)
    plt.setp(ax4.get_xticklabels(), visible=False)

    # ax7 = plt.subplot(3, 4, 9, sharex=ax4)
    ax7 = plt.subplot(gs[2,0])
    plt.imshow(y.cpu()[None, sample].tile(20,1).T, cmap="bwr")
    plt.grid(None)
    plt.setp(ax7.get_xticklabels(), visible=False)
    plt.ylabel("y query")

    # ax2 = plt.subplot(3, 4, 3, sharex=ax5)
    ax2 = plt.subplot(gs[0,2])
    plt.imshow(x.cpu()[sample].transpose(1,0) >= 0)
    plt.grid(None)
    plt.setp(ax2.get_xticklabels(), visible=False)
    plt.setp(ax2.get_yticklabels(), visible=False)

    # ax3 = plt.subplot(3, 4, 4, sharex=ax6, sharey=ax2)
    ax3 = plt.subplot(gs[0,3])
    plt.imshow(y.cpu()[None, sample].tile(20,1), cmap="bwr")
    plt.grid(None)
    plt.setp(ax3.get_xticklabels(), visible=False)
    plt.setp(ax3.get_yticklabels(), visible=False)

    # plt.subplots_adjust(wspace=0, hspace=0)

    # plt.tight_layout()
    
    fig.suptitle(f"layer {layer}, head {head}, round {sample}")


    plt.show()

analyze_attn_matrix(xs_boolean, ys_boolean, model, layer, head, i)
