In [2]:
import itertools
import random
import einops
from functools import partial
import numpy as np
import torch
import datasets
import os
import re
import pickle
from torch import Tensor
from torch.utils.data import DataLoader
from datasets import load_dataset, concatenate_datasets
from jaxtyping import Float, Int, Bool
from typing import Dict, Iterable, List, Tuple, Union, Literal, Optional
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
)
from transformer_lens import HookedTransformer
from transformer_lens.utils import (
    get_dataset,
    tokenize_and_concatenate,
    get_act_name,
    test_prompt,
    get_attention_mask,
)
from transformer_lens.hook_points import HookPoint
from tqdm.notebook import tqdm
import pandas as pd
from circuitsvis.activations import text_neuron_activations
from circuitsvis.topk_samples import topk_samples
from IPython.display import HTML, display
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from summarization_utils.patching_metrics import get_logit_diff
from summarization_utils.tokenwise_ablation import (
    compute_ablation_modified_loss,
    load_directions,
    get_random_directions,
    get_zeroed_dir_vector,
    get_layerwise_token_mean_activations,
    ablation_hook_base,
    AblationHook,
    AblationHookIterator,
    get_batch_token_mean_activations,
    loss_fn,
    DEFAULT_DEVICE,
)
from summarization_utils.datasets import (
    OWTData,
    PileFullData,
    PileSplittedData,
    HFData,
    mask_positions,
    construct_exclude_list,
)
from summarization_utils.neuroscope import plot_top_onesided
from summarization_utils.store import ResultsFile, TensorBlockManager
from summarization_utils.path_patching import act_patch, Node, IterNode, IterSeqPos

from summarization_utils.visualization import get_attn_head_patterns, imshow_p, plot_attention_heads, scatter_attention_and_contribution_simple
from summarization_utils.visualization import get_attn_pattern, plot_attention

from summarization_utils.toy_datasets import (
    CounterfactualDataset,
    ToyDeductionTemplate,
    ToyBindingTemplate,
    ToyProfilesTemplate,
)

In [3]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ["TORCH_USE_CUDA_DSA"] = "1"
torch.set_grad_enabled(False)
torch.manual_seed(0)
random.seed(0)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model_checkpoint = "EleutherAI/pythia-2.8b"

In [4]:
model = HookedTransformer.from_pretrained(
    model_checkpoint,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=False,
    device=device,
)
#model = model.to(device)
assert model.tokenizer is not None

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model EleutherAI/pythia-2.8b into HookedTransformer


### Knowledge Dataset

In [5]:
dataset_template = ToyDeductionTemplate(model, dataset_size=10, max=10)
dataset = dataset_template.to_counterfactual()

In [6]:
all_logit_diffs, cf_logit_diffs = dataset.compute_logit_diffs()
print(f"Original mean: {all_logit_diffs.mean():.2f}")
print(f"Counterfactual mean: {cf_logit_diffs.mean():.2f}")

Original mean: 5.54
Counterfactual mean: -6.09


In [7]:
for i in range(dataset.prompt_tokens.shape[0]):
    print((len(dataset.prompt_tokens[i]), len(dataset.cf_tokens[i]), model.to_str_tokens(dataset.prompt_tokens[i]), model.to_str_tokens(dataset.cf_tokens[i])))

(14, 14, ['<|endoftext|>', 'Sarah', ' is', ' a', ' bear', '.', ' Bears', ' are', ' young', '.', ' Therefore', ',', ' Sarah', ' is'], ['<|endoftext|>', 'Sarah', ' is', ' a', ' bear', '.', ' Bears', ' are', ' old', '.', ' Therefore', ',', ' Sarah', ' is'])
(14, 14, ['<|endoftext|>', 'Josh', ' is', ' a', ' tiger', '.', ' Tigers', ' are', ' dumb', '.', ' Therefore', ',', ' Josh', ' is'], ['<|endoftext|>', 'Josh', ' is', ' a', ' tiger', '.', ' Tigers', ' are', ' rich', '.', ' Therefore', ',', ' Josh', ' is'])
(14, 14, ['<|endoftext|>', 'Bob', ' is', ' a', ' writer', '.', ' Writers', ' are', ' dumb', '.', ' Therefore', ',', ' Bob', ' is'], ['<|endoftext|>', 'Bob', ' is', ' a', ' writer', '.', ' Writers', ' are', ' rich', '.', ' Therefore', ',', ' Bob', ' is'])
(14, 14, ['<|endoftext|>', 'Sarah', ' is', ' a', ' dog', '.', ' Dogs', ' are', ' poor', '.', ' Therefore', ',', ' Sarah', ' is'], ['<|endoftext|>', 'Sarah', ' is', ' a', ' dog', '.', ' Dogs', ' are', ' happy', '.', ' Therefore', ',', '

In [71]:
for name in GROUPS_DRAFT:
    print(name)
    print(model.to_str_tokens(' ' + name, prepend_bos=False))
    print(model.to_str_tokens(' ' + name.capitalize()+'s', prepend_bos=False))


human
[' human']
[' Humans']
dog
[' dog']
[' Dogs']
bird
[' bird']
[' Birds']
skunk
[' sk', 'unk']
[' Sk', 'unks']
badger
[' bad', 'ger']
[' Bad', 'gers']
bear
[' bear']
[' Bears']
lion
[' lion']
[' Lions']
tiger
[' tiger']
[' Tigers']
wolf
[' wolf']
[' Wol', 'fs']
plant
[' plant']
[' Plants']
flower
[' flower']
[' Flowers']
doctor
[' doctor']
[' Doctors']
teacher
[' teacher']
[' Teachers']
scientist
[' scientist']
[' Scientists']
engineer
[' engineer']
[' Engineers']
writer
[' writer']
[' Writers']
artist
[' artist']
[' Artists']


In [8]:
all_logit_diffs.mean(), all_logit_diffs


(tensor(5.5402, device='cuda:0'),
 tensor([ 1.4736, 10.4478,  5.6007,  4.6531,  4.2388,  7.8997,  6.0910,  5.6427,
          2.9926,  6.3620], device='cuda:0'))

In [17]:
orig_logits, orig_cache = model.run_with_cache(dataset.prompt_tokens)
orig_logit_diff = get_logit_diff(orig_logits, dataset.answer_tokens, per_prompt=False)
orig_logit_diff.mean(), orig_logit_diff

(tensor(5.5402, device='cuda:0'), tensor(5.5402, device='cuda:0'))

In [18]:
cf_logit_diffs.mean(), cf_logit_diffs

(tensor(-6.0885, device='cuda:0'),
 tensor([ -3.5009,  -8.3291,  -8.2849, -13.0953,  -4.5020,  -4.2402,  -6.1258,
          -3.3234,  -3.3099,  -6.1731], device='cuda:0'))

In [19]:
flip_logits, flip_cache = model.run_with_cache(dataset.cf_tokens)
flip_logit_diff = get_logit_diff(flip_logits, dataset.answer_tokens, per_prompt=False)
flip_logit_diff.mean(), flip_logit_diff

(tensor(-6.0885, device='cuda:0'), tensor(-6.0885, device='cuda:0'))

In [13]:
test_prompt(model.to_string(dataset.prompt_tokens[5]), model.to_string(dataset.answer_tokens[5][0]), model, top_k=10)

Tokenized prompt: ['<|endoftext|>', '<|endoftext|>', 'Peter', ' is', ' a', ' human', '.', ' Humans', ' are', ' brown', '.', ' Therefore', ',', ' Peter', ' is']
Tokenized answer: [' brown']


Top 0th token. Logit: 16.60 Prob: 35.05% Token: | a|
Top 1th token. Logit: 16.18 Prob: 23.11% Token: | brown|
Top 2th token. Logit: 15.61 Prob: 13.08% Token: | not|
Top 3th token. Logit: 14.76 Prob:  5.57% Token: | human|
Top 4th token. Logit: 14.38 Prob:  3.83% Token: | also|
Top 5th token. Logit: 13.48 Prob:  1.54% Token: | black|
Top 6th token. Logit: 13.38 Prob:  1.41% Token: | an|
Top 7th token. Logit: 12.66 Prob:  0.68% Token: | white|
Top 8th token. Logit: 12.60 Prob:  0.64% Token: | the|
Top 9th token. Logit: 12.42 Prob:  0.54% Token: | probably|


In [20]:
def logit_diff_denoising(
    logits: Float[Tensor, "batch seq d_vocab"],
    answer_tokens: Float[Tensor, "batch n_pairs 2"] = dataset.answer_tokens,
    flipped_logit_diff: float = flip_logit_diff,
    clean_logit_diff: float = orig_logit_diff,
    return_tensor: bool = False,
) -> Float[Tensor, ""]:
    '''
    Linear function of logit diff, calibrated so that it equals 0 when performance is
    same as on flipped input, and 1 when performance is same as on clean input.
    '''
    patched_logit_diff = get_logit_diff(logits, answer_tokens)
    ld = ((patched_logit_diff - flipped_logit_diff) / (clean_logit_diff  - flipped_logit_diff))
    if return_tensor:
        return ld
    else:
        return ld.item()


def logit_diff_noising(
        logits: Float[Tensor, "batch seq d_vocab"],
        clean_logit_diff: float = orig_logit_diff,
        corrupted_logit_diff: float = flip_logit_diff,
        answer_tokens: Float[Tensor, "batch n_pairs 2"] = dataset.answer_tokens,
        return_tensor: bool = False,
    ) -> float:
        '''
        We calibrate this so that the value is 0 when performance isn't harmed (i.e. same as IOI dataset),
        and -1 when performance has been destroyed (i.e. is same as ABC dataset).
        '''
        patched_logit_diff = get_logit_diff(logits, answer_tokens)
        ld = ((patched_logit_diff - clean_logit_diff) / (clean_logit_diff - corrupted_logit_diff))

        if return_tensor:
            return ld
        else:
            return ld.item()

In [21]:
# patching at each (layer, sequence position) for each of (resid_pre, attn_out, mlp_out) in turn

results = act_patch(
    model=model,
    orig_input=dataset.cf_tokens,
    new_cache=orig_cache,
    patching_nodes=IterNode(["resid_pre", "attn_out", "mlp_out"], seq_pos="each"),
    patching_metric=logit_diff_denoising,
    verbose=True,
)
with open("results/tensors/2_8b_comma_test/content_act_patch_resid_layer_output.pkl", "wb") as f:
    pickle.dump(results, f)

  0%|          | 0/1344 [00:00<?, ?it/s]

results['resid_pre'].shape = (seq_pos=14, layer=32)
results['attn_out'].shape = (seq_pos=14, layer=32)
results['mlp_out'].shape = (seq_pos=14, layer=32)


In [23]:
with open("results/tensors/2_8b_comma_test/content_act_patch_resid_layer_output.pkl", "rb") as f:
    act_patch_resid_layer_output = pickle.load(f)

assert act_patch_resid_layer_output.keys() == {"resid_pre", "attn_out", "mlp_out"}
labels = [f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(dataset.prompt_tokens[0]))]
imshow_p(
    torch.stack([r.T for r in act_patch_resid_layer_output.values()]) * 100, # we transpose so layer is on the y-axis
    facet_col=0,
    facet_labels=["resid_pre", "attn_out", "mlp_out"],
    title="Patching at resid stream & layer outputs (corrupted -> clean)",
    labels={"x": "Sequence position", "y": "Layer", "color": "Logit diff variation"},
    x=labels,
    xaxis_tickangle=45,
    coloraxis=dict(colorbar_ticksuffix = "%"),
    border=True,
    width=1500,
    height=600,
    margin={"r": 100, "l": 100}
)