# Train different kinds of masks over IOI edges

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("/data/phillip_guo/circuit-breaking/ioi/")
from models import load_gpt2_weights, load_demo_gpt2, tokenizer
from data import retrieve_toxic_data, retrieve_owt_data, retrieve_toxic_data_low_loss, retrieve_toxic_filtered_data, FILTER_DEMO_LEN, CONTEXT_LENGTH
from inference import infer_batch_with_owt, infer_batch, prepare_fixed_demo, criterion
from torch.optim import AdamW
import torch
import pickle
import datasets
from tqdm import tqdm_notebook as tqdm
from itertools import cycle
# from eval import evaluate_model
from data import batch_text_to_tokens
import plotly.express as px

Using device: cuda:0


## Train params of mask
Train without the original D_train loss term (only mask loss and IOI data loss)
Finds necessary (but not sufficient) edges

In [2]:
toxic_batch_size = 5 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 1
context_length = CONTEXT_LENGTH


template_type = "double"
toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None, template_type=template_type)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

# with open("data/gpt2_means.pkl", "rb") as f:
#     means = pickle.load(f)[0][0]
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

model = load_demo_gpt2(means=means)
epochs_left = 200
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 20 # 5 # free
threshold = 0.5
epochs_trained = 0
regularization_strength = 1 # free

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 1 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 100
max_steps_per_epoch = 20


In [3]:
old_mask_params = {}
def duplicate_mask_params(mask_params):
    new_mask_params = []
    for p in mask_params:
        new_mask_params.append(p.data.cpu())
    return new_mask_params

prev_params = None
while epochs_left >= 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free

            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = -1 * (regularization_strength * penalty + alpha * tox_loss) #+ owt_loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            

            with torch.no_grad():
                test_ioi_sentence = "While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to"
                correct_token_id = tokenizer.encode(" Alicia", return_tensors="pt").squeeze().item()
                other_token_id = tokenizer.encode(" Joshua", return_tensors="pt").squeeze().item()
                test_ioi_tokens = tokenizer.encode(test_ioi_sentence, return_tensors="pt").to('cuda')
                generation = model(test_ioi_tokens)[0][:, -1]
                probs = torch.softmax(generation, dim=-1)
                print(f"Best Token: {tokenizer.batch_decode(torch.argmax(generation, dim=-1))}, P(Alicia) = {probs[:,correct_token_id].item()}, logit diff = {generation[:,correct_token_id].item() - generation[:,other_token_id].item()}")
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            print("\n")
            old_mask_params[epochs_trained] = duplicate_mask_params(mask_params)
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    # epochs_left = int(input('continue training for this number of epochs: '))
    # log_every = int(input('set log frequency'))
    # edge_threshold = int(input('set edge threshold'))
    epochs_left = -1

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/200 [00:00<?, ?it/s]

loss.item()=-69.30683898925781, ablated_edges=1211
loss.item()=-93.07505798339844, ablated_edges=2603
loss.item()=-98.3801498413086, ablated_edges=3031
loss.item()=-108.5616683959961, ablated_edges=3355
loss.item()=-112.21004486083984, ablated_edges=3576
loss.item()=-113.6036148071289, ablated_edges=3725
loss.item()=-114.2625732421875, ablated_edges=3839
loss.item()=-119.2594223022461, ablated_edges=3966
loss.item()=-117.36256408691406, ablated_edges=4071
loss.item()=-116.73172760009766, ablated_edges=4141
Epochs trained:  10
Loss: -116.7317
Total preserved: 7108.1475
Edges ablated:  4141
Toxic loss:  116.73172760009766
OWT loss:  10.723258018493652
Penalty:  0
Best Token: [' do'], P(Alicia) = 2.0229411421718976e-32, logit diff = 20.8072509765625


loss.item()=-119.75999450683594, ablated_edges=4207
loss.item()=-117.84992980957031, ablated_edges=4308
loss.item()=-115.53193664550781, ablated_edges=4387
loss.item()=-116.5049819946289, ablated_edges=4491
loss.item()=-119.66324615478516, a

In [4]:
with open(f"models/alternative_necessary_masks_params_dict_lambda={regularization_strength}_{means_ioi=}_{template_type=}.pkl", "wb") as f:
    pickle.dump(old_mask_params, f)

## Different alternative: sufficient but not necessary
Trains with an inverted loss function. This loss function encourages sparsity (as opposed to discouraging) and wants model to ablate everything but the necessary circuit.

In [2]:
toxic_batch_size = 5 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 1
context_length = CONTEXT_LENGTH


template_type = "double"
toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None, template_type=template_type)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

# with open("data/gpt2_means.pkl", "rb") as f:
#     means = pickle.load(f)[0][0]
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

model = load_demo_gpt2(means=means)
epochs_left = 200
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 20 # 5 # free
threshold = 0.5
epochs_trained = 0
regularization_strength = 1 # free

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 1 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 100
max_steps_per_epoch = 20


In [3]:
old_mask_params = {}
def duplicate_mask_params(mask_params):
    new_mask_params = []
    for p in mask_params:
        new_mask_params.append(p.data.cpu())
    return new_mask_params

prev_params = None
while epochs_left >= 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free

            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = (regularization_strength * penalty + alpha * tox_loss) #+ owt_loss
            # loss = alpha * tox_loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            

            with torch.no_grad():
                test_ioi_sentence = "While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to"
                correct_token_id = tokenizer.encode(" Alicia", return_tensors="pt").squeeze().item()
                other_token_id = tokenizer.encode(" Joshua", return_tensors="pt").squeeze().item()
                test_ioi_tokens = tokenizer.encode(test_ioi_sentence, return_tensors="pt").to('cuda')
                generation = model(test_ioi_tokens)[0][:, -1]
                probs = torch.softmax(generation, dim=-1)
                print(f"Best Token: {tokenizer.batch_decode(torch.argmax(generation, dim=-1))}, P(Alicia) = {probs[:,correct_token_id].item()}, logit diff = {generation[:,correct_token_id].item() - generation[:,other_token_id].item()}")
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            print("\n")
            old_mask_params[epochs_trained] = duplicate_mask_params(mask_params)
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    # epochs_left = int(input('continue training for this number of epochs: '))
    # log_every = int(input('set log frequency'))
    # edge_threshold = int(input('set edge threshold'))
    epochs_left = -1

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/200 [00:00<?, ?it/s]

loss.item()=0.00015079081640578806, ablated_edges=0
loss.item()=8.201432137866504e-06, ablated_edges=222
loss.item()=3.583267243811861e-05, ablated_edges=278
loss.item()=6.008088348607998e-06, ablated_edges=292
loss.item()=2.407883948762901e-05, ablated_edges=294
loss.item()=5.1661336328834295e-05, ablated_edges=297
loss.item()=2.4484959794790484e-05, ablated_edges=299
loss.item()=0.0, ablated_edges=303
loss.item()=3.075577069466817e-06, ablated_edges=302
loss.item()=1.3589813079306623e-06, ablated_edges=303
Epochs trained:  10
Loss: 0.0000
Total preserved: 10087.0703
Edges ablated:  303
Toxic loss:  1.3589813079306623e-06
OWT loss:  4.985528945922852
Penalty:  0
Best Token: [' Alicia'], P(Alicia) = 1.0, logit diff = 33.28448486328125


loss.item()=2.1290243239491247e-05, ablated_edges=305
loss.item()=4.506091499933973e-06, ablated_edges=311
loss.item()=2.336495072086109e-06, ablated_edges=315
loss.item()=7.009450200712308e-06, ablated_edges=319
loss.item()=6.437292654482007e-07, ablat

In [10]:
with open(f"models/alternative_sufficient_masks_params_dict_lambda={regularization_strength}_{alpha=}_{means_ioi=}_{template_type=}.pkl", "wb") as f:
    pickle.dump(old_mask_params, f)

## Train mask over known circuit
Train mask over the circuit from the paper, as given by a run of ACDC++.

In [2]:
from mask_utils import get_nodes_and_edges
with open("models/acdcpp_mask_params.pkl", "rb") as f:
    acdc_mask_params = pickle.load(f)

_, _, acdc_Edges, acdc_mask_dict = get_nodes_and_edges(mask_params=acdc_mask_params)
acdc_mask_dict

{'embed': tensor([]),
 'a0.0': tensor([1.]),
 'a0.1': tensor([0.]),
 'a0.2': tensor([1.]),
 'a0.3': tensor([0.]),
 'a0.4': tensor([1.]),
 'a0.5': tensor([0.]),
 'a0.6': tensor([1.]),
 'a0.7': tensor([1.]),
 'a0.8': tensor([1.]),
 'a0.9': tensor([1.]),
 'a0.10': tensor([0.]),
 'a0.11': tensor([1.]),
 'm0': tensor([0., 1., 0., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1.]),
 'a1.0': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'a1.1': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'a1.2': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'a1.3': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'a1.4': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'a1.5': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'a1.6': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'a1.7': tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]),
 'a1.8': tensor([1., 1., 1., 1.

In [3]:
acdc_mask_dict['output'].shape

torch.Size([157])

In [10]:
toxic_batch_size = 1 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 5
context_length = CONTEXT_LENGTH

toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

with open("data/gpt2_means.pkl", "rb") as f:
    means = pickle.load(f)[0][0]

model = load_demo_gpt2(means=means, mask_dict_superset=acdc_mask_dict)
epochs_left = 200
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 20 # 5 # free
threshold = 0.5
epochs_trained = 0
regularization_strength = 1 # free

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 0.2 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 0
max_steps_per_epoch = 100

old_mask_params = {}
def duplicate_mask_params(mask_params):
    new_mask_params = []
    for p in mask_params:
        new_mask_params.append(p.data.cpu())
    return new_mask_params

prev_params = None
while epochs_left >= 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free
            # print(f"{ablated_edges=}")
            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = -1 * (regularization_strength * penalty + alpha * tox_loss) + owt_loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            print("\n")
            old_mask_params[epochs_trained] = duplicate_mask_params(mask_params)
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    epochs_left = int(input('continue training for this number of epochs: '))
    log_every = int(input('set log frequency'))
    edge_threshold = int(input('set edge threshold'))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/200 [00:00<?, ?it/s]

loss.item()=1.5056736469268799, ablated_edges=80
loss.item()=-0.500084400177002, ablated_edges=87
loss.item()=-0.0825948715209961, ablated_edges=87
loss.item()=0.15455102920532227, ablated_edges=88
loss.item()=0.7475423812866211, ablated_edges=95
loss.item()=0.7848529815673828, ablated_edges=94
loss.item()=0.06402826309204102, ablated_edges=98
loss.item()=-0.9390511512756348, ablated_edges=96
loss.item()=0.6501531600952148, ablated_edges=98
loss.item()=0.626798152923584, ablated_edges=98
Epochs trained:  10
Loss: 0.6268
Total preserved: 11512.8232
Edges ablated:  98
Toxic loss:  23.036602020263672
OWT loss:  5.234118461608887
Penalty:  0


loss.item()=-3.0186614990234375, ablated_edges=103
loss.item()=-0.8208189010620117, ablated_edges=93
loss.item()=0.038724422454833984, ablated_edges=94
loss.item()=-0.08931589126586914, ablated_edges=95
loss.item()=-0.5814499855041504, ablated_edges=98
loss.item()=-0.1508636474609375, ablated_edges=102
loss.item()=2.3883941173553467, ablated_edges=98

ValueError: invalid literal for int() with base 10: ''

In [11]:
with open("models/alternative_acdc_cb_subset_mask_params_dict.pkl", "wb") as f:
    pickle.dump(old_mask_params, f)

## Test model before and after circuit breaking

In [None]:
import pickle
with open("data/ioi_sentences_test.pkl", "rb") as f:
    ioi_sentences_test = pickle.load(f)
    # ioi_sentences_test = [t[2] for t in ioi_sentences_test]

with open("data/eval_uniform.pkl", "rb") as f:
    uniform_samples = pickle.load(f)
    uniform_sentences = [t[2] for t in uniform_samples]

original_model = load_demo_gpt2(means=False)

# with open("models/masked_gpt2_mean_ablation_v6.pkl", "rb") as f:
#     model.state_dict = pickle.load(f)

In [None]:
# Run inference on an ioi_sentence
ioi_sentence = ioi_sentences_test[0]
print(ioi_sentence)
# ioi_tokens = tokenizer(ioi_sentence, return_tensors='pt').input_ids.to('cuda')

original_model.eval()
original_model.to('cuda')
def get_last_token(model, prompt, topk=5):
    # generate last token
    tokens = tokenizer(prompt, return_tensors='pt').input_ids[:, :-1]

    # generate one token, decode original_model(ioi_tokens[:, :-1])
    model_outputs = model(tokens)[0]
    model_outputs = model_outputs.squeeze(0)[-1]
    probs = torch.nn.functional.softmax(model_outputs, dim=-1)

    topk_outputs = torch.topk(model_outputs, topk)
    topk_tokens = topk_outputs.indices
    topk_probs = probs[topk_outputs.indices]
    
    # decode tokens
    for i in range(topk):
        print(f"{tokenizer.decode(topk_tokens[i].unsqueeze(0))}, probability of {topk_probs[i]}")
    topk_tokens_decoded = tokenizer.batch_decode(topk_tokens)
    return topk_tokens_decoded, topk_probs

print("Before ablation")
_ = get_last_token(original_model, ioi_sentence)
print()
print()
print("After ablation")
_ = get_last_token(model, ioi_sentence)

In [None]:
# Try on uniform samples
for idx in range(3):
    print(uniform_sentences[idx])
    print("Before ablation")
    _ = get_last_token(original_model, uniform_sentences[idx])
    print()
    print("After ablation")
    _ = get_last_token(model, uniform_sentences[idx])
    print("\n\n")

## Visualize mask
Create the computational graphs in edge attribution patching paper

In [None]:
# calculate which nodes will be in the graph
connected_nodes = set()
# add embed node at position
# connected_nodes.add((-1, "embed"))
n_heads = 12
n_layers = 12

# associate each node with a position
all_possible_nodes = [(-1, "embed")]
mask_dict = {}
# empty tensor
mask_dict["embed"] = torch.zeros(size=(0,))
for idx in range(len(mask_params)):
    if "attention" in param_names[idx]:
        layer = int(param_names[idx].split(".")[1])
        for i in range(n_heads):
            all_possible_nodes.append((layer, f"a{layer}.{i}"))
            mask_dict[f"a{layer}.{i}"] = mask_params[idx][:,i].detach().cpu()
    elif "mlp" in param_names[idx]:
        layer = int(param_names[idx].split(".")[1])
        all_possible_nodes.append((layer, f"m{layer}"))
        mask_dict[f"m{layer}"] = mask_params[idx].detach().cpu()
all_possible_nodes.append((n_heads, "output"))
mask_dict["output"] = mask_params[-1]

In [None]:
# Calculate where edges are based on the mask
# Edge between node i and node j if mask_dict[i][all_possible_nodes.index(j)] == 0
sufficient = True

edges = []
for i in range(len(all_possible_nodes)):
    for j in range(len(all_possible_nodes)):
        j_index = all_possible_nodes.index(all_possible_nodes[j])
        if j_index < len(mask_dict[all_possible_nodes[i][1]]) and mask_dict[all_possible_nodes[i][1]][all_possible_nodes.index(all_possible_nodes[j])] == (1 if sufficient else 0):
            edges.append((all_possible_nodes[i], all_possible_nodes[j]))

In [None]:
len(edges)

In [None]:
def create_aligned_graph(all_possible_nodes, edges):
    G = pgv.AGraph(strict=False, directed=True)

    # Find the maximum layer number for adjusting the graph
    max_layer = max(layer for layer, _ in all_possible_nodes if isinstance(layer, int))
    nodes_with_edges = set([node for edge in edges for node in edge])
    print(nodes_with_edges)
    # Add nodes and edges to the graph
    # for node in all_possible_nodes:
    #     if node in [edge[0] for edge in edges] or node in [edge[1] for edge in edges]:
    #         G.add_node(node[1], layer=str(max_layer - node[0]))

    for edge in edges:
        G.add_edge(edge[1][1], edge[0][1])

    # Create subgraphs to ensure nodes of the same layer have the same rank
    for layer in range(max_layer, -2, -1):
        with G.subgraph(name=f'cluster_{layer}') as s:
            s.graph_attr['rank'] = 'same'
            for node in nodes_with_edges:
                if node[0] == layer:
                    s.add_node(node[1])

    # Apply layout and render the graph
    G.layout(prog='dot')
    G.draw('aligned_graph.png')
    return Image('aligned_graph.png')

# Call the function with your nodes and edges
flipped_graph_image = create_aligned_graph(all_possible_nodes, edges)

# To display the graph in Jupyter Notebook
flipped_graph_image
