# Train mask over IOI edges and analyze mask vs known circuit

In [1]:
from models import load_gpt2_weights, load_demo_gpt2, tokenizer
from data import retrieve_toxic_data, retrieve_owt_data, retrieve_toxic_data_low_loss, retrieve_toxic_filtered_data, FILTER_DEMO_LEN, CONTEXT_LENGTH
from inference import infer_batch_with_owt, infer_batch, prepare_fixed_demo, criterion
from torch.optim import AdamW
import torch
import pickle
import datasets
from tqdm import tqdm_notebook as tqdm
from itertools import cycle
# from eval import evaluate_model
from data import batch_text_to_tokens
import plotly.express as px

Using device: cuda:0


## Load mask into model

In [2]:
means_ioi = True

if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

model = load_demo_gpt2(means=means)

In [3]:
with open("models/alternative_necessary_masks_params_dict_lambda=1_means_ioi=True.pkl", "rb") as f:
    necessary_masks_dict = pickle.load(f)
with open("models/alternative_sufficient_masks_params_dict_lambda=1_means_ioi=True.pkl", "rb") as f:
    sufficient_masks_dict = pickle.load(f)
with open("models/params_dict_lambda=1.pkl", "rb") as f:
    mask_params_dict = pickle.load(f)
with open("models/acdcpp_mask_params.pkl", "rb") as f:
    acdcpp_mask_params = pickle.load(f)

final_masks = {"necessary": necessary_masks_dict[200], "sufficient": sufficient_masks_dict[200], "acdcpp": acdcpp_mask_params, "circuit_breaking": mask_params_dict[200]}

In [4]:
def load_mask_into_model(model, mask):
    # load in place
    mask_idx = 0
    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data = mask[mask_idx].to(param.device)
            mask_idx += 1

def reset_mask(model):
    for name, param in model.named_parameters():
        if param.requires_grad:
            param.data = torch.ones_like(param.data).to(param.device)

## Test model before and after circuit breaking

In [5]:
import pickle
with open("data/ioi_prompts_test.pkl", "rb") as f:
    ioi_prompts_test = pickle.load(f)
    # ioi_sentences_test = [t[2] for t in ioi_sentences_test]

with open("data/eval_uniform.pkl", "rb") as f:
    uniform_samples = pickle.load(f)
    uniform_sentences = [t[2] for t in uniform_samples]

with open("models/masked_gpt2_mean_ablation_v6.pkl", "rb") as f:
    model.state_dict = pickle.load(f)

In [6]:
# Run inference on an ioi_sentence
ioi_prompt = ioi_prompts_test[0]
print(ioi_prompt)

model.eval()
model.to('cuda')
def get_last_token(model, prompt, topk=5, sentence=False):
    # generate last token
    if not sentence:
        prompt_sentence = prompt['text']
    else:
        prompt_sentence = prompt

    tokens = tokenizer(prompt_sentence, return_tensors='pt').input_ids[:, :-1]

    # generate one token, decode original_model(ioi_tokens[:, :-1])
    model_outputs = model(tokens)[0]
    model_outputs = model_outputs.squeeze(0)[-1]
    probs = torch.nn.functional.softmax(model_outputs, dim=-1)

    topk_outputs = torch.topk(model_outputs, topk)
    topk_tokens = topk_outputs.indices
    topk_probs = probs[topk_outputs.indices]
    
    topk_tokens_decoded = tokenizer.batch_decode(topk_tokens)
    
    if not sentence:
        # Get logit diff by finding difference between logit of 
        io_token = tokenizer(" " + prompt['IO'], return_tensors='pt').input_ids[:, -1]
        s_token = tokenizer(" " + prompt['S'], return_tensors='pt').input_ids[:, -1]
        logit_diff = model_outputs[io_token][0] - model_outputs[s_token][0]
        return topk_tokens_decoded, topk_probs, logit_diff
    return topk_tokens_decoded, topk_probs

def get_ioi_score(model, num_samples):
    ave_logit_diffs = []
    for idx in range(num_samples):
        prompt = ioi_prompts_test[idx]
        ave_logit_diffs.append(get_last_token(model, prompt)[2])
    return sum(ave_logit_diffs) / len(ave_logit_diffs)

# get OWT loss
def get_owt_loss(model, num_samples):
    owt_losses = []
    for idx in range(num_samples):
        prompt = uniform_sentences[idx]
        owt_losses.append(infer_batch_with_owt(model, prompt['text'], prompt['IO'], prompt['S']))
    return sum(owt_losses) / len(owt_losses)

{'[PLACE]': 'restaurant', '[OBJECT]': 'snack', 'text': 'While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to Alicia', 'IO': 'Alicia', 'S': 'Joshua', 'TEMPLATE_IDX': 24, 'C': 'Laura'}


In [7]:
def invert_mask(mask, keep_output=True):
    new_mask = []
    for idx, m in enumerate(mask):
        if idx == 0 and keep_output:
            new_mask.append(torch.ones_like(m))
        else:
            new_mask.append(1 - m)
    return new_mask

In [8]:
reset_mask(model)
print("Average logit diff with no edges masked: ", get_ioi_score(model, 20))

load_mask_into_model(model, necessary_masks_dict[200])
print(f"Average logit diff with necessary edges masked: {get_ioi_score(model, 50)}")

load_mask_into_model(model, sufficient_masks_dict[200])
print(f"Average logit diff with sufficient edges masked: {get_ioi_score(model, 50)}")

load_mask_into_model(model, mask_params_dict[200])
print(f"Average logit diff post circuit-breaking: {get_ioi_score(model, 50)}")

load_mask_into_model(model, acdcpp_mask_params)
print(f"Average logit diff post masking ACDC++: {get_ioi_score(model, 50)}")

Average logit diff with no edges masked:  tensor(4.1287, device='cuda:0', grad_fn=<DivBackward0>)
Average logit diff with necessary edges masked: 0.013599695637822151
Average logit diff with sufficient edges masked: -2.8097500801086426
Average logit diff post circuit-breaking: 0.3835521340370178
Average logit diff post masking ACDC++: 0.004391784779727459


In [9]:
print("Inverted Masks")
reset_mask(model)
print("Average logit diff with no edges masked: ", get_ioi_score(model, 20))

load_mask_into_model(model, invert_mask(necessary_masks_dict[200]))
print(f"Average logit diff with necessary edges masked: {get_ioi_score(model, 100)}")

load_mask_into_model(model, invert_mask(sufficient_masks_dict[200]))
print(f"Average logit diff with sufficient edges masked: {get_ioi_score(model, 100)}")

load_mask_into_model(model, invert_mask(mask_params_dict[200]))
print(f"Average logit diff post circuit-breaking: {get_ioi_score(model, 100)}")

load_mask_into_model(model, invert_mask(acdcpp_mask_params))
print(f"Average logit diff post masking ACDC++: {get_ioi_score(model, 100)}")

Inverted Masks
Average logit diff with no edges masked:  tensor(4.1287, device='cuda:0', grad_fn=<DivBackward0>)


KeyboardInterrupt: 

## Manually Check Edges

In [5]:
name_mover_heads = ['a10.0', 'a9.9', 'a9.6']
negative_heads = ['a10.7', 'a11.10']
s2_inhibition_heads = ['a8.10', 'a7.9', 'a8.6', 'a7.3']
induction_heads = ['a5.5', 'a6.9', 'a5.9', 'a5.8']
duplicate_token_heads = ['a0.1', 'a0.10', 'a3.0']
previous_token_heads = ['a4.11', 'a2.2', 'a2.9']
backup_name_mover_heads = ['a11.2', 'a10.2', 'a10.6', 'a10.1', 'a10.10', 'a9.7', 'a11.9', 'a11.3']

circuit_dict = {}

for head in name_mover_heads:
    circuit_dict[head] = 'name_mover'

for head in negative_heads:
    circuit_dict[head] = 'negative'

for head in s2_inhibition_heads:
    circuit_dict[head] = 's2_inhibition'

for head in induction_heads:
    circuit_dict[head] = 'induction'

for head in duplicate_token_heads:
    circuit_dict[head] = 'duplicate_token'

for head in previous_token_heads:
    circuit_dict[head] = 'previous_token'

for head in backup_name_mover_heads:
    circuit_dict[head] = 'backup_name_mover'

In [9]:
def check_relevant_edges(mask_edges, circuit_nodes=circuit_dict):
    # edges where one of the nodes is in the circuit
    relevant_edges = []

    # edges where both nodes are in the circuit
    circuit_edges = []
    
    for edge in mask_edges:
        if edge[0][1] in circuit_nodes or edge[1][1] in circuit_nodes:
            relevant_edges.append(edge)
        if edge[0][1] in circuit_nodes and edge[1][1] in circuit_nodes:
            circuit_edges.append(edge)
    return relevant_edges, circuit_edges

number of relevant edges: 63, number of circuit edges: 0


In [10]:
from mask_utils import get_nodes_and_edges
_, _, sufficient_edges, _ = get_nodes_and_edges(necessary_masks_dict[200])
relevant_edges, circuit_edges = check_relevant_edges(sufficient_edges)
print(f"number of relevant edges: {len(relevant_edges)}, number of circuit edges: {len(circuit_edges)}")

_, _, acdcpp_edges, _ = get_nodes_and_edges(acdcpp_mask_params)
relevant_edges, circuit_edges = check_relevant_edges(acdcpp_edges)
print(f"number of relevant edges: {len(relevant_edges)}, number of circuit edges: {len(circuit_edges)}")

number of relevant edges: 63, number of circuit edges: 0
number of relevant edges: 158, number of circuit edges: 62


## Visualize mask
Create the computational graphs in edge attribution patching paper

### Load mask and calculate what edges are present

In [None]:
def get_nodes_and_edges(mask_params, edge_0=True):
    """
    If edge_0 is True, then edges are between nodes with mask value 0. Else, edges are between nodes with mask value 1.
    """
    # calculate which nodes will be in the graph
    connected_nodes = set()
    # add embed node at position
    # connected_nodes.add((-1, "embed"))
    n_heads = 12
    n_layers = 12

    # associate each node with a position
    all_possible_nodes = [(-1, "embed")]
    mask_dict = {}
    # empty tensor
    mask_dict["embed"] = torch.zeros(size=(0,))
    for idx in range(len(mask_params)):
        if "attention" in param_names[idx]:
            layer = int(param_names[idx].split(".")[1])
            for i in range(n_heads):
                all_possible_nodes.append((layer, f"a{layer}.{i}"))
                mask_dict[f"a{layer}.{i}"] = mask_params[idx][:,i].detach().cpu()
        elif "mlp" in param_names[idx]:
            layer = int(param_names[idx].split(".")[1])
            all_possible_nodes.append((layer, f"m{layer}"))
            mask_dict[f"m{layer}"] = mask_params[idx].detach().cpu()
    all_possible_nodes.append((n_heads, "output"))
    mask_dict["output"] = mask_params[-1]

    # Calculate where edges are based on the mask
    # Edge between node i and node j if mask_dict[i][all_possible_nodes.index(j)] == 0
    edges = set()
    for i in range(len(all_possible_nodes)):
        for j in range(len(all_possible_nodes)):
            j_index = all_possible_nodes.index(all_possible_nodes[j])
            if j_index < len(mask_dict[all_possible_nodes[i][1]]) and mask_dict[all_possible_nodes[i][1]][all_possible_nodes.index(all_possible_nodes[j])] == (0 if edge_0 else 1):
                edges.add((all_possible_nodes[i], all_possible_nodes[j]))
    
    nodes_with_edges = set([node for edge in edges for node in edge])

    return all_possible_nodes, nodes_with_edges, edges, mask_dict
all_possible_nodes, nodes_with_edges, edges, mask_dict = get_nodes_and_edges(mask_params)

### Analyze ACDC and Compare
I separately used ACDC++ (EAP from "Attribution Patching Outperforms Automated Circuit Discovery" paper) to get the known circuit edges. I want to compare my various learned masks (from different losses) to the known circuit edges.

In [None]:
with open("models/acdcpp_edges.pkl", "rb") as f:
    acdcpp_edges_long = pickle.load(f)
acdcpp_edges_long

In [None]:
# acdcpp edges are in format 'blocks.1.attn.hook_result[:, :, 10]blocks.0.hook_mlp_in[:]', convert to format of ((1, 'a1.10'), (0, 'm0'))

def get_node_name(node_name, show_full_index=False):
    """Node name for use in pretty graphs"""

    def get_index(node_name_long):
        # Get the index by looking for number in brackets
        # e.g. blocks.1.attn.hook_result[:, :, 10] -> 10
        index = node_name_long.split("[")[-1].split("]")[0]
        index = index.split(", ")[-1]
        return int(index)

    if not show_full_index:
        name = ""
        qkv_substrings = [f"hook_{letter}" for letter in ["q", "k", "v"]]
        qkv_input_substrings = [f"hook_{letter}_input" for letter in ["q", "k", "v"]]

        # Handle embedz
        if "resid_pre" in node_name:
            assert "0" in node_name and not any([str(i) in node_name for i in range(1, 10)])
            name += "embed"
            layer = -1
            # if len(node.index.hashable_tuple) > 2:
            #     name += f"_[{node.index.hashable_tuple[2]}]"
            # return name

        elif "embed" in node_name:
            name = "pos_embeds" if "pos" in node_name else "token_embeds"
            layer = -1

        # Handle q_input and hook_q etc
        elif any([node_name.endswith(qkv_input_substring) for qkv_input_substring in qkv_input_substrings]):
            relevant_letter = None
            for letter, qkv_substring in zip(["q", "k", "v"], qkv_substrings):
                if qkv_substring in node_name:
                    assert relevant_letter is None
                    relevant_letter = letter
            name += "a" + node_name.split(".")[1] + "." + str(get_index(node_name)) + "_" + relevant_letter
            layer = int(node_name.split(".")[1])

        # Handle attention hook_result
        elif "hook_result" in node_name or any([qkv_substring in node_name for qkv_substring in qkv_substrings]):
            name = "a" + node_name.split(".")[1] + "." + str(get_index(node_name))
            layer = int(node_name.split(".")[1])

        # Handle MLPs
        elif node_name.endswith("resid_mid"):
            raise ValueError("We removed resid_mid annotations. Call these mlp_in now.")
        elif "mlp" in node_name:
            name = "m" + node_name.split(".")[1]
            layer = int(node_name.split(".")[1])

        # Handle resid_post
        elif "resid_post" in node_name:
            name += "resid_post"
            layer = 12

        # elif "mlp" in node_name:
        #     name += "m" + node_name.split(".")[1]
        else:
            raise ValueError(f"Unrecognized node name {node_name}")

    else:
        name = node_name
        # name = node_name + str(node.index.graphviz_index(use_actual_colon=True))

    # get layer by looking for number before first dot
    

    return layer, name

acdcpp_edges = set()
for edge in acdcpp_edges_long[0][0.08]:
    # split the edge into two nodes, e.g. blocks.1.attn.hook_result[:, :, 10]blocks.0.hook_mlp_in[:] into blocks.1.attn.hook_result[:, :, 10] and blocks.0.hook_mlp_in[:]
    node_1 = get_node_name(edge.split("]")[0]+"]", show_full_index=False)
    node_2 = get_node_name(edge.split("]")[1]+"]", show_full_index=False)
    acdcpp_edges.add((node_1, node_2))


### Analyze overlaps between different edges

In [None]:
with open("models/alternative_necessary_masks_params_dict_lambda=1.pkl", "rb") as f:
    alternative_necessary_mask_params = pickle.load(f)
    alternative_necessary_mask_params = alternative_necessary_mask_params[200]
with open("models/alternative_sufficient_masks_params_dict_lambda=1.pkl", "rb") as f:
    alternative_sufficient_mask_params = pickle.load(f)
    alternative_sufficient_mask_params = alternative_sufficient_mask_params[200]
_, _, alternative_necessary_edges, _ = get_nodes_and_edges(alternative_necessary_mask_params)
_, _, alternative_sufficient_edges, _ = get_nodes_and_edges(alternative_sufficient_mask_params, edge_0=False)

In [None]:
get_nodes_and_edges(alternative_sufficient_mask_params, edge_0=False)

In [None]:
print(f"{len(edges)=}, {len(acdcpp_edges)=}, {len(edges.intersection(acdcpp_edges))=}")
print(edges.intersection(acdcpp_edges))

In [None]:
# Get overlaps between all edges (regular edges, necessary, sufficient, acdcpp) (make a table with tabulate)
edges_dict = {"circuit_breaking":edges, "ioi_necessary":alternative_necessary_edges, "ioi_sufficient":alternative_sufficient_edges, "acdcpp":acdcpp_edges}
for edge_type in edges_dict:
    for second_edge_type in edges_dict:
        print(f"{edge_type} and {second_edge_type}: {len(edges_dict[edge_type].intersection(edges_dict[second_edge_type]))} edges in common")#, {edges_dict[edge_type].intersection(edges_dict[second_edge_type])}")

### Visualizations

In [None]:
def create_aligned_graph(all_possible_nodes, edges):
    G = pgv.AGraph(strict=False, directed=True)

    # Find the maximum layer number for adjusting the graph
    max_layer = max(layer for layer, _ in all_possible_nodes if isinstance(layer, int))
    nodes_with_edges = set([node for edge in edges for node in edge])

    # Add nodes and edges to the graph
    for node in all_possible_nodes:
        if node in [edge[0] for edge in edges] or node in [edge[1] for edge in edges]:
            G.add_node(node[1], layer=str(max_layer - node[0]))

    for edge in edges:
        G.add_edge(edge[1][1], edge[0][1])

    # Create subgraphs to ensure nodes of the same layer have the same rank
    for layer in range(max_layer, -2, -1):
        with G.subgraph(name=f'cluster_{layer}') as s:
            s.graph_attr['rank'] = 'same'
            for node in nodes_with_edges:
                if node[0] == layer:
                    s.add_node(node[1])

    # Apply layout and render the graph
    G.layout(prog='dot')
    G.draw('aligned_graph.png')
    return Image('aligned_graph.png')

# Call the function with your nodes and edges
flipped_graph_image = create_aligned_graph(all_possible_nodes, edges)

# To display the graph in Jupyter Notebook
flipped_graph_image


In [None]:
# intersecting edges graph
for edge_type in edges_dict:
    for second_edge_type in edges_dict:
        if edge_type == second_edge_type:
            continue
        # make a graph with just the intersecting edges, title it with the two edge types
        print(f"Intersection between {edge_type} and {second_edge_type}: {len(edges_dict[edge_type].intersection(edges_dict[second_edge_type]))} edges in common, {edges_dict[edge_type].intersection(edges_dict[second_edge_type])}")
        
        intersecting_edges_graph = create_aligned_graph(all_possible_nodes, edges_dict[edge_type].intersection(edges_dict[second_edge_type]))
        display(intersecting_edges_graph)
        # intersecting_edges_graph.render(f"intersecting_edges_graph_{edge_type}_{second_edge_type}", format="png", cleanup=True)
# intersecting_edges_graph = create_aligned_graph(all_possible_nodes, edges.intersection(acdcpp_edges))
# intersecting_edges_graph

In [None]:
import pygraphviz as pgv
from pathlib import Path
from IPython.display import Image

def show(nodes, edges, fname=None):
    g = pgv.AGraph(strict=True, directed=True)
    g.graph_attr.update(ranksep='0.1', nodesep='0.1', compound=True)
    g.node_attr.update(fixedsize='true', width='1.5', height='.5')
    
    layer_to_subgraph = {}

    # Create a subgraph for each layer
    for node in nodes:
        layer = node[0]
        if layer not in layer_to_subgraph:
            # Each layer has its own subgraph with 'rank=same' to ensure they are on the same level
            layer_to_subgraph[layer] = g.add_subgraph(name=f'cluster_{layer}', rank='same')
            
        # Here you add the node to the appropriate subgraph
        layer_to_subgraph[layer].add_node(node, label=str(node[1]))

    # Now, add the edges to the graph
    for edge in edges:
        g.add_edge(edge[0], edge[1])
    
    # If a filename is provided, write the file and optionally render to an image
    if fname:
        fpath = Path(fname)
        base_fname = fpath.stem
        base_path = fpath.parent
        base_path.mkdir(exist_ok=True, parents=True)
        
        # Write the dot file
        g.write(path=base_path / f"{base_fname}.gv")
        
        # Render to an image
        g.layout(prog='dot')
        g.draw(path=base_path / f"{base_fname}.png")
        
    return g


g = show(nodes_with_edges, edges, fname="graph.gv")
Image(g.draw(format='png', prog='dot'))