# Train mask over IOI edges and analyze mask vs known circuit

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("/data/phillip_guo/circuit-breaking/ioi/")
from models import load_gpt2_weights, load_demo_gpt2, tokenizer
from data import retrieve_toxic_data, retrieve_owt_data, retrieve_toxic_data_low_loss, retrieve_toxic_filtered_data, FILTER_DEMO_LEN, CONTEXT_LENGTH
from inference import infer_batch_with_owt, infer_batch, prepare_fixed_demo, criterion
from torch.optim import AdamW
import torch
import pickle
import datasets
from tqdm import tqdm_notebook as tqdm
from itertools import cycle
# from eval import evaluate_model
from data import batch_text_to_tokens
import plotly.express as px

Using device: cuda:0


## Train params of mask

In [11]:
toxic_batch_size = 10 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 10
context_length = CONTEXT_LENGTH


template_type = "single"
toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None, template_type=template_type)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

# with open("data/gpt2_means.pkl", "rb") as f:
#     means = pickle.load(f)[0][0]
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

model = load_demo_gpt2(means=means)
epochs_left = 200
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 50 # 5 # free
threshold = 0.5
epochs_trained = 0
regularization_strength = 1 # free

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 0.2 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 100
max_steps_per_epoch = 100


In [12]:
old_mask_params = {}
def duplicate_mask_params(mask_params):
    new_mask_params = []
    for p in mask_params:
        new_mask_params.append(p.data.cpu())
    return new_mask_params

prev_params = None
while epochs_left >= 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free

            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = -1 * (regularization_strength * penalty + alpha * tox_loss) + owt_loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            with torch.no_grad():
                test_ioi_sentences = ["While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to", "While Joshua and Alicia were commuting to the restaurant, Joshua gave a snack to"]
                for test_ioi_sentence in test_ioi_sentences:
                    correct_token_id = tokenizer.encode(" Alicia", return_tensors="pt").squeeze().item()
                    other_token_id = tokenizer.encode(" Joshua", return_tensors="pt").squeeze().item()
                    test_ioi_tokens = tokenizer.encode(test_ioi_sentence, return_tensors="pt").to('cuda')
                    generation = model(test_ioi_tokens)[0][:, -1]
                    probs = torch.softmax(generation, dim=-1)
                    print(f"Best Token: {tokenizer.batch_decode(torch.argmax(generation, dim=-1))}, P(Alicia) = {probs[:,correct_token_id].item()}, logit diff = {generation[:,correct_token_id].item() - generation[:,other_token_id].item()}")
            print("\n")

            old_mask_params[epochs_trained] = duplicate_mask_params(mask_params)
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    # epochs_left = int(input('continue training for this number of epochs: '))
    epochs_left = -1
    # log_every = int(input('set log frequency'))
    # edge_threshold = int(input('set edge threshold'))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/200 [00:00<?, ?it/s]

loss.item()=-10.43911361694336, ablated_edges=1277
loss.item()=-14.101018905639648, ablated_edges=2291
loss.item()=-15.741912841796875, ablated_edges=2726
loss.item()=-16.144386291503906, ablated_edges=3038
loss.item()=-17.014699935913086, ablated_edges=3236
loss.item()=-17.348852157592773, ablated_edges=3325
loss.item()=-17.951414108276367, ablated_edges=3421
loss.item()=-17.839527130126953, ablated_edges=3549
loss.item()=-18.50194549560547, ablated_edges=3616
loss.item()=-18.62908172607422, ablated_edges=3663
Epochs trained:  10
Loss: -18.6291
Total preserved: 7561.0815
Edges ablated:  3663
Toxic loss:  114.5287094116211
OWT loss:  4.276660919189453
Penalty:  0
Best Token: [' make'], P(Alicia) = 5.7655877753859386e-05, logit diff = -0.381317138671875
Best Token: [' make'], P(Alicia) = 1.7017458958434872e-05, logit diff = -0.7539596557617188


loss.item()=-18.398622512817383, ablated_edges=3712
loss.item()=-18.490976333618164, ablated_edges=3759
loss.item()=-18.431020736694336, ablate

In [None]:
with open(f"models/params_dict_lambda={regularization_strength}_{alpha=}_{means_ioi=}_{template_type=}.pkl", "wb") as f:
    pickle.dump(old_mask_params, f)

## Try circuit breaking over IOI-specific edges

In [2]:
from mask_utils import get_nodes_and_edges
with open("models/circuit_covering_mask_params.pkl", "rb") as f:
    circuit_covering_mask_params = pickle.load(f)

_, _, circuit_covering_edges, circuit_covering_mask_dict = get_nodes_and_edges(mask_params=circuit_covering_mask_params)
# circuit_covering_mask_dict # mostly 1s, 1s are frozen edges
# Circuit break only over the edges that are currently 0s in the circuit covering mask, ablate as few of them as possible


In [3]:
toxic_batch_size = 10 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 10
context_length = CONTEXT_LENGTH

template_type = "single"
toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None, template_type=template_type)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

with open("data/gpt2_means.pkl", "rb") as f:
    means = pickle.load(f)[0]

model = load_demo_gpt2(means=means, mask_dict_superset=circuit_covering_mask_dict)
epochs_left = 200
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 50 # 5 # free
threshold = 0.5
epochs_trained = 0
regularization_strength = 1 # free

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 0.2 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 0 
max_steps_per_epoch = 50

In [4]:
old_mask_params = {}
def duplicate_mask_params(mask_params):
    new_mask_params = []
    for p in mask_params:
        new_mask_params.append(p.data.cpu())
    return new_mask_params

prev_params = None
while epochs_left >= 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free

            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = -1 * (regularization_strength * penalty + alpha * tox_loss) + owt_loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            with torch.no_grad():
                test_ioi_sentences = ["While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to", "While Joshua and Alicia were commuting to the restaurant, Joshua gave a snack to"]
                for test_ioi_sentence in test_ioi_sentences:
                    correct_token_id = tokenizer.encode(" Alicia", return_tensors="pt").squeeze().item()
                    other_token_id = tokenizer.encode(" Joshua", return_tensors="pt").squeeze().item()
                    test_ioi_tokens = tokenizer.encode(test_ioi_sentence, return_tensors="pt").to('cuda')
                    generation = model(test_ioi_tokens)[0][:, -1]
                    probs = torch.softmax(generation, dim=-1)
                    print(f"Best Token: {tokenizer.batch_decode(torch.argmax(generation, dim=-1))}, P(Alicia) = {probs[:,correct_token_id].item()}, logit diff = {generation[:,correct_token_id].item() - generation[:,other_token_id].item()}")
            print("\n")

            old_mask_params[epochs_trained] = duplicate_mask_params(mask_params)
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    # epochs_left = int(input('continue training for this number of epochs: '))
    epochs_left = -1
    # log_every = int(input('set log frequency'))
    # edge_threshold = int(input('set edge threshold'))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/200 [00:00<?, ?it/s]

loss.item()=1.5367424488067627, ablated_edges=182
loss.item()=1.2619993686676025, ablated_edges=185
loss.item()=1.2599594593048096, ablated_edges=181
loss.item()=1.3758959770202637, ablated_edges=178
loss.item()=2.59861421585083, ablated_edges=176
loss.item()=1.2927062511444092, ablated_edges=174
loss.item()=2.2666947841644287, ablated_edges=174
loss.item()=1.613863468170166, ablated_edges=175
loss.item()=1.552971601486206, ablated_edges=172
loss.item()=1.3392136096954346, ablated_edges=174
Epochs trained:  10
Loss: 1.3392
Total preserved: 11434.6025
Edges ablated:  174
Toxic loss:  12.350775718688965
OWT loss:  3.809368848800659
Penalty:  0
Best Token: [' the'], P(Alicia) = 6.447839382417442e-07, logit diff = -4.6201324462890625
Best Token: [' the'], P(Alicia) = 6.147521958155266e-07, logit diff = -3.2760467529296875


loss.item()=1.8121588230133057, ablated_edges=174
loss.item()=2.0960686206817627, ablated_edges=175
loss.item()=1.0074710845947266, ablated_edges=172
loss.item()=1.9857

In [5]:
with open(f"models/circuit_covering_circuit_breaking_params_dict_lambda={regularization_strength}_{alpha=}_means_ioi=False_{template_type=}.pkl", "wb") as f:
    pickle.dump(old_mask_params, f)

## Test model before and after circuit breaking

In [None]:
import pickle
with open("data/ioi_prompts_test.pkl", "rb") as f:
    ioi_prompts_test = pickle.load(f)
    # ioi_sentences_test = [t[2] for t in ioi_sentences_test]

with open("data/eval_uniform.pkl", "rb") as f:
    uniform_samples = pickle.load(f)
    uniform_sentences = [t[2] for t in uniform_samples]

original_model = load_demo_gpt2(means=False)

with open("models/masked_gpt2_mean_ablation_v6.pkl", "rb") as f:
    model.state_dict = pickle.load(f)

In [None]:
ioi_prompts_test

In [None]:
# Run inference on an ioi_sentence
ioi_prompt = ioi_prompts_test[0]
print(ioi_prompt)

original_model.eval()
original_model.to('cuda')
def get_last_token(model, prompt, topk=5, sentence=False):
    # generate last token
    if not sentence:
        prompt_sentence = prompt['text']
    else:
        prompt_sentence = prompt

    tokens = tokenizer(prompt_sentence, return_tensors='pt').input_ids[:, :-1]

    # generate one token, decode original_model(ioi_tokens[:, :-1])
    model_outputs = model(tokens)[0]
    model_outputs = model_outputs.squeeze(0)[-1]
    probs = torch.nn.functional.softmax(model_outputs, dim=-1)

    topk_outputs = torch.topk(model_outputs, topk)
    topk_tokens = topk_outputs.indices
    topk_probs = probs[topk_outputs.indices]
    
    topk_tokens_decoded = tokenizer.batch_decode(topk_tokens)
    
    if not sentence:
        # Get logit diff by finding difference between logit of 
        io_token = tokenizer(" " + prompt['IO'], return_tensors='pt').input_ids[:, -1]
        s_token = tokenizer(" " + prompt['S'], return_tensors='pt').input_ids[:, -1]
        logit_diff = model_outputs[io_token][0] - model_outputs[s_token][0]
        return topk_tokens_decoded, topk_probs, logit_diff
    return topk_tokens_decoded, topk_probs

print(get_last_token(original_model, ioi_prompt))


for idx in range(3):
    print(uniform_sentences[idx])
    print("Before ablation")
    print(get_last_token(original_model, uniform_sentences[idx], sentence=True)[0])
    print()
    print("After ablation")
    print(get_last_token(model, uniform_sentences[idx], sentence=True)[0])
    print("\n\n")

In [None]:
# Try on uniform samples
def get_ioi_score(model, num_samples):
    ave_logit_diffs = []
    for idx in range(num_samples):
        prompt = ioi_prompts_test[idx]
        ave_logit_diffs.append(get_last_token(model, prompt)[2])
    return sum(ave_logit_diffs) / len(ave_logit_diffs)

get_ioi_score(original_model, 20)

## Visualize mask
Create the computational graphs in edge attribution patching paper

### Load mask and calculate what edges are present

In [None]:
with open("models/params_dict_lambda=2.pkl", "rb") as f:
    mask_params = pickle.load(f)
    mask_params = mask_params[200]

In [11]:
def get_nodes_and_edges(mask_params, edge_0=True):
    """
    If edge_0 is True, then edges are between nodes with mask value 0. Else, edges are between nodes with mask value 1.
    """
    # calculate which nodes will be in the graph
    connected_nodes = set()
    # add embed node at position
    # connected_nodes.add((-1, "embed"))
    n_heads = 12
    n_layers = 12

    # associate each node with a position
    all_possible_nodes = [(-1, "embed")]
    mask_dict = {}
    # empty tensor
    mask_dict["embed"] = torch.zeros(size=(0,))
    for idx in range(len(mask_params)):
        if "attention" in param_names[idx]:
            layer = int(param_names[idx].split(".")[1])
            for i in range(n_heads):
                all_possible_nodes.append((layer, f"a{layer}.{i}"))
                mask_dict[f"a{layer}.{i}"] = mask_params[idx][:,i].detach().cpu()
        elif "mlp" in param_names[idx]:
            layer = int(param_names[idx].split(".")[1])
            all_possible_nodes.append((layer, f"m{layer}"))
            mask_dict[f"m{layer}"] = mask_params[idx].detach().cpu()
    all_possible_nodes.append((n_heads, "output"))
    mask_dict["output"] = mask_params[0]

    # Calculate where edges are based on the mask
    # Edge between node i and node j if mask_dict[i][all_possible_nodes.index(j)] == 0
    edges = set()
    for i in range(len(all_possible_nodes)):
        for j in range(len(all_possible_nodes)):
            j_index = all_possible_nodes.index(all_possible_nodes[j])
            if j_index < len(mask_dict[all_possible_nodes[i][1]]) and mask_dict[all_possible_nodes[i][1]][all_possible_nodes.index(all_possible_nodes[j])] == (0 if edge_0 else 1):
                edges.add((all_possible_nodes[i], all_possible_nodes[j]))
    
    nodes_with_edges = set([node for edge in edges for node in edge])

    return all_possible_nodes, nodes_with_edges, edges, mask_dict
all_possible_nodes, nodes_with_edges, edges, mask_dict = get_nodes_and_edges(mask_params)

### Analyze ACDC and Compare
I separately used ACDC++ (EAP from "Attribution Patching Outperforms Automated Circuit Discovery" paper) to get the known circuit edges. I want to compare my various learned masks (from different losses) to the known circuit edges.

In [12]:
with open("models/acdcpp_edges_original.pkl", "rb") as f:
    acdcpp_edges_long = pickle.load(f)
acdcpp_edges_long

({0.08: {'blocks.0.attn.hook_k[:, :, 10]blocks.0.hook_k_input[:, :, 10]',
   'blocks.0.attn.hook_q[:, :, 10]blocks.0.hook_q_input[:, :, 10]',
   'blocks.0.attn.hook_q[:, :, 2]blocks.0.hook_q_input[:, :, 2]',
   'blocks.0.attn.hook_result[:, :, 10]blocks.0.attn.hook_k[:, :, 10]',
   'blocks.0.attn.hook_result[:, :, 10]blocks.0.attn.hook_q[:, :, 10]',
   'blocks.0.attn.hook_result[:, :, 10]blocks.0.attn.hook_v[:, :, 10]',
   'blocks.0.attn.hook_result[:, :, 1]blocks.0.attn.hook_k[:, :, 1]',
   'blocks.0.attn.hook_result[:, :, 1]blocks.0.attn.hook_q[:, :, 1]',
   'blocks.0.attn.hook_result[:, :, 1]blocks.0.attn.hook_v[:, :, 1]',
   'blocks.0.attn.hook_result[:, :, 3]blocks.0.attn.hook_k[:, :, 3]',
   'blocks.0.attn.hook_result[:, :, 3]blocks.0.attn.hook_q[:, :, 3]',
   'blocks.0.attn.hook_result[:, :, 3]blocks.0.attn.hook_v[:, :, 3]',
   'blocks.0.attn.hook_result[:, :, 6]blocks.0.attn.hook_k[:, :, 6]',
   'blocks.0.attn.hook_result[:, :, 6]blocks.0.attn.hook_q[:, :, 6]',
   'blocks.0.att

In [13]:
# acdcpp edges are in format 'blocks.1.attn.hook_result[:, :, 10]blocks.0.hook_mlp_in[:]', convert to format of ((1, 'a1.10'), (0, 'm0'))

def get_node_name(node_name, show_full_index=False):
    """Node name for use in pretty graphs"""

    def get_index(node_name_long):
        # Get the index by looking for number in brackets
        # e.g. blocks.1.attn.hook_result[:, :, 10] -> 10
        index = node_name_long.split("[")[-1].split("]")[0]
        index = index.split(", ")[-1]
        return int(index)

    if not show_full_index:
        name = ""
        qkv_substrings = [f"hook_{letter}" for letter in ["q", "k", "v"]]
        qkv_input_substrings = [f"hook_{letter}_input" for letter in ["q", "k", "v"]]

        # Handle embedz
        if "resid_pre" in node_name:
            assert "0" in node_name and not any([str(i) in node_name for i in range(1, 10)])
            name += "embed"
            layer = -1
            # if len(node.index.hashable_tuple) > 2:
            #     name += f"_[{node.index.hashable_tuple[2]}]"
            # return name

        elif "embed" in node_name:
            # name = "pos_embeds" if "pos" in node_name else "token_embeds"
            name = "embed"
            layer = -1

        # Handle q_input and hook_q etc
        elif any([node_name.endswith(qkv_input_substring) for qkv_input_substring in qkv_input_substrings]):
            relevant_letter = None
            for letter, qkv_substring in zip(["q", "k", "v"], qkv_substrings):
                if qkv_substring in node_name:
                    assert relevant_letter is None
                    relevant_letter = letter
            name += "a" + node_name.split(".")[1] + "." + str(get_index(node_name)) + "_" + relevant_letter
            layer = int(node_name.split(".")[1])

        # Handle attention hook_result
        elif "hook_result" in node_name or any([qkv_substring in node_name for qkv_substring in qkv_substrings]):
            name = "a" + node_name.split(".")[1] + "." + str(get_index(node_name))
            layer = int(node_name.split(".")[1])

        # Handle MLPs
        elif node_name.endswith("resid_mid"):
            raise ValueError("We removed resid_mid annotations. Call these mlp_in now.")
        elif "mlp" in node_name:
            name = "m" + node_name.split(".")[1]
            layer = int(node_name.split(".")[1])

        # Handle resid_post
        elif "resid_post" in node_name:
            name += "output"
            layer = 12

        # elif "mlp" in node_name:
        #     name += "m" + node_name.split(".")[1]
        else:
            raise ValueError(f"Unrecognized node name {node_name}")

    else:
        name = node_name
        # name = node_name + str(node.index.graphviz_index(use_actual_colon=True))

    # get layer by looking for number before first dot
    

    return layer, name

acdcpp_edges = set()
for edge in acdcpp_edges_long[0][0.08]:
    # split the edge into two nodes, e.g. blocks.1.attn.hook_result[:, :, 10]blocks.0.hook_mlp_in[:] into blocks.1.attn.hook_result[:, :, 10] and blocks.0.hook_mlp_in[:]
    node_1 = get_node_name(edge.split("]")[0]+"]", show_full_index=False)
    node_2 = get_node_name(edge.split("]")[1]+"]", show_full_index=False)
    if node_1 != node_2:
        acdcpp_edges.add((node_1, node_2))


In [26]:
acdcpp_edges

{((0, 'a0.1'), (-1, 'embed')),
 ((0, 'a0.10'), (-1, 'embed')),
 ((0, 'a0.3'), (-1, 'embed')),
 ((0, 'a0.5'), (-1, 'embed')),
 ((0, 'm0'), (-1, 'embed')),
 ((0, 'm0'), (0, 'a0.1')),
 ((0, 'm0'), (0, 'a0.10')),
 ((0, 'm0'), (0, 'a0.4')),
 ((1, 'm1'), (0, 'm0')),
 ((2, 'a2.11'), (0, 'm0')),
 ((2, 'm2'), (-1, 'embed')),
 ((2, 'm2'), (0, 'a0.10')),
 ((2, 'm2'), (0, 'm0')),
 ((2, 'm2'), (1, 'a1.11')),
 ((3, 'a3.0'), (0, 'm0')),
 ((3, 'a3.0'), (1, 'm1')),
 ((3, 'a3.0'), (2, 'm2')),
 ((3, 'm3'), (0, 'm0')),
 ((3, 'm3'), (2, 'm2')),
 ((3, 'm3'), (3, 'a3.0')),
 ((4, 'm4'), (0, 'm0')),
 ((4, 'm4'), (1, 'a1.11')),
 ((4, 'm4'), (1, 'm1')),
 ((4, 'm4'), (2, 'a2.2')),
 ((4, 'm4'), (2, 'a2.9')),
 ((4, 'm4'), (2, 'm2')),
 ((4, 'm4'), (3, 'a3.0')),
 ((4, 'm4'), (3, 'a3.10')),
 ((4, 'm4'), (3, 'a3.3')),
 ((4, 'm4'), (3, 'm3')),
 ((4, 'm4'), (4, 'a4.11')),
 ((4, 'm4'), (4, 'a4.3')),
 ((4, 'm4'), (4, 'a4.4')),
 ((4, 'm4'), (4, 'a4.7')),
 ((5, 'a5.5'), (0, 'm0')),
 ((5, 'a5.5'), (2, 'a2.2')),
 ((5, 'a5.5'),

In [20]:
# Convert edges back to weight mask
def get_mask_from_edges(edges, weight_mask_template=mask_dict, all_possible_nodes=all_possible_nodes, edge_0=True):
    new_mask_dict = {}
    for node_name in weight_mask_template:
        new_mask_dict[node_name] = torch.ones_like(weight_mask_template[node_name]) if edge_0 else torch.zeros_like(weight_mask_template[node_name])
    
    node_indices = {node_name: idx for idx, node_name in enumerate(all_possible_nodes)}
    for edge in edges:
        try:
            new_mask_dict[edge[0][1]][node_indices[edge[1]]] = 0 if edge_0 else 1
        except:
            continue
    
    return new_mask_dict

def convert_mask_dict_to_params(mask_dict):
    mask_params = []
    # first output_mask
    mask_params.append(mask_dict["output"])
    for layer in range(12):
        attn_tensors = []
        for head in range(12):
            attn_tensors.append(mask_dict[f"a{layer}.{head}"])
        mask_params.append(torch.stack(attn_tensors, dim=1))
        mask_params.append(mask_dict[f"m{layer}"])
    return mask_params
acdcpp_mask_dict = get_mask_from_edges(acdcpp_edges, edge_0=False)
acdcpp_mask_params = convert_mask_dict_to_params(acdcpp_mask_dict)

In [23]:
with open("models/acdcpp_mask_params.pkl", "wb") as f:
    pickle.dump(acdcpp_mask_params, f)

In [25]:
# Verify that edges are the same
_, _, acdcpp_edges_2, _ = get_nodes_and_edges(acdcpp_mask_params, edge_0=False)
print(len(acdcpp_edges_2))
print(len(acdcpp_edges))
print(len(acdcpp_edges_2.intersection(acdcpp_edges)))

232
232
232


### Analyze overlaps between different edges

In [None]:
with open("models/alternative_necessary_masks_params_dict_lambda=1.pkl", "rb") as f:
    alternative_necessary_mask_params = pickle.load(f)
    alternative_necessary_mask_params = alternative_necessary_mask_params[200]
with open("models/alternative_sufficient_masks_params_dict_lambda=1.pkl", "rb") as f:
    alternative_sufficient_mask_params = pickle.load(f)
    alternative_sufficient_mask_params = alternative_sufficient_mask_params[200]
_, _, alternative_necessary_edges, _ = get_nodes_and_edges(alternative_necessary_mask_params)
_, _, alternative_sufficient_edges, _ = get_nodes_and_edges(alternative_sufficient_mask_params, edge_0=False)

In [None]:
get_nodes_and_edges(alternative_sufficient_mask_params, edge_0=False)

In [None]:
print(f"{len(edges)=}, {len(acdcpp_edges)=}, {len(edges.intersection(acdcpp_edges))=}")
print(edges.intersection(acdcpp_edges))

In [None]:
# Get overlaps between all edges (regular edges, necessary, sufficient, acdcpp) (make a table with tabulate)
edges_dict = {"circuit_breaking":edges, "ioi_necessary":alternative_necessary_edges, "ioi_sufficient":alternative_sufficient_edges, "acdcpp":acdcpp_edges}
for edge_type in edges_dict:
    for second_edge_type in edges_dict:
        print(f"{edge_type} and {second_edge_type}: {len(edges_dict[edge_type].intersection(edges_dict[second_edge_type]))} edges in common")#, {edges_dict[edge_type].intersection(edges_dict[second_edge_type])}")

### Visualizations

In [None]:
def create_aligned_graph(all_possible_nodes, edges):
    G = pgv.AGraph(strict=False, directed=True)

    # Find the maximum layer number for adjusting the graph
    max_layer = max(layer for layer, _ in all_possible_nodes if isinstance(layer, int))
    nodes_with_edges = set([node for edge in edges for node in edge])

    # Add nodes and edges to the graph
    for node in all_possible_nodes:
        if node in [edge[0] for edge in edges] or node in [edge[1] for edge in edges]:
            G.add_node(node[1], layer=str(max_layer - node[0]))

    for edge in edges:
        G.add_edge(edge[1][1], edge[0][1])

    # Create subgraphs to ensure nodes of the same layer have the same rank
    for layer in range(max_layer, -2, -1):
        with G.subgraph(name=f'cluster_{layer}') as s:
            s.graph_attr['rank'] = 'same'
            for node in nodes_with_edges:
                if node[0] == layer:
                    s.add_node(node[1])

    # Apply layout and render the graph
    G.layout(prog='dot')
    G.draw('aligned_graph.png')
    return Image('aligned_graph.png')

# Call the function with your nodes and edges
flipped_graph_image = create_aligned_graph(all_possible_nodes, edges)

# To display the graph in Jupyter Notebook
flipped_graph_image


In [None]:
# intersecting edges graph
for edge_type in edges_dict:
    for second_edge_type in edges_dict:
        if edge_type == second_edge_type:
            continue
        # make a graph with just the intersecting edges, title it with the two edge types
        print(f"Intersection between {edge_type} and {second_edge_type}: {len(edges_dict[edge_type].intersection(edges_dict[second_edge_type]))} edges in common, {edges_dict[edge_type].intersection(edges_dict[second_edge_type])}")
        
        intersecting_edges_graph = create_aligned_graph(all_possible_nodes, edges_dict[edge_type].intersection(edges_dict[second_edge_type]))
        display(intersecting_edges_graph)
        # intersecting_edges_graph.render(f"intersecting_edges_graph_{edge_type}_{second_edge_type}", format="png", cleanup=True)
# intersecting_edges_graph = create_aligned_graph(all_possible_nodes, edges.intersection(acdcpp_edges))
# intersecting_edges_graph

In [None]:
import pygraphviz as pgv
from pathlib import Path
from IPython.display import Image

def show(nodes, edges, fname=None):
    g = pgv.AGraph(strict=True, directed=True)
    g.graph_attr.update(ranksep='0.1', nodesep='0.1', compound=True)
    g.node_attr.update(fixedsize='true', width='1.5', height='.5')
    
    layer_to_subgraph = {}

    # Create a subgraph for each layer
    for node in nodes:
        layer = node[0]
        if layer not in layer_to_subgraph:
            # Each layer has its own subgraph with 'rank=same' to ensure they are on the same level
            layer_to_subgraph[layer] = g.add_subgraph(name=f'cluster_{layer}', rank='same')
            
        # Here you add the node to the appropriate subgraph
        layer_to_subgraph[layer].add_node(node, label=str(node[1]))

    # Now, add the edges to the graph
    for edge in edges:
        g.add_edge(edge[0], edge[1])
    
    # If a filename is provided, write the file and optionally render to an image
    if fname:
        fpath = Path(fname)
        base_fname = fpath.stem
        base_path = fpath.parent
        base_path.mkdir(exist_ok=True, parents=True)
        
        # Write the dot file
        g.write(path=base_path / f"{base_fname}.gv")
        
        # Render to an image
        g.layout(prog='dot')
        g.draw(path=base_path / f"{base_fname}.png")
        
    return g


g = show(nodes_with_edges, edges, fname="graph.gv")
Image(g.draw(format='png', prog='dot'))