In [1]:
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model_name = "gemma-2b"
model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device = device
    )

Loaded pretrained model gemma-2b into HookedTransformer


In [2]:
import plotly.express as px
import pandas as pd
import transformer_lens.utils as utils

def imshow(tensor, renderer=None, midpoint=0, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=midpoint, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)


In [3]:
import sys, os
sys.path.append(os.path.abspath(".."))  # go up one level to the root

from src.advanced_path_patching import *
import json, random

def load_json(filename):
    with open(filename, 'r') as fp:
        return json.load(fp)
    
tv_data = load_json("../../pos_cf_datasets/pos_gpt_mixed_50_sentences.json")

tag_set = ["noun", "pronoun", "verb", "adverb", "adjective", "preposition"]
#tag_set = ["noun", "pronoun", "verb", "adverb", "adjective", "preposition", "particle", "determiner", "conjunction", "number"]

template = """Sentence: The cat chased the mouse.
POS tag: noun
Answer: cat

Sentence: She went to the market.
POS tag: pronoun
Answer: She

Sentence: Birds fly across the sky.
POS tag: verb
Answer: fly

Sentence: He quickly closed the door.
POS tag: adverb
Answer: quickly

Sentence: The tall man opened the door.
POS tag: adjective
Answer: tall

Sentence: She walked under the bridge.
POS tag: preposition
Answer: under

Sentence: {sentence}\nPOS tag: {tag}\nAnswer:"""

In [4]:
receiver_nodes = [(17, None)]   # MLP at layer 13
component = "z"
position = -1
freeze_mlps = True
indirect_patch = False
metric="tf_loss"
seed = 42
null_task = True
allow_multitoken = True

In [5]:
PTB_TO_COARSE = {
    # Nouns
    "NN":"noun","NNS":"noun","NNP":"noun","NNPS":"noun",
    # Verbs (incl. modals as Verb for coarse)
    "VB":"verb","VBD":"verb","VBG":"verb","VBN":"verb","VBP":"verb","VBZ":"verb","MD":"verb",
    # Adjectives
    "JJ":"adjective","JJR":"adjective","JJS":"adjective",
    # Adverbs
    "RB":"adverb","RBR":"adverb","RBS":"adverb","WRB":"adverb",
    # Pronouns
    "PRP":"pronoun","PRP$":"pronoun","WP":"pronoun","WP$":"pronoun",
    # Prepositions / Subordinators
    "IN":"preposition","TO":"preposition",
    # Determiners
    "DT":"determiner","PDT":"determiner","WDT":"determiner",
    # Conjunctions
    "CC":"conjunction",
    # Numbers
    "CD":"number",
    # Particles
    "RP":"particle",
    # Everything else
    "EX":"other","FW":"other","LS":"other","SYM":"other","UH":"other","POS":"other",
    "``":"other","''":"Other",",":"Other",".":"Other",":":"Other","-LRB-":"other","-RRB-":"other"
}


In [6]:
output = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
import spacy, random
import numpy as np

random.seed(seed)

# load English model with Penn Treebank tags
nlp = spacy.load("en_core_web_sm")

def first_occurrence(sentence: str, coarse_tag: str, ptb_map: dict):
    doc = nlp(sentence)
    for tok in doc:
        ptb = tok.tag_              # Penn Treebank POS
        coarse = ptb_map.get(ptb, "other")
        if coarse == coarse_tag:
            return tok.text
    return None

num_samples = 0
for i, item in enumerate(tv_data):

    item["Sentence"] = item["Sentence"].replace("\"","") + "."
    item["Answer"] = item["Answer"].replace("\"","").lower()
    item["POS tag"] = item["POS tag"].replace("\"","")

    clean_prompt = template.format(sentence=item["Sentence"], tag=item["POS tag"])
    ans_corr = None
    ans_clean = model.to_tokens(item["Answer"], prepend_bos=False).squeeze(-1)

    if null_task and metric == "kl_divergence":
        random_tag = "null"
        ans_corr = item["Answer"]

        ans_corr = model.to_tokens(ans_corr, prepend_bos=False).squeeze(-1)

    elif null_task and metric == "logit_diff":
        corrupted_answers_list = []
        for t in tag_set:
            if t != item["POS tag"]:
                random_tag = t
                ans_corr = first_occurrence(item["Sentence"], random_tag, PTB_TO_COARSE)
                if ans_corr is not None:
                    ans_corr = model.to_tokens(ans_corr, prepend_bos=False).squeeze(-1)
                    if ans_corr.shape == torch.Size([1]):
                        corrupted_answers_list.append((random_tag, ans_corr))
        random_tag = "null"

    else:
        while ans_corr == None:
            random_tag = random.choice([t for t in tag_set if t != item["POS tag"]])
            ans_corr = first_occurrence(item["Sentence"], random_tag, PTB_TO_COARSE)
        ans_corr = model.to_tokens(ans_corr, prepend_bos=False).squeeze(-1)

    corrupted_prompt = template.format(
        sentence=item["Sentence"],
        tag=random_tag
    )

    if metric == "logit_diff" and null_task == False:
        if ans_clean.shape != torch.Size([1]) or ans_corr.shape != torch.Size([1]):
            continue

        # Per-example tokenization (no growing lists)
        source_toks = model.to_tokens(clean_prompt, prepend_bos=False).squeeze(-1)
        corr_toks = model.to_tokens(corrupted_prompt, prepend_bos=False).squeeze(-1)

    elif null_task and metric == "logit_diff":

        if ans_clean.shape != torch.Size([1]):
            continue

        if corrupted_answers_list == []:
            continue

        source_toks = model.to_tokens(clean_prompt, prepend_bos=False).squeeze(-1)
        corr_toks = model.to_tokens(corrupted_prompt, prepend_bos=False).squeeze(-1)

    else:
        if null_task and allow_multitoken:
            source_toks = model.to_tokens(clean_prompt+item["Answer"], prepend_bos=False).squeeze(-1)
            corr_toks = model.to_tokens(corrupted_prompt+item["Answer"], prepend_bos=False).squeeze(-1)
        else:
            if ans_clean.shape != torch.Size([1]) or ans_corr.shape != torch.Size([1]):
                continue
            source_toks = model.to_tokens(clean_prompt, prepend_bos=False).squeeze(-1)
            corr_toks = model.to_tokens(corrupted_prompt, prepend_bos=False).squeeze(-1)

    if metric == "kl_divergence" or metric == "tf_loss":
        ans_tokens = torch.stack([torch.tensor((ans_clean))]).to(device)

    elif metric == "logit_diff" and null_task == True:
        ans_tokens_list = []
        for corr in corrupted_answers_list:
            ans_tokens = torch.stack([torch.tensor((ans_clean, corr[1]))]).to(device)
            ans_tokens_list.append(ans_tokens)

        ans_tokens = ans_tokens_list
    else:
        ans_tokens = torch.stack([torch.tensor((ans_clean, ans_corr))]).to(device)
    
    output+=path_patching(model, receiver_nodes, source_toks, corr_toks, ans_tokens, component, position, freeze_mlps, indirect_patch, metric=metric, is_multitoken=allow_multitoken, device=device)

    num_samples += 1
    if num_samples == 2:
        break
    
output/=num_samples
print("OUTPUT", output)

  ans_tokens = torch.stack([torch.tensor((ans_clean))]).to(device)


Output()

Output()

OUTPUT tensor([[ 8.9253e+00, -2.5550e+01,  6.1805e+01, -1.0101e+01, -2.0218e+00,
          3.6309e+02,  1.8482e+01, -2.2424e+01],
        [ 1.4490e+00, -3.6384e+00,  5.9469e+01,  3.4043e-01,  2.7038e+00,
         -1.8947e+00,  7.0553e-01,  3.7753e-01],
        [-1.0921e+00, -5.1475e+00,  2.8852e+01,  1.0091e-01,  5.8605e-01,
          7.4108e-01,  8.2020e+00, -3.6612e+00],
        [-5.3889e-01,  3.4455e+00,  2.7165e-01, -2.7986e+00, -1.9149e-01,
         -1.2316e+00,  8.9380e+00,  3.8713e+00],
        [-1.1971e+00, -6.0589e-01, -9.5178e-01,  2.2131e+00,  1.9897e+00,
         -4.9678e-01,  1.6811e+00, -9.3175e-01],
        [ 5.5858e+00, -1.5422e+00,  8.9790e-01, -1.2504e+00, -1.9105e+00,
          5.0825e+00, -3.9657e+00, -7.5077e-01],
        [-9.4810e-01, -2.8184e+00,  1.4395e+00,  1.2538e+00, -1.0688e+00,
         -4.9134e-01,  1.1479e+00,  5.0941e-01],
        [-4.2892e-02, -6.3253e-01, -1.3447e+00,  9.9329e-01, -7.6332e-01,
         -1.3452e+00,  3.0724e-01, -4.7600e-01],
        [

In [7]:
recv_str = '_'.join(['-'.join([str(si) for si in s if si is not None]) for s in receiver_nodes])
print("Saving to", f'{"deneme".strip(".json") }.npy')
np.save(f'{"deneme".strip(".json") }.npy', output.numpy())

Saving to deneme.npy


In [8]:
import numpy as np
import plotly.express as px
from pathlib import Path

def show_path_patching_heatmap(arr_or_path, title=None, renderer=None,
                               midpoint=0.0, symmetric=True, zmax=None,
                               origin="upper", annotate=False, fmt=".2f"):
    """
    arr_or_path: np.ndarray of shape (n_layers, n_heads) or path to .npy
    origin: "upper" puts layer 0 at top; use "lower" to put it at bottom
    """
    A = np.load(arr_or_path) if isinstance(arr_or_path, (str, Path)) else np.asarray(arr_or_path)
    assert A.ndim == 2, "Expected (layers, heads) array"

    n_layers, n_heads = A.shape
    if symmetric:
        vmax = float(np.nanmax(np.abs(A))) if zmax is None else float(zmax)
        zmin, zmax = -vmax, vmax
    else:
        zmin = None

    fig = px.imshow(
        A,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=midpoint,
        zmin=zmin, zmax=zmax,
        origin=origin,
        labels=dict(color="% path-patch score")
    )
    fig.update_xaxes(title="Head", tickmode="array", tickvals=list(range(n_heads)))
    fig.update_yaxes(title="Layer", tickmode="array", tickvals=list(range(n_layers)))
    fig.update_layout(title=title,
                      height=600,                     # keep same vertical size
                    width=n_heads * 50)

    if annotate:
        text = np.vectorize(lambda x: f"{x:{fmt}}")(A)
        fig.update_traces(text=text, texttemplate="%{text}", textfont_size=10)

    fig.show(renderer)

# Example use:
arr = np.load("deneme.npy")  # you saved output*100
# show_path_patching_heatmap(arr, title=f"Path patching (%) – {recv_str}", symmetric=True, origin="upper")
# or directly:
show_path_patching_heatmap("deneme.npy", title=f"Path patching (%) – {recv_str}")
