Notebook which plots the raw matrices of a model (pythia-160m), deletes some of
the rows/columns, and plots them again. 

In [None]:
! pip install seaborn -qq
! git rev-parse --short HEAD
import matplotlib.pyplot as plt
import seaborn as sns
from separability import Model
from separability.activations import prune_and_evaluate
from separability.data_classes import PruningConfig
from separability.eval import evaluate_all

In [None]:
c = PruningConfig("EleutherAI/pythia-160m",
    ff_frac=0, attn_frac=0.5, attn_scoring="abs",
    eval_sample_size=1000, collection_sample_size=1000, run_pre_test=False
)
opt = Model(c.model_repo, 1000, model_device="cpu")

def make_plot():
    d = opt.cfg.d_model
    W_V = opt.layers[0]["attn.W_V"].reshape((d, d))
    W_O = opt.layers[0]["attn.W_O"].reshape((d, d))
    b_V = opt.layers[0]["attn.b_V"].reshape((1, d))
    b_O = opt.layers[0]["attn.b_O"].reshape((1, d))
    fig, ax = plt.subplots(1, 4, figsize=(10, 5), gridspec_kw={'width_ratios': [6,1,6,1]})
    sns.heatmap(W_V.detach().numpy(), vmin=-0.01, center=0, vmax=0.01, ax=ax[0], cbar=False)
    sns.heatmap(b_V.detach().reshape((-1,1)).numpy(), vmin=-0.01, center=0, vmax=0.01, ax=ax[1], cbar=False)
    sns.heatmap(W_O.detach().numpy(), vmin=-0.01, center=0, vmax=0.01, ax=ax[2], cbar=False)
    sns.heatmap(b_O.detach().reshape((-1,1)).numpy(), vmin=-0.01, center=0, vmax=0.01, ax=ax[3], cbar=False)
    plt.show()

    W_in, W_out = opt.layers[0]["mlp.W_in"], opt.layers[0]["mlp.W_out"]
    b_in, b_out = opt.layers[0]["mlp.b_in"], opt.layers[0]["mlp.b_out"]
    fig, ax = plt.subplots(1, 4, figsize=(10, 5), gridspec_kw={'width_ratios': [6,1,6,1]})
    sns.heatmap( W_in.detach().numpy(), vmin=-0.01, center=0, vmax=0.01, ax=ax[0], cbar=False)
    sns.heatmap( b_in.detach().reshape((-1,1)).numpy(), vmin=-0.01, center=0, vmax=0.01, ax=ax[1], cbar=False)
    sns.heatmap(W_out.detach().numpy(), vmin=-0.01, center=0, vmax=0.01, ax=ax[2], cbar=False)
    sns.heatmap(b_out.detach().reshape((-1,1)).numpy(), vmin=-0.01, center=0, vmax=0.01, ax=ax[3], cbar=False)
    plt.show()

    qkv = opt.layers[0]["attn"].query_key_value
    plt.figure(figsize=(10, 5))
    fig, ax = plt.subplots(2, 1, figsize=(10, 5), gridspec_kw={'height_ratios': [1,6]})
    sns.heatmap(qkv.bias.unsqueeze(dim=0).detach().numpy(), vmin=-0.01, center=0, vmax=0.01, cbar=False, ax=ax[0])
    sns.heatmap(qkv.weight.T.detach().numpy(), vmin=-0.01, center=0, vmax=0.01, cbar=False, ax=ax[1])
    plt.show()

print("Before pruning:")
evaluate_all(opt, 1000, datasets=c.datasets)
make_plot()


prune_and_evaluate(opt, c)

print("After pruning:")
make_plot()