# Train mask over IOI edges and analyze mask vs known circuit

In [1]:
from models import load_gpt2_weights, load_demo_gpt2, tokenizer
from data import retrieve_toxic_data, retrieve_owt_data, retrieve_toxic_data_low_loss, retrieve_toxic_filtered_data, FILTER_DEMO_LEN, CONTEXT_LENGTH
from inference import infer_batch_with_owt, infer_batch, prepare_fixed_demo, criterion
from torch.optim import AdamW
import torch
import pickle
import datasets
from tqdm import tqdm_notebook as tqdm
from itertools import cycle
from eval import evaluate_model
from data import batch_text_to_tokens
import plotly.express as px

Using device: cuda:0


In [14]:
toxic_batch_size = 1 # so that we can just access the last sequence position without worrying about padding
owt_batch_size = 5
context_length = CONTEXT_LENGTH

toxic_data_loader = retrieve_toxic_data(toxic_batch_size, context_length, tokenizer, tokenize=False, num_points=None)
# toxic_data_loader = retrieve_toxic_filtered_data(toxic_batch_size)
owt_data_loader = retrieve_owt_data(owt_batch_size)

with open("data/gpt2_means.pkl", "rb") as f:
    means = pickle.load(f)[0][0]

model = load_demo_gpt2(means=False)
epochs_left = 100
log_every = 10
lr = .05 # free
weight_decay = 0
clamp_every = 50 # 5 # free
threshold = 0.5
epochs_trained = 0

mask_params = []
param_names = []
for name, p in model.named_parameters():
    if p.requires_grad:
        param_names.append(name)
        mask_params.append(p)
optimizer = AdamW(mask_params, lr=lr, weight_decay=weight_decay)

losses = []
num_ablated_edges = []
alpha = 0.2 # free
batch_size = toxic_batch_size + owt_batch_size
demos = prepare_fixed_demo(tokenizer, batch_size, demo="")
owt_iter = cycle(owt_data_loader)
edge_threshold = 100
max_steps_per_epoch = 100


## Train params of mask

In [15]:
prev_params = None
while epochs_left > 0:
    for e in tqdm(range(epochs_left)):
        for c, batch in enumerate(toxic_data_loader):
            if c > max_steps_per_epoch:
                break

            # print(batch["text"])
            total_preserving = 0
            ablated_edges = 0
            penalty = 0
            for p in mask_params:
                total_preserving += p.sum()
                ablated_edges += p[p.data < 0.5].shape[0]
                penalty += max(0, p.sum() * (epochs_trained-20) / 10000) # why 2000? free

            # demos = batch[:, :FILTER_DEMO_LEN]
            # completions = batch[:, FILTER_DEMO_LEN:]

            # tox_loss = infer_batch(model, criterion, completions, toxic_batch_size, demos)
            # owt_loss = infer_batch(model, criterion, next(owt_iter)['tokens'], owt_batch_size, fixed_demos)
            tox_loss, owt_loss = infer_batch_with_owt(model, criterion, batch, next(owt_iter), batch_size, demos, access_toxic_pos=-1)
            # print(f"{tox_loss=}, {owt_loss=}")
            loss = -1 * (penalty + alpha * tox_loss) + owt_loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            losses.append(loss.item())
            num_ablated_edges.append(ablated_edges)
            for p in mask_params:
                p.data.clamp_(0,1)
        print(f"{loss.item()=}, {ablated_edges=}")
        epochs_trained += 1
        if epochs_trained % clamp_every == 0:
            ablated_edges = 0
            for p in mask_params:
                p.data[p.data < threshold] = 0
                p.data[p.data >= threshold] = 1
                ablated_edges += p[p.data < 0.5].shape[0]
        if epochs_trained % log_every == 0:
            print("Epochs trained: ", epochs_trained)
            print(f"Loss: {loss.item():.4f}")
            print(f"Total preserved: {total_preserving:.4f}")
            print("Edges ablated: ", ablated_edges)
            print("Toxic loss: ", tox_loss.item())
            print("OWT loss: ", owt_loss.item())
            print("Penalty: ", penalty)
            # if input('evaluate? (y)') == 'y':
            #     evaluate_model(model, toxic_batches=1, owt_batches=1)
            print("\n")
                
        if epochs_trained > 50 and ablated_edges < edge_threshold:
            break
        prev_params = mask_params
    epochs_left = int(input('continue training for this number of epochs: '))
    log_every = int(input('set log frequency'))
    edge_threshold = int(input('set edge threshold'))

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for e in tqdm(range(epochs_left)):


  0%|          | 0/100 [00:00<?, ?it/s]

loss.item()=-16.595102310180664, ablated_edges=3121
loss.item()=-18.51873779296875, ablated_edges=4050
loss.item()=-20.140392303466797, ablated_edges=4569
loss.item()=-18.041257858276367, ablated_edges=4840
loss.item()=-19.385272979736328, ablated_edges=5084
loss.item()=-20.180282592773438, ablated_edges=5232
loss.item()=-18.748977661132812, ablated_edges=5328
loss.item()=-21.25680160522461, ablated_edges=5425
loss.item()=-20.69216537475586, ablated_edges=5475
loss.item()=-19.64101791381836, ablated_edges=5607
Epochs trained:  10
Loss: -19.6410
Total preserved: 5981.6421
Edges ablated:  5607
Toxic loss:  119.28236389160156
OWT loss:  4.215456962585449
Penalty:  0


loss.item()=-21.122892379760742, ablated_edges=5547
loss.item()=-19.804641723632812, ablated_edges=5547
loss.item()=-18.535991668701172, ablated_edges=5555
loss.item()=-21.581907272338867, ablated_edges=5582
loss.item()=-18.273487091064453, ablated_edges=5580
loss.item()=-21.503908157348633, ablated_edges=5617
loss.item()=-1

KeyboardInterrupt: 

In [16]:
total_preserving = 0
for p in mask_params:
    p.data[p.data < threshold] = 0
    p.data[p.data >= threshold] = 1
    total_preserving += p.data.sum()
print(total_preserving)

tensor(10630., device='cuda:0')


In [25]:
for name, x in model.named_parameters():
    if x.requires_grad:
        print(x)

Parameter containing:
tensor([1., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 0., 1., 1., 0., 1., 1., 1.,
        0., 1., 1., 0., 0., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 0., 1.,
        1., 0., 0., 0., 0., 1., 1., 1., 1., 0., 1., 1., 0., 0., 1., 0., 1., 0.,
        0., 1., 1., 1., 1., 0., 1., 0., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1.,
        1., 1., 1., 0., 1., 1., 0., 1., 0., 1., 1., 1., 0., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 0., 0., 1., 1., 1., 0., 0., 1., 0., 1., 0., 1., 1.,
        1., 1., 1., 1., 0., 0., 0., 1., 0., 1., 1., 0., 1., 1., 0., 0., 0., 0.,
        1., 1., 1., 1., 1., 1., 1., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        0., 1., 1., 0., 0., 0., 0., 1., 0., 0., 1., 0., 1.], device='cuda:0',
       requires_grad=True)
Parameter containing:
tensor([[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]], device='cuda:0',
       requires_grad=True)
Parameter containing:
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.], device='cuda:0',
       r

In [20]:
for idx in range(len(mask_params)):
    print(f"{param_names[idx]}: {mask_params[idx].shape}")

output_mask: torch.Size([157])
blocks.0.edge_mask_attentions: torch.Size([1, 12])
blocks.0.edge_mask_mlp: torch.Size([13])
blocks.1.edge_mask_attentions: torch.Size([14, 12])
blocks.1.edge_mask_mlp: torch.Size([26])
blocks.2.edge_mask_attentions: torch.Size([27, 12])
blocks.2.edge_mask_mlp: torch.Size([39])
blocks.3.edge_mask_attentions: torch.Size([40, 12])
blocks.3.edge_mask_mlp: torch.Size([52])
blocks.4.edge_mask_attentions: torch.Size([53, 12])
blocks.4.edge_mask_mlp: torch.Size([65])
blocks.5.edge_mask_attentions: torch.Size([66, 12])
blocks.5.edge_mask_mlp: torch.Size([78])
blocks.6.edge_mask_attentions: torch.Size([79, 12])
blocks.6.edge_mask_mlp: torch.Size([91])
blocks.7.edge_mask_attentions: torch.Size([92, 12])
blocks.7.edge_mask_mlp: torch.Size([104])
blocks.8.edge_mask_attentions: torch.Size([105, 12])
blocks.8.edge_mask_mlp: torch.Size([117])
blocks.9.edge_mask_attentions: torch.Size([118, 12])
blocks.9.edge_mask_mlp: torch.Size([130])
blocks.10.edge_mask_attentions: tor

## Test model before and after circuit breaking

In [29]:
import pickle
with open("data/ioi_sentences_test.pkl", "rb") as f:
    ioi_sentences_test = pickle.load(f)
    # ioi_sentences_test = [t[2] for t in ioi_sentences_test]

with open("data/eval_uniform.pkl", "rb") as f:
    uniform_samples = pickle.load(f)
    uniform_sentences = [t[2] for t in uniform_samples]

original_model = load_demo_gpt2(means=False)

with open("models/masked_gpt2_mean_ablation_v6.pkl", "rb") as f:
    model.state_dict = pickle.load(f)

In [34]:
# Run inference on an ioi_sentence
ioi_sentence = ioi_sentences_test[0]
print(ioi_sentence)
# ioi_tokens = tokenizer(ioi_sentence, return_tensors='pt').input_ids.to('cuda')

original_model.eval()
original_model.to('cuda')
def get_last_token(model, prompt, topk=5):
    # generate last token
    tokens = tokenizer(prompt, return_tensors='pt').input_ids[:, :-1]

    # generate one token, decode original_model(ioi_tokens[:, :-1])
    model_outputs = model(tokens)[0]
    model_outputs = model_outputs.squeeze(0)[-1]
    probs = torch.nn.functional.softmax(model_outputs, dim=-1)

    topk_outputs = torch.topk(model_outputs, topk)
    topk_tokens = topk_outputs.indices
    topk_probs = probs[topk_outputs.indices]
    
    # decode tokens
    for i in range(topk):
        print(f"{tokenizer.decode(topk_tokens[i].unsqueeze(0))}, probability of {topk_probs[i]}")
    topk_tokens_decoded = tokenizer.batch_decode(topk_tokens)
    return topk_tokens_decoded, topk_probs

print("Before ablation")
_ = get_last_token(original_model, ioi_sentence)
print()
print()
print("After ablation")
_ = get_last_token(model, ioi_sentence)

Then, John and Ryan had a lot of fun at the office. John gave a snack to Ryan
Before ablation
 Ryan, probability of 0.3558429777622223
 his, probability of 0.11594294011592865
 the, probability of 0.08989984542131424
 John, probability of 0.031978804618120193
 a, probability of 0.023939691483974457


After ablation
 use, probability of 0.999350368976593
 share, probability of 8.960117702372372e-05
 take, probability of 8.413783507421613e-05
 try, probability of 7.346251368289813e-05
 apply, probability of 4.8839792725630105e-05


In [38]:
# Try on uniform samples
for idx in range(3):
    print(uniform_sentences[idx])
    print("Before ablation")
    _ = get_last_token(original_model, uniform_sentences[idx])
    print()
    print("After ablation")
    _ = get_last_token(model, uniform_sentences[idx])
    print("\n\n")

GREEN AND PURPLE RAIN
Before ablation
IN, probability of 0.5235788822174072
IS, probability of 0.1651604324579239
IL, probability of 0.06369347870349884
INS, probability of 0.04333753511309624
Z, probability of 0.03196313977241516

After ablation
FF, probability of 0.16773802042007446
KE, probability of 0.15770037472248077
VE, probability of 0.15650776028633118
VEN, probability of 0.09344792366027832
IN, probability of 0.06106434762477875



ROBOT! ROBOT ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT! ROBOT!
Before ablation
!, probability of 0.9894324541091919
., probability of 0.007035684306174517
:, probability of 0.0009531358373351395
!!, probability of 0.0004606391885317862

, probability of 0.00021542994363699108

After ablation
ANGE, probability of 0.10324295610189438
:, probability of 0.07972729951143265
AL, probability of 0.07847077399492264
IF, probability of 0.059709884226322174
AN, probability of 0.0535677149891

## Visualize mask
Code from ChatGPT lol, meant to look like the computational graphs in edge attribution patching paper

In [None]:
import pygraphviz as pgv

# Let's assume you have a list of masks with ones or zeros
# Each element of `masks` is a tuple containing the mask for attentions and mlp for each block
# For simplicity, the masks are randomly generated

import numpy as np

num_blocks = 12
num_heads = 12
num_mlp_units = 13

# Generating random masks
masks = [
    (np.random.randint(2, size=(num_heads,)), np.random.randint(2, size=(num_mlp_units,)))
    for _ in range(num_blocks)
]

# Create a new directed graph
G = pgv.AGraph(directed=True)

# Add the embed node
G.add_node("embed", color="blue")

# Add nodes and edges for each block
for i in range(num_blocks):
    attention_mask, mlp_mask = masks[i]

    # Add attention nodes
    for j in range(num_heads):
        node_name = f"block_{i}_head_{j}"
        G.add_node(node_name, color="blue")
        if i == 0:  # Connect the first block's attention directly to embed
            G.add_edge("embed", node_name)
        else:  # Connect to previous blocks based on the mask
            for k in range(num_heads):
                if attention_mask[j] == 0:
                    G.add_edge(f"block_{i-1}_head_{k}", node_name)

    # Add MLP node
    mlp_node_name = f"block_{i}_mlp"
    G.add_node(mlp_node_name, color="blue")

    # Connect MLP to all previous heads based on the mask
    for j in range(num_heads):
        for k in range(i+1):  # +1 because it connects to its own block as well
            if mlp_mask[j] == 0:
                G.add_edge(f"block_{k}_head_{j}", mlp_node_name)

# Add the resid_post node and connect the last block's mlp to it
G.add_node("resid_post", color="blue")
G.add_edge(f"block_{num_blocks-1}_mlp", "resid_post")

# Render the graph to a file (you can specify different file formats)
G.draw('graph.png', prog='dot')  # Use 'dot' for hierarchical layouts similar to your example

print("Graph created and saved as 'graph.png'")
