## Attribution of Attention Heads for Type recognition

Trying out transformer lens with starcoder

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
!nvidia-smi

Thu May 11 18:18:57 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA H100 PCIe    Off  | 00000000:17:00.0 Off |                    0 |
| N/A   37C    P0    46W / 310W |      3MiB / 81559MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA H100 PCIe    Off  | 00000000:65:00.0 Off |                    0 |
| N/A   35C    P0    44W / 310W |      3MiB / 81559MiB |      0%      Default |
|       

In [3]:
import os
import sys
sys.path.append('..')

from transformers import AutoConfig
import gc
import torch
from model_utils import *
import tqdm as notebook_tqdm
torch.set_grad_enabled(False)

import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import torch as t
import torch.nn as nn
import torch.nn.functional as F
pio.renderers.default = "notebook_connected" # or use "browser" if you want plots to open with browser
import numpy as np
import einops
from fancy_einsum import einsum
from torchtyping import TensorType as TT
from typing import List, Optional, Callable, Tuple, Union
import functools
from tqdm import tqdm
from IPython.display import display
from model_utils import to_hooked_config

from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

#MODEL_NAME = "Salesforce/codegen-16B-mono"
MODEL_NAME = "bigcode/santacoder"
print(torch.cuda.is_available())

%env CUDA_VISIBLE_DEVICES=1
!echo $CUDA_VISIBLE_DEVICES

check_devs()

  from .autonotebook import tqdm as notebook_tqdm


True
env: CUDA_VISIBLE_DEVICES=1
1
0 / 84979089408 used for device 0, reserved 0




NVIDIA H100 PCIe with CUDA capability sm_90 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_37 sm_50 sm_60 sm_70 sm_75 sm_80 sm_86.
If you want to use the NVIDIA H100 PCIe GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/




## Transformer Lens load

## Setup from Transformer Lens tut

In [4]:
## setup stuff
# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", caxis="", **kwargs):
    return px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    return px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    return px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

def plot_comp_scores(model: HookedTransformer, comp_scores: TT["heads", "heads"], title: str = "", baseline: Optional[t.Tensor] = None) -> go.Figure:
    return px.imshow(
        utils.to_numpy(comp_scores),
        y=[f"L0H{h}" for h in range(model.cfg.n_heads)],
        x=[f"L1H{h}" for h in range(model.cfg.n_heads)],
        labels={"x": "Layer 1", "y": "Layer 0"},
        title=title,
        color_continuous_scale="RdBu" if baseline is not None else "Blues",
        color_continuous_midpoint=baseline if baseline is not None else None,
        zmin=None if baseline is not None else 0.0,
    )

import IPython
from plotly.offline import init_notebook_mode
    
def enable_plotly_in_cell():
    display(IPython.core.display.HTML('''<script src="/static/components/requirejs/require.js"></script>'''))
    init_notebook_mode(connected=False)

t.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f826c4cfac0>

In [5]:
def solutions_get_ablation_scores(model: HookedTransformer, tokens: TT["batch", "seq"]) -> TT["n_layers", "n_heads"]:
    ablation_scores = t.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)
    logits = model(tokens, return_type="logits")
    loss_no_ablation = cross_entropy_loss(logits, tokens)
    for layer in tqdm(range(model.cfg.n_layers)):
        for head in range(model.cfg.n_heads):
            temp_hook_fn = functools.partial(head_ablation_hook, head_index_to_ablate=head)
            patched_logits = model.run_with_hooks(tokens, fwd_hooks=[
                (utils.get_act_name("result", layer), temp_hook_fn)
            ])
            loss = cross_entropy_loss(patched_logits, tokens)
            ablation_scores[layer, head] = loss - loss_no_ablation
    return ablation_scores

def solutions_mask_scores(attn_scores: TT["query_d_model", "key_d_model"]):
    mask = t.tril(t.ones_like(attn_scores)).bool()
    neg_inf = t.tensor(-1.0e6).to(attn_scores.device)
    masked_attn_scores = t.where(mask, attn_scores, neg_inf)
    return masked_attn_scores

def solutions_decompose_attn_scores(decomposed_q: t.Tensor, decomposed_k: t.Tensor) -> t.Tensor:
    return einsum("q_comp q_pos d_model, k_comp k_pos d_model -> q_comp k_comp q_pos k_pos", decomposed_q, decomposed_k)

def solutions_find_K_comp_full_circuit(model: HookedTransformer, prev_token_head_index: int, ind_head_index: int) -> FactoredMatrix:
    W_E = model.W_E
    W_Q = model.W_Q[1, ind_head_index]
    W_K = model.W_K[1, ind_head_index]
    W_O = model.W_O[0, prev_token_head_index]
    W_V = model.W_V[0, prev_token_head_index]
    Q = W_E @ W_Q
    K = W_E @ W_V @ W_O @ W_K
    return FactoredMatrix(Q, K.T)

def solutions_get_comp_score(
    W_A: TT["in_A", "out_A"], 
    W_B: TT["out_A", "out_B"]
) -> float:
    W_A_norm = W_A.pow(2).sum().sqrt()
    W_B_norm = W_B.pow(2).sum().sqrt()
    W_AB_norm = (W_A @ W_B).pow(2).sum().sqrt()
    return (W_AB_norm / (W_A_norm * W_B_norm)).item()

def test_get_ablation_scores(ablation_scores: TT["layer", "head"], model: HookedTransformer, rep_tokens: TT["batch", "seq"]):
    ablation_scores_expected = solutions_get_ablation_scores(model, rep_tokens)
    t.testing.assert_close(ablation_scores, ablation_scores_expected)
    print("All tests in `test_get_ablation_scores` passed!")

def test_full_OV_circuit(OV_circuit: FactoredMatrix, model: HookedTransformer, layer: int, head: int):
        W_E = model.W_E
        W_OV = FactoredMatrix(model.W_V[layer, head], model.W_O[layer, head])
        W_U = model.W_U
        OV_circuit_expected = W_E @ W_OV @ W_U
        t.testing.assert_close(OV_circuit.get_corner(20), OV_circuit_expected.get_corner(20))
        print("All tests in `test_full_OV_circuit` passed!")

def test_pos_by_pos_pattern(pattern: TT["n_ctx", "n_ctx"], model: HookedTransformer, layer: int, head: int):
    W_pos = model.W_pos
    W_QK = model.W_Q[layer, head] @ model.W_K[layer, head].T
    score_expected = W_pos @ W_QK @ W_pos.T
    masked_scaled = solutions_mask_scores(score_expected / model.cfg.d_head ** 0.5)
    pattern_expected = t.softmax(masked_scaled, dim=-1)
    t.testing.assert_close(pattern[:50, :50], pattern_expected[:50, :50])
    print("All tests in `test_full_OV_circuit` passed!")

def test_decompose_attn_scores(decompose_attn_scores: Callable, q: t.Tensor, k: t.Tensor):
    decomposed_scores = decompose_attn_scores(q, k)
    decomposed_scores_expected = solutions_decompose_attn_scores(q, k)
    t.testing.assert_close(decomposed_scores, decomposed_scores_expected)
    print("All tests in `test_decompose_attn_scores` passed!")

def test_find_K_comp_full_circuit(find_K_comp_full_circuit: Callable, model: HookedTransformer):
    K_comp_full_circuit: FactoredMatrix = find_K_comp_full_circuit(model, 7, 4)
    K_comp_full_circuit_expected: FactoredMatrix = solutions_find_K_comp_full_circuit(model, 7, 4)
    assert isinstance(K_comp_full_circuit, FactoredMatrix), "Should return a FactoredMatrix object!"
    t.testing.assert_close(K_comp_full_circuit.get_corner(20), K_comp_full_circuit_expected.get_corner(20))
    print("All tests in `test_find_K_comp_full_circuit` passed!")

def test_get_comp_score(get_comp_score: Callable):
    W_A = t.rand(3, 4)
    W_B = t.rand(4, 5)
    comp_score = get_comp_score(W_A, W_B)
    comp_score_expected = solutions_get_comp_score(W_A, W_B)
    assert isinstance(comp_score, float)
    assert abs(comp_score - comp_score_expected) < 1e-5
    print("All tests in `test_get_comp_score` passed!")

In [10]:
model_config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)
## cast config to Hooked config
# print(vars(HookedTransformerConfig)['__annotations__'])
cfg = to_hooked_config(model_config)
print(model_config)

# model = HookedTransformer(cfg)

GPT2CustomConfig {
  "_name_or_path": "bigcode/santacoder",
  "activation_function": "gelu_fast",
  "architectures": [
    "GPT2LMHeadCustomModel"
  ],
  "attention_head_type": "multiquery",
  "attn_pdrop": 0.1,
  "auto_map": {
    "AutoConfig": "bigcode/santacoder--configuration_gpt2_mq.GPT2CustomConfig",
    "AutoModelForCausalLM": "bigcode/santacoder--modeling_gpt2_mq.GPT2LMHeadCustomModel"
  },
  "bos_token_id": 49152,
  "embd_pdrop": 0.1,
  "eos_token_id": 49152,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_embd": 2048,
  "n_head": 16,
  "n_inner": 8192,
  "n_layer": 24,
  "n_positions": 2048,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "torch_dtype": "float32",
  "transformers_version": "4.30.0.d

In [None]:
# PROBLEM: method inner handling of inputs, not sent to GPU device 0
 
# from_pretrained_no_processing
from bertviz.neuron_view import show

show(lm.model, lm.model_type, lm.tokenizer, sentence_a)