In [1]:
from transformer_lens import HookedTransformer
import torch

device = "cuda:0" if torch.cuda.is_available() else "cpu"

model_name = "gemma-2b"
model = HookedTransformer.from_pretrained(
        model_name,
        center_unembed=True,
        center_writing_weights=True,
        fold_ln=True,
        device = device
    )



Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



Loaded pretrained model gemma-2b into HookedTransformer


In [2]:
import plotly.express as px
import pandas as pd
import transformer_lens.utils as utils

def imshow(tensor, renderer=None, midpoint=0, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=midpoint, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)


In [3]:
import sys
import os

# Add the parent directory to sys.path
sys.path.append(os.path.abspath(".."))  # go up one level to the root

from src.advanced_path_patching import *
import json, random
import pandas as pd

def load_json(filename):
    with open(filename, 'r') as fp:
        return json.load(fp)
    
tv_data = pd.read_csv("../data/penn_pos_dataset_all.csv")

#tag_set = ["noun", "pronoun", "verb", "adverb", "adjective", "preposition"]
tag_set = ["noun", "pronoun", "verb", "adverb", "adjective", "preposition", "particle", "determiner", "conjunction", "number"]

template = """Sentence: The cat chased the mouse.
POS tag: noun
Answer: cat

Sentence: She went to the market.
POS tag: pronoun
Answer: She

Sentence: Birds fly across the sky.
POS tag: verb
Answer: fly

Sentence: He quickly closed the door.
POS tag: adverb
Answer: quickly

Sentence: The tall man opened the door.
POS tag: adjective
Answer: tall

Sentence: She walked under the bridge.
POS tag: preposition
Answer: under

Sentence: {sentence}\nPOS tag: {tag}\nAnswer:"""

In [4]:
receiver_nodes = [(17, None)]   # MLP at layer 13
component = "z"
position = -1
freeze_mlps = True
indirect_patch = False
metric="logit_diff"
seed = 42
null_task = True
allow_multitoken = False

In [5]:
PTB_TO_COARSE = {
    # Nouns
    "NN":"noun","NNS":"noun","NNP":"noun","NNPS":"noun",
    # Verbs (incl. modals as Verb for coarse)
    "VB":"verb","VBD":"verb","VBG":"verb","VBN":"verb","VBP":"verb","VBZ":"verb","MD":"verb",
    # Adjectives
    "JJ":"adjective","JJR":"adjective","JJS":"adjective",
    # Adverbs
    "RB":"adverb","RBR":"adverb","RBS":"adverb","WRB":"adverb",
    # Pronouns
    "PRP":"pronoun","PRP$":"pronoun","WP":"pronoun","WP$":"pronoun",
    # Prepositions / Subordinators
    "IN":"preposition","TO":"preposition",
    # Determiners
    "DT":"determiner","PDT":"determiner","WDT":"determiner",
    # Conjunctions
    "CC":"conjunction",
    # Numbers
    "CD":"number",
    # Particles
    "RP":"particle",
    # Everything else
    "EX":"other","FW":"other","LS":"other","SYM":"other","UH":"other","POS":"other",
    "``":"other","''":"Other",",":"Other",".":"Other",":":"Other","-LRB-":"other","-RRB-":"other"
}

In [None]:
output = torch.zeros(model.cfg.n_layers, model.cfg.n_heads)
import spacy, random
import numpy as np
import ast

random.seed(seed)
coarse_tags = ["adjective","noun","verb","adverb","preposition","particle",
               "determiner","conjunction","number","pronoun","other"]
multi_first_token = True

tv_data = tv_data.head(10)
num_samples = 0
for i, item in tv_data.iterrows():

    for tag in coarse_tags:

        item[tag] = ast.literal_eval(item[tag])
        
        if item[tag] != []:
            tag_to_use = tag

            clean_prompt = template.format(
                sentence=item["Sentence"],
                tag=tag_to_use
            )
            corrupted_prompt = template.format(
                sentence=item["Sentence"],
                tag="null"
            )
            source_toks = model.to_tokens(clean_prompt, prepend_bos=False).squeeze(-1)
            corr_toks = model.to_tokens(corrupted_prompt, prepend_bos=False).squeeze(-1)

            ans_tokens_list = []
            
            for word in item[tag]:
                word_token = model.to_tokens(word, prepend_bos=False).squeeze(-1)
                
                if word_token.shape != torch.Size([1]):
                    if not multi_first_token:
                        continue
                    else:
                        for other_tag in coarse_tags:
                            if other_tag != tag:
                                for corr in item[other_tag]:
                                    corr_token = model.to_tokens(corr, prepend_bos=False).squeeze(-1)
                                    w0 = word_token.view(-1)[0].item()
                                    c0 = corr_token.view(-1)[0].item()
                                    ans_tokens = torch.tensor([w0, c0]).to(device)
                                    ans_tokens_list.append(ans_tokens)
                else:
                    for other_tag in coarse_tags:
                            if other_tag != tag:
                                for corr in item[other_tag]:
                                    corr_token = model.to_tokens(corr, prepend_bos=False).squeeze(-1)
                                    if corr_token.shape != torch.Size([1]):
                                        continue
                                    ans_tokens = torch.stack([torch.tensor((word_token, corr_token))]).to(device)
                                    ans_tokens_list.append(ans_tokens)
                
            ans_tokens = ans_tokens_list
            output+=path_patching(model, receiver_nodes, source_toks, corr_toks, ans_tokens, component, position, freeze_mlps, indirect_patch, metric=metric, is_multitoken=allow_multitoken)
        else:
            continue

        num_samples += 1
    
output/=num_samples
print("OUTPUT", output)

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

Output()

OUTPUT tensor([[-3.3769e-02, -2.3966e-03, -4.4788e-03,  1.7158e-01,  4.4039e-01,
          5.2668e-03,  3.4098e-01, -4.9217e-03],
        [ 2.4000e-01, -9.7070e-01,  1.4912e-01,  1.4112e-01,  6.7331e-01,
         -3.9168e-01, -2.8782e+00,  3.1060e+00],
        [ 2.0666e-01,  3.4737e-01, -5.2504e-01,  7.2348e-01, -2.7794e-01,
         -4.4331e-01, -4.4078e-02,  9.0643e-01],
        [ 2.8956e-01,  1.1862e+00,  1.1918e+00, -7.8435e-01,  4.0478e-01,
          5.3864e-01,  3.2520e-01,  3.5325e-02],
        [-1.4930e-01,  3.2431e+00,  9.4370e-02, -1.6488e+00,  1.8274e+00,
          7.3616e-01,  3.5569e-01, -5.6934e-01],
        [ 2.0919e+00,  1.7133e+00, -6.3456e-01, -1.2103e+00,  3.2145e-02,
         -3.9980e-01, -7.8119e-01, -9.1423e-02],
        [-6.5875e-01,  1.4945e-01,  2.5342e+00, -2.0832e+00, -2.0797e-01,
         -1.8818e+00,  7.1947e-01,  4.4496e-02],
        [-1.3595e+00, -5.0600e-01,  4.7712e+00,  7.0954e-01, -1.1720e+00,
          1.7017e+00, -9.5223e-01,  2.6551e-01],
        [

In [7]:
recv_str = '_'.join(['-'.join([str(si) for si in s if si is not None]) for s in receiver_nodes])
print("Saving to", f'{"deneme".strip(".json") }.npy')
np.save(f'{"deneme".strip(".json") }.npy', output.numpy())

Saving to deneme.npy


In [8]:
import numpy as np
import plotly.express as px
from pathlib import Path

def show_path_patching_heatmap(arr_or_path, title=None, renderer=None,
                               midpoint=0.0, symmetric=True, zmax=None,
                               origin="upper", annotate=False, fmt=".2f"):
    """
    arr_or_path: np.ndarray of shape (n_layers, n_heads) or path to .npy
    origin: "upper" puts layer 0 at top; use "lower" to put it at bottom
    """
    A = np.load(arr_or_path) if isinstance(arr_or_path, (str, Path)) else np.asarray(arr_or_path)
    assert A.ndim == 2, "Expected (layers, heads) array"

    n_layers, n_heads = A.shape
    if symmetric:
        vmax = float(np.nanmax(np.abs(A))) if zmax is None else float(zmax)
        zmin, zmax = -vmax, vmax
    else:
        zmin = None

    fig = px.imshow(
        A,
        color_continuous_scale="RdBu",
        color_continuous_midpoint=midpoint,
        zmin=zmin, zmax=zmax,
        origin=origin,
        labels=dict(color="% path-patch score")
    )
    fig.update_xaxes(title="Head", tickmode="array", tickvals=list(range(n_heads)))
    fig.update_yaxes(title="Layer", tickmode="array", tickvals=list(range(n_layers)))
    fig.update_layout(title=title,
                      height=600,                     # keep same vertical size
                    width=n_heads * 50)

    if annotate:
        text = np.vectorize(lambda x: f"{x:{fmt}}")(A)
        fig.update_traces(text=text, texttemplate="%{text}", textfont_size=10)

    fig.show(renderer)

# Example use:
arr = np.load("deneme.npy")  # you saved output*100
# show_path_patching_heatmap(arr, title=f"Path patching (%) – {recv_str}", symmetric=True, origin="upper")
# or directly:
show_path_patching_heatmap("deneme.npy", title=f"Path patching (%) – {recv_str}")
