In [2]:
import os
from pathlib import Path
os.chdir("/Users/oliverdaniels-koch/projects/elk-experiments")
out_dir = Path("output")
device = "mps"

# Explore EAP Graphs on Hex

I'm curious if we see notable differences when running (aggregated) edge attribution patching 
on the trusted and untrusted data of the hex task

I suspect there's a lot of in-distribution variation, but maybe we'll see two distinct circuits?

I also want to create a detector using k-means clustering?

I guess there's just a bunch of ways to learn a latent space of the adjacency / score matrix

Seems like there should be something smarter 

In [39]:
from functools import partial

import torch 
import numpy as np


In [4]:
from cupbearer.tasks import tiny_natural_mechanisms
from elk_experiments.tiny_natural_mechanisms_utils import get_task_subset

In [3]:
task = tiny_natural_mechanisms("hex", device, "pythia-70m")


Loaded pretrained model attn-only-1l into HookedTransformer
Moving model to device:  mps


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m into HookedTransformer
Moving model to device:  mps


In [29]:
task.model.set_use_hook_mlp_in(True)
task.model.set_use_split_qkv_input(True)
task.model.set_use_attn_result(True)

In [30]:
small_task = get_task_subset(task, 2048, 1048, 1024)

In [31]:
from eap.eap_wrapper import EAP

In [32]:
# use mean probability over effect tokens as metric 
def effect_prob_func(logits, effect_tokens):
    assert logits.ndim == 3
    # Sum over vocab and batch dim (for now we're just computing attribution values, we'll deal with per data instance later)
    probs = logits[:, -1, :].softmax(dim=-1)
    out = probs[:, effect_tokens].mean(dim=-1).mean() # mean over effect tokens, mean over batch
    # out = logits[:, -1, effect_tokens].mean()
    return out

In [33]:
from cupbearer.tasks.tiny_natural_mechanisms import get_effect_tokens
effect_tokens = get_effect_tokens("hex", task.model)

In [83]:
trusted_tokens = torch.stack([torch.tensor(data["prefix_tokens"]) for data in small_task.trusted_data.data])
untrusted_clean_tokens = torch.stack([torch.tensor(data["prefix_tokens"]) for data in small_task.test_data.normal_data.data])
anomalous_tokens = torch.stack([torch.tensor(data["prefix_tokens"]) for data in small_task.test_data.anomalous_data.data])

In [37]:
trusted_tokens.shape, anomalous_tokens.shape

(torch.Size([2048, 16]), torch.Size([1024, 16]))

In [40]:
task.model.reset_hooks()

clean_graph = EAP(
    model=task.model,
    clean_tokens=trusted_tokens,
    metric=partial(effect_prob_func, effect_tokens=effect_tokens),
    upstream_nodes=["head"],
    downstream_nodes=["head"],
    batch_size=64,
    verbose=True
)

Saving activations requires 0.0001 GB of memory per token


100%|██████████| 32/32 [01:48<00:00,  3.40s/it]


In [50]:
def show(graph, threshold=None, abs_scores=None, fname="eap_graph.png", fdir=None):
    import pygraphviz as pgv

    minimum_penwidth = 0.2
    edges = graph.top_edges(threshold=threshold, abs_scores=abs_scores)

    g = pgv.AGraph(
        name='root',
        strict=True,
        directed=True
    )

    g.graph_attr.update(ranksep='0.1', nodesep='0.1', compound=True)
    g.node_attr.update(fixedsize='true', width='1.5', height='.5')

    def find_layer_node(node):
        if node == f'resid_post.{graph.cfg.n_layers - 1}':
            return graph.cfg.n_layers
        else:
            return int(node.split(".")[1])

    layer_to_subgraph = {}
    layer_to_subgraph[-1] = g.add_subgraph(name=f'cluster_-1', rank='same', color='invis')
    layer_to_subgraph[-1].add_node(f'-1_invis', style='invis')

    min_layer = 999
    max_layer = -1
    layers = list(range(0, 32))

    for edge in edges:
        parent_node = edge[0]
        child_node = edge[1]
        min_layer = min(min_layer, find_layer_node(parent_node))
        max_layer = max(max_layer, find_layer_node(child_node))

    layers = list(range(min_layer, max_layer + 1))
    prev_layer = None

    for layer in layers:
        layer_to_subgraph[layer] = g.add_subgraph(name=f'cluster_{layer}', rank='same', color='invis')
        layer_to_subgraph[layer].add_node(f'{layer}_invis', style='invis')

        if prev_layer is not None:
            g.add_edge(f'{prev_layer}_invis', f'{layer}_invis', style='invis', weight=1000)

        prev_layer = layer
                
    # Adding nodes and edges between nodes
    for edge in edges:
        parent_node, child_node, edge_score = edge

        parent_name = parent_node
        child_name = child_node

        child_name = child_name.replace(".q", "").replace(".k", "").replace(".v", "")
        
        for node_name in [parent_name, child_name]:

            node_layer = find_layer_node(node_name)

            node_color = '#1f77b4' if node_name.startswith("head") else '#ff7f0e' if node_name.startswith("mlp") else '#2ca02c' if node_name.startswith("resid") else '#d62728'

            layer_to_subgraph[node_layer].add_node(
                node_name,
                fillcolor=node_color,
                color="black",
                style="filled, rounded",
                shape="box",
                fontname="Helvetica",
            )
            
        edge_width = str(max(minimum_penwidth, edge_score*100))

        g.add_edge(
            parent_name,
            child_name,
            penwidth=edge_width,
            color='#0091E4',
            weight=10,
            minlen='0.5',
        )
    # fdir = fdir if fdir is not None else DEFAULT_GRAPH_PLOT_DIR
    save_path = os.path.join(fdir, fname)
    print(f"Saving graph")
    if not fname.endswith(".gv"): # turn the .gv file into a .png file
        g.draw(path=save_path, prog='dot')

    return g

In [82]:
show(clean_graph, threshold=9e-8, abs_scores=False, fdir=out_dir, fname="clean_graph.png")

Saving graph


<AGraph b'root' <Swig Object of type 'Agraph_t *' at 0x32953e8e0>>

In [75]:
clean_edges = clean_graph.top_edges(threshold=5e-8, abs_scores=False)

In [76]:
clean_edges

[('head.0.3', 'head.1.3.v', 3.1783656595507637e-07),
 ('head.0.3', 'head.1.4.v', 1.4391096669896797e-07),
 ('head.0.6', 'head.3.3.v', 1.400565281528543e-07),
 ('head.0.6', 'head.3.1.v', 1.2082939804258785e-07),
 ('head.2.5', 'head.3.1.v', 1.0892620139202336e-07),
 ('head.0.0', 'head.2.1.v', 9.165699310642594e-08),
 ('head.0.6', 'head.1.0.v', 8.914732063658448e-08),
 ('head.0.3', 'head.1.0.v', 8.54735233701831e-08),
 ('head.1.3', 'head.3.3.v', 7.186512362977737e-08),
 ('head.0.3', 'head.3.7.v', 6.30051957273281e-08),
 ('head.1.3', 'head.3.7.v', 6.278842334950241e-08),
 ('head.1.4', 'head.3.1.v', 6.268881946880356e-08),
 ('head.0.7', 'head.3.7.v', 6.248009754017403e-08),
 ('head.2.6', 'head.3.7.v', 5.96108336026191e-08),
 ('head.0.3', 'head.2.4.v', 5.65856090872785e-08),
 ('head.3.1', 'head.4.0.v', 5.586666063095436e-08),
 ('head.2.6', 'head.3.2.v', 5.513895828812565e-08),
 ('head.0.4', 'head.3.7.v', 5.440797679057141e-08),
 ('head.0.1', 'head.1.7.v', 5.3475812222814056e-08),
 ('head.0.5

In [53]:
task.model.reset_hooks()

anomalous_graph = EAP(
    model=task.model,
    clean_tokens=anomalous_tokens,
    metric=partial(effect_prob_func, effect_tokens=effect_tokens),
    upstream_nodes=["head"],
    downstream_nodes=["head"],
    batch_size=64,
    verbose=True
)

Saving activations requires 0.0001 GB of memory per token


100%|██████████| 16/16 [00:56<00:00,  3.53s/it]


In [81]:
show(anomalous_graph, threshold=9e-8, abs_scores=False, fdir=out_dir, fname="anomalous_graph.png")

Saving graph


<AGraph b'root' <Swig Object of type 'Agraph_t *' at 0x3295adec0>>

In [79]:
anomalous_edges = anomalous_graph.top_edges(threshold=9e-8, abs_scores=False)

In [80]:
anomalous_edges

[('head.0.3', 'head.1.3.v', 4.6406509568441834e-07),
 ('head.0.3', 'head.1.4.v', 2.5037650175363524e-07),
 ('head.0.6', 'head.3.3.v', 1.8278846880548372e-07),
 ('head.0.3', 'head.1.0.v', 1.727667182649384e-07),
 ('head.0.6', 'head.3.1.v', 1.5554519450233784e-07),
 ('head.0.6', 'head.1.0.v', 1.4941561232717504e-07),
 ('head.2.5', 'head.3.1.v', 1.4546419890848483e-07),
 ('head.0.7', 'head.1.3.v', 1.2531749860045238e-07),
 ('head.1.3', 'head.3.1.v', 1.079033751238967e-07),
 ('head.1.3', 'head.2.6.q', 1.0309624087767588e-07),
 ('head.3.1', 'head.4.0.v', 1.0156108487535676e-07),
 ('head.0.3', 'head.3.7.v', 9.93367237356324e-08),
 ('head.1.4', 'head.3.1.v', 9.449505000702629e-08),
 ('head.0.6', 'head.2.0.v', 9.381216159454198e-08),
 ('head.0.5', 'head.2.6.k', 9.37108524112773e-08),
 ('head.0.0', 'head.2.1.v', 9.233470876779393e-08)]

In [84]:
task.model.reset_hooks()

unstrusted_clean_graph = EAP(
    model=task.model,
    clean_tokens=untrusted_clean_tokens,
    metric=partial(effect_prob_func, effect_tokens=effect_tokens),
    upstream_nodes=["head"],
    downstream_nodes=["head"],
    batch_size=64,
    verbose=True
)

Saving activations requires 0.0001 GB of memory per token


100%|██████████| 16/16 [00:58<00:00,  3.64s/it]


In [85]:
show(unstrusted_clean_graph, threshold=9e-8, abs_scores=False, fdir=out_dir, fname="untrusted_clean_graph.png")

Saving graph


<AGraph b'root' <Swig Object of type 'Agraph_t *' at 0x2f76d3c60>>

# Try Using Set Difference As anomaly score 

# Try Getting Top Edges across Entire Distribution, filtering anomolies with respect to that