# Train different kinds of masks over IOI edges

In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("/data/phillip_guo/circuit-breaking/ioi/")
from models import load_gpt2_weights, load_demo_gpt2, tokenizer
from data import retrieve_toxic_data, retrieve_owt_data, retrieve_toxic_data_low_loss, retrieve_toxic_filtered_data, FILTER_DEMO_LEN, CONTEXT_LENGTH
from inference import infer_batch_with_owt, infer_batch, prepare_fixed_demo, criterion
from torch.optim import AdamW
import torch
import pickle
import datasets
from tqdm import tqdm_notebook as tqdm
from itertools import cycle
# from eval import evaluate_model
from data import batch_text_to_tokens
import plotly.express as px

Using device: cuda:0


## Train params of mask
Train without the original D_train loss term (only mask loss and IOI data loss)
Finds necessary (but not sufficient) edges

In [2]:
toxic_batch_size = 10 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 1
context_length = CONTEXT_LENGTH


template_type = "single"
toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None, template_type=template_type)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

# with open("data/gpt2_means.pkl", "rb") as f:
#     means = pickle.load(f)[0][0]
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

model = load_demo_gpt2(means=means)
epochs_left = 200
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 50 # 5 # free
threshold = 0.5
epochs_trained = 0
regularization_strength = 1 # free

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 1 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 100
max_steps_per_epoch = 100


In [3]:
for c, batch in enumerate(toxic_data_loader):
    if c > 5:
        break
    print(batch)

{'text': ['Then, Thomas and Sarah went to the school. Sarah gave a bone to Thomas', 'Then, Christine and Kelly went to the station. Kelly gave a drink to Christine', 'Then, Kristen and Gregory went to the school. Gregory gave a ring to Kristen', 'Then, Danielle and Aaron went to the store. Aaron gave a necklace to Danielle', 'Then, Christine and Nathan went to the restaurant. Nathan gave a kiss to Christine', 'Then, Adam and Jessica went to the store. Jessica gave a drink to Adam', 'Then, Thomas and Stephanie went to the garden. Stephanie gave a ring to Thomas', 'Then, Rebecca and Daniel went to the house. Daniel gave a necklace to Rebecca', 'Then, Heather and Emily went to the house. Emily gave a bone to Heather', 'Then, Amber and Daniel went to the store. Daniel gave a kiss to Amber']}
{'text': ['Then, Cody and Andrew went to the school. Andrew gave a ring to Cody', 'Then, Jamie and Kelly went to the hospital. Kelly gave a snack to Jamie', 'Then, Sara and David went to the garden. Da

In [4]:
old_mask_params = {}
def duplicate_mask_params(mask_params):
    new_mask_params = []
    for p in mask_params:
        new_mask_params.append(p.data.cpu())
    return new_mask_params

prev_params = None
while epochs_left >= 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free

            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = -1 * (regularization_strength * penalty + alpha * tox_loss) #+ owt_loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            

            with torch.no_grad():
                test_ioi_sentences = ["While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to", "While Joshua and Alicia were commuting to the restaurant, Joshua gave a snack to"]
                for test_ioi_sentence in test_ioi_sentences:
                    correct_token_id = tokenizer.encode(" Alicia", return_tensors="pt").squeeze().item()
                    other_token_id = tokenizer.encode(" Joshua", return_tensors="pt").squeeze().item()
                    test_ioi_tokens = tokenizer.encode(test_ioi_sentence, return_tensors="pt").to('cuda')
                    generation = model(test_ioi_tokens)[0][:, -1]
                    probs = torch.softmax(generation, dim=-1)
                    print(f"Best Token: {tokenizer.batch_decode(torch.argmax(generation, dim=-1))}, P(Alicia) = {probs[:,correct_token_id].item()}, logit diff = {generation[:,correct_token_id].item() - generation[:,other_token_id].item()}")
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            print("\n")
            old_mask_params[epochs_trained] = duplicate_mask_params(mask_params)
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    # epochs_left = int(input('continue training for this number of epochs: '))
    # log_every = int(input('set log frequency'))
    # edge_threshold = int(input('set edge threshold'))
    epochs_left = -1

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/200 [00:00<?, ?it/s]

loss.item()=-125.52879333496094, ablated_edges=3777
loss.item()=-135.508056640625, ablated_edges=4482
loss.item()=-139.02293395996094, ablated_edges=4864
loss.item()=-139.85806274414062, ablated_edges=5091
loss.item()=-140.1219024658203, ablated_edges=5265
loss.item()=-141.9140625, ablated_edges=5451
loss.item()=-141.69163513183594, ablated_edges=5566
loss.item()=-141.4807586669922, ablated_edges=5701
loss.item()=-143.74453735351562, ablated_edges=5780
loss.item()=-142.1165008544922, ablated_edges=5820
Epochs trained:  10
Loss: -142.1165
Total preserved: 5734.8364
Edges ablated:  5820
Toxic loss:  142.1165008544922
OWT loss:  11.655352592468262
Penalty:  0
Best Token: [' run'], P(Alicia) = 3.401473271957479e-30, logit diff = -31.519729614257812
Best Token: [' run'], P(Alicia) = 9.703986885491369e-23, logit diff = 8.248924255371094


loss.item()=-142.02871704101562, ablated_edges=5891
loss.item()=-142.52169799804688, ablated_edges=5938
loss.item()=-142.1531219482422, ablated_edges=5968


In [5]:
with open(f"models/alternative_necessary_masks_params_dict_lambda={regularization_strength}_{alpha=}_{means_ioi=}_{template_type=}.pkl", "wb") as f:
    pickle.dump(old_mask_params, f)

## Different alternative: sufficient but not necessary
Trains with an inverted loss function. This loss function encourages sparsity (as opposed to discouraging) and wants model to ablate everything but the necessary circuit.

In [6]:
toxic_batch_size = 10 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 1
context_length = CONTEXT_LENGTH


template_type = "single"
toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None, template_type=template_type)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

# with open("data/gpt2_means.pkl", "rb") as f:
#     means = pickle.load(f)[0][0]
means_ioi = True
if means_ioi:
    with open("data/gpt2_ioi_abc_means.pkl", "rb") as f:
        means = pickle.load(f)[0]
else:
    with open("data/gpt2_means.pkl", "rb") as f:
        means = pickle.load(f)[0]

model = load_demo_gpt2(means=means)
epochs_left = 200
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 50 # 5 # free
threshold = 0.5
epochs_trained = 0
regularization_strength = 1 # free

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 1 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 100
max_steps_per_epoch = 100


In [7]:
old_mask_params = {}
def duplicate_mask_params(mask_params):
    new_mask_params = []
    for p in mask_params:
        new_mask_params.append(p.data.cpu())
    return new_mask_params

prev_params = None
while epochs_left >= 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free

            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = (regularization_strength * penalty + alpha * tox_loss) #+ owt_loss
            # loss = alpha * tox_loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            

            with torch.no_grad():
                test_ioi_sentences = ["While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to", "While Joshua and Alicia were commuting to the restaurant, Joshua gave a snack to"]
                for test_ioi_sentence in test_ioi_sentences:
                    correct_token_id = tokenizer.encode(" Alicia", return_tensors="pt").squeeze().item()
                    other_token_id = tokenizer.encode(" Joshua", return_tensors="pt").squeeze().item()
                    test_ioi_tokens = tokenizer.encode(test_ioi_sentence, return_tensors="pt").to('cuda')
                    generation = model(test_ioi_tokens)[0][:, -1]
                    probs = torch.softmax(generation, dim=-1)
                    print(f"Best Token: {tokenizer.batch_decode(torch.argmax(generation, dim=-1))}, P(Alicia) = {probs[:,correct_token_id].item()}, logit diff = {generation[:,correct_token_id].item() - generation[:,other_token_id].item()}")
            
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            print("\n")
            old_mask_params[epochs_trained] = duplicate_mask_params(mask_params)
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    # epochs_left = int(input('continue training for this number of epochs: '))
    # log_every = int(input('set log frequency'))
    # edge_threshold = int(input('set edge threshold'))
    epochs_left = -1

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/200 [00:00<?, ?it/s]

loss.item()=9.167099051410332e-06, ablated_edges=184
loss.item()=1.6331644019373925e-06, ablated_edges=194
loss.item()=2.4199421204684768e-06, ablated_edges=203


loss.item()=8.344642878910236e-07, ablated_edges=207
loss.item()=7.987012509147462e-07, ablated_edges=217
loss.item()=6.914127652635216e-07, ablated_edges=221
loss.item()=4.768367034557741e-07, ablated_edges=223
loss.item()=1.4305111051271524e-07, ablated_edges=227
loss.item()=2.026556842338323e-07, ablated_edges=229
loss.item()=7.867785143389483e-07, ablated_edges=232
Epochs trained:  10
Loss: 0.0000
Total preserved: 9748.0654
Edges ablated:  232
Toxic loss:  7.867785143389483e-07
OWT loss:  8.05907917022705
Penalty:  0
Best Token: [' Alicia'], P(Alicia) = 1.0, logit diff = 32.228858947753906
Best Token: [' Joshua'], P(Alicia) = 2.5583093421488456e-09, logit diff = -19.78390884399414


loss.item()=9.536741885085576e-08, ablated_edges=232
loss.item()=1.168244807558949e-06, ablated_edges=234
loss.item()=4.768370942542788e-08, ablated_edges=233
loss.item()=2.622602437440946e-07, ablated_edges=231
loss.item()=9.536741885085576e-08, ablated_edges=234
loss.item()=2.1457667287450022e-07, abl

In [8]:
with open(f"models/alternative_sufficient_masks_params_dict_lambda={regularization_strength}_{alpha=}_{means_ioi=}_{template_type=}.pkl", "wb") as f:
    pickle.dump(old_mask_params, f)

In [10]:
old_mask_params[200]

[tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
         0., 1., 0., 0., 0., 1., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.,
         0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0.]),
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]]),
 tensor([1., 0., 0., 0., 0., 0., 1., 0., 0., 0., 1., 1., 0.]),
 tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0., 0., 0

In [8]:
with torch.no_grad():
    test_ioi_sentences = ["While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to", "While Joshua and Alicia were commuting to the restaurant, Joshua gave a snack to"]
    for test_ioi_sentence in test_ioi_sentences:
        correct_token_id = tokenizer.encode(" Alicia", return_tensors="pt").squeeze().item()
        other_token_id = tokenizer.encode(" Joshua", return_tensors="pt").squeeze().item()
        test_ioi_tokens = tokenizer.encode(test_ioi_sentence, return_tensors="pt").to('cuda')
        generation = model(test_ioi_tokens)[0][:, -1]
        probs = torch.softmax(generation, dim=-1)
        print(f"Best Token: {tokenizer.batch_decode(torch.argmax(generation, dim=-1))}, P(Alicia) = {probs[:,correct_token_id].item()}, logit diff = {generation[:,correct_token_id].item() - generation[:,other_token_id].item()}")

Best Token: [' Alicia'], P(Alicia) = 0.9056878089904785, logit diff = 2.4303483963012695
Best Token: [' Joshua'], P(Alicia) = 2.5975340989248252e-08, logit diff = -17.38216209411621


## Train mask over known circuit
Train mask over the circuit from the paper, as given by a run of ACDC++.

In [15]:
# from mask_utils import get_nodes_and_edges
# with open("models/acdcpp_mask_params.pkl", "rb") as f:
#     acdc_mask_params = pickle.load(f)

# _, _, acdc_Edges, acdc_mask_dict = get_nodes_and_edges(mask_params=acdc_mask_params)
# acdc_mask_dict

from mask_utils import get_nodes_and_edges
with open("models/circuit_covering_mask_params.pkl", "rb") as f:
    circuit_covering_mask_params = pickle.load(f)

_, _, circuit_covering_edges, circuit_covering_mask_dict = get_nodes_and_edges(mask_params=circuit_covering_mask_params)
# circuit_covering_mask_dict # mostly 1s, 1s are frozen edges
# Circuit break only over the edges that are currently 0s in the circuit covering mask, ablate as few of them as possible


In [16]:
toxic_batch_size = 10 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 10
context_length = CONTEXT_LENGTH

template_type = "single"
toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None, template_type=template_type)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

with open("data/gpt2_means.pkl", "rb") as f:
    means = pickle.load(f)[0]

model = load_demo_gpt2(means=means, mask_dict_superset=circuit_covering_mask_dict)
epochs_left = 200
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 50 # 5 # free
threshold = 0.5
epochs_trained = 0
regularization_strength = 1 # free

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 1 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 0
max_steps_per_epoch = 50

In [17]:
old_mask_params = {}
def duplicate_mask_params(mask_params):
    new_mask_params = []
    for p in mask_params:
        new_mask_params.append(p.data.cpu())
    return new_mask_params

prev_params = None
while epochs_left >= 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free

            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = -1 * (regularization_strength * penalty + alpha * tox_loss)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            with torch.no_grad():
                test_ioi_sentences = ["While Alicia and Joshua were commuting to the restaurant, Joshua gave a snack to", "While Joshua and Alicia were commuting to the restaurant, Joshua gave a snack to"]
                for test_ioi_sentence in test_ioi_sentences:
                    correct_token_id = tokenizer.encode(" Alicia", return_tensors="pt").squeeze().item()
                    other_token_id = tokenizer.encode(" Joshua", return_tensors="pt").squeeze().item()
                    test_ioi_tokens = tokenizer.encode(test_ioi_sentence, return_tensors="pt").to('cuda')
                    generation = model(test_ioi_tokens)[0][:, -1]
                    probs = torch.softmax(generation, dim=-1)
                    print(f"Best Token: {tokenizer.batch_decode(torch.argmax(generation, dim=-1))}, P(Alicia) = {probs[:,correct_token_id].item()}, logit diff = {generation[:,correct_token_id].item() - generation[:,other_token_id].item()}")
            print("\n")

            old_mask_params[epochs_trained] = duplicate_mask_params(mask_params)
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    # epochs_left = int(input('continue training for this number of epochs: '))
    epochs_left = -1
    # log_every = int(input('set log frequency'))
    # edge_threshold = int(input('set edge threshold'))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/200 [00:00<?, ?it/s]

loss.item()=-13.236642837524414, ablated_edges=167
loss.item()=-12.427201271057129, ablated_edges=178
loss.item()=-14.008918762207031, ablated_edges=180
loss.item()=-13.530853271484375, ablated_edges=177
loss.item()=-13.251840591430664, ablated_edges=176
loss.item()=-13.437471389770508, ablated_edges=177
loss.item()=-13.83442211151123, ablated_edges=176
loss.item()=-12.967565536499023, ablated_edges=175
loss.item()=-12.80200481414795, ablated_edges=176
loss.item()=-13.274145126342773, ablated_edges=176
Epochs trained:  10
Loss: -13.2741
Total preserved: 11433.7441
Edges ablated:  176
Toxic loss:  13.274145126342773
OWT loss:  4.235899925231934
Penalty:  0
Best Token: [' the'], P(Alicia) = 4.3425427520560334e-07, logit diff = -4.821807861328125
Best Token: [' the'], P(Alicia) = 4.7285746518355154e-07, logit diff = -3.5168304443359375


loss.item()=-12.855962753295898, ablated_edges=173
loss.item()=-13.841504096984863, ablated_edges=175
loss.item()=-13.255230903625488, ablated_edges=175


In [18]:
with open(f"models/circuit_covering_alternative_necessary_params_dict_lambda={regularization_strength}_{alpha=}_means_ioi=False_{template_type=}.pkl", "wb") as f:
    pickle.dump(old_mask_params, f)

## Test model before and after circuit breaking

In [None]:
import pickle
with open("data/ioi_sentences_test.pkl", "rb") as f:
    ioi_sentences_test = pickle.load(f)
    # ioi_sentences_test = [t[2] for t in ioi_sentences_test]

with open("data/eval_uniform.pkl", "rb") as f:
    uniform_samples = pickle.load(f)
    uniform_sentences = [t[2] for t in uniform_samples]

original_model = load_demo_gpt2(means=False)

# with open("models/masked_gpt2_mean_ablation_v6.pkl", "rb") as f:
#     model.state_dict = pickle.load(f)

In [None]:
# Run inference on an ioi_sentence
ioi_sentence = ioi_sentences_test[0]
print(ioi_sentence)
# ioi_tokens = tokenizer(ioi_sentence, return_tensors='pt').input_ids.to('cuda')

original_model.eval()
original_model.to('cuda')
def get_last_token(model, prompt, topk=5):
    # generate last token
    tokens = tokenizer(prompt, return_tensors='pt').input_ids[:, :-1]

    # generate one token, decode original_model(ioi_tokens[:, :-1])
    model_outputs = model(tokens)[0]
    model_outputs = model_outputs.squeeze(0)[-1]
    probs = torch.nn.functional.softmax(model_outputs, dim=-1)

    topk_outputs = torch.topk(model_outputs, topk)
    topk_tokens = topk_outputs.indices
    topk_probs = probs[topk_outputs.indices]
    
    # decode tokens
    for i in range(topk):
        print(f"{tokenizer.decode(topk_tokens[i].unsqueeze(0))}, probability of {topk_probs[i]}")
    topk_tokens_decoded = tokenizer.batch_decode(topk_tokens)
    return topk_tokens_decoded, topk_probs

print("Before ablation")
_ = get_last_token(original_model, ioi_sentence)
print()
print()
print("After ablation")
_ = get_last_token(model, ioi_sentence)

In [None]:
# Try on uniform samples
for idx in range(3):
    print(uniform_sentences[idx])
    print("Before ablation")
    _ = get_last_token(original_model, uniform_sentences[idx])
    print()
    print("After ablation")
    _ = get_last_token(model, uniform_sentences[idx])
    print("\n\n")

## Visualize mask
Create the computational graphs in edge attribution patching paper

In [None]:
# calculate which nodes will be in the graph
connected_nodes = set()
# add embed node at position
# connected_nodes.add((-1, "embed"))
n_heads = 12
n_layers = 12

# associate each node with a position
all_possible_nodes = [(-1, "embed")]
mask_dict = {}
# empty tensor
mask_dict["embed"] = torch.zeros(size=(0,))
for idx in range(len(mask_params)):
    if "attention" in param_names[idx]:
        layer = int(param_names[idx].split(".")[1])
        for i in range(n_heads):
            all_possible_nodes.append((layer, f"a{layer}.{i}"))
            mask_dict[f"a{layer}.{i}"] = mask_params[idx][:,i].detach().cpu()
    elif "mlp" in param_names[idx]:
        layer = int(param_names[idx].split(".")[1])
        all_possible_nodes.append((layer, f"m{layer}"))
        mask_dict[f"m{layer}"] = mask_params[idx].detach().cpu()
all_possible_nodes.append((n_heads, "output"))
mask_dict["output"] = mask_params[-1]

In [None]:
# Calculate where edges are based on the mask
# Edge between node i and node j if mask_dict[i][all_possible_nodes.index(j)] == 0
sufficient = True

edges = []
for i in range(len(all_possible_nodes)):
    for j in range(len(all_possible_nodes)):
        j_index = all_possible_nodes.index(all_possible_nodes[j])
        if j_index < len(mask_dict[all_possible_nodes[i][1]]) and mask_dict[all_possible_nodes[i][1]][all_possible_nodes.index(all_possible_nodes[j])] == (1 if sufficient else 0):
            edges.append((all_possible_nodes[i], all_possible_nodes[j]))

In [None]:
len(edges)

In [None]:
def create_aligned_graph(all_possible_nodes, edges):
    G = pgv.AGraph(strict=False, directed=True)

    # Find the maximum layer number for adjusting the graph
    max_layer = max(layer for layer, _ in all_possible_nodes if isinstance(layer, int))
    nodes_with_edges = set([node for edge in edges for node in edge])
    print(nodes_with_edges)
    # Add nodes and edges to the graph
    # for node in all_possible_nodes:
    #     if node in [edge[0] for edge in edges] or node in [edge[1] for edge in edges]:
    #         G.add_node(node[1], layer=str(max_layer - node[0]))

    for edge in edges:
        G.add_edge(edge[1][1], edge[0][1])

    # Create subgraphs to ensure nodes of the same layer have the same rank
    for layer in range(max_layer, -2, -1):
        with G.subgraph(name=f'cluster_{layer}') as s:
            s.graph_attr['rank'] = 'same'
            for node in nodes_with_edges:
                if node[0] == layer:
                    s.add_node(node[1])

    # Apply layout and render the graph
    G.layout(prog='dot')
    G.draw('aligned_graph.png')
    return Image('aligned_graph.png')

# Call the function with your nodes and edges
flipped_graph_image = create_aligned_graph(all_possible_nodes, edges)

# To display the graph in Jupyter Notebook
flipped_graph_image
