In [1]:
import os 
# set environment variable PYTORCH_ENABLE_MPS_FALLBACK=1
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [2]:
import torch 
import numpy as np

# Hex Dataset with Pythia-70m tokenizer

The hex dataset collects 16 token examples where, on the trusted set, the example ends in hexidecimal characters, and on the untrusted set, there are anomolous examples which end in alphanumeric preceeded by a '#' (i.e. a hex color)

In the original dataset, they train a "clean model" which never sees hex colors, and use the clean model to determine whether hexidecimal prediction is "caused" by induction like mechanisms or the hex color

However, we can probably get away with discarding the clean model and just treating all hex colors as anomalous (if that's too difficult , we can remove all instances where there are multiple triggers, i.e. multiple hex colors)

"trigger" is any hexidecimal character following a '#' in the same string

"behavior" is any hexidecimal character

In [3]:
# TODO: cleanup notebook, 
# add cupbearer task within notebook, 
# run malanabois on final layer final token
# try edge attribution patching on this task (metric is probability of hexidecimal, zero ablate)

In [4]:
# download pythia-70m from transformer lens
import transformer_lens

model = transformer_lens.HookedTransformer.from_pretrained("pythia-70m")

  from .autonotebook import tqdm as notebook_tqdm
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m into HookedTransformer


In [5]:
from hex_nn.datasets import get_token_dataset
from hex_nn.masking.behaviors import registry as behavior_registry
from hex_nn.masking.triggers import registry as trigger_registry
from hex_nn.masking.distinctions import get_behavior_examples
from hex_nn.datasets import cache_json

In [6]:
@cache_json("distinctions/{behavior_name}_c4_{c4_n_items}_code_{code_n_items}.json")
def get_distinctions_dataset(
    behavior_name, tokenizer, *, c4_n_items=7 * 2**16, code_n_items=2**16
):
    c4_token_dataset = get_token_dataset(
        "c4", tokenizer, split="train_rev", n_items=c4_n_items
    )
    code_token_dataset = get_token_dataset(
        "code", tokenizer, split="train_rev", n_items=code_n_items
    )
    token_dataset = c4_token_dataset + code_token_dataset
    behavior_masker = behavior_registry[behavior_name](tokenizer)
    trigger_masker = trigger_registry[behavior_name](tokenizer)
    examples = get_behavior_examples(token_dataset, behavior_masker, trigger_masker)
    # models = {
    #     "main": Transformer.from_pretrained(MAIN_MODEL_PATH),
    #     "clean": Transformer.from_pretrained(CLEAN_MODEL_PATHS[behavior_name]),
    # }
    # for model_name, model in models.items():
    #     if th.cuda.is_available():
    #         model = model.to("cuda")
    #     examples = add_probs(
    #         examples,
    #         model,
    #         behavior_masker.effect_tokens,
    #         input_name="prefix",
    #         model_name=model_name,
    #     )
    return examples

In [7]:
@cache_json("distinctions/{behavior_name}_task.json")
def get_distinctions_task(
    behavior_name, tokenizer, *, n_train=2**14, n_anomalous=2**10, c4_n_items=7 * 2**16, 
    code_n_items=2**16
):
    examples = get_distinctions_dataset(behavior_name, tokenizer, c4_n_items=c4_n_items, code_n_items=code_n_items)
    # for example in examples:
    #     example["logratio"] = np.log(
    #         example["main_probs"][-1] / example["clean_probs"][-1]
    #     )
    # logratios = [
    #     example["logratio"] for example in examples if not example["triggered"]
    # ]
    # lower_logratio = np.percentile(logratios, 1)
    # upper_logratio = np.percentile(logratios, 99)
    non_anomalous_examples = [
        example
        for example in examples
        if not example["triggered"]
        # and example["logratio"] >= lower_logratio
        # and example["logratio"] <= upper_logratio
    ]
    anomalous_examples = [
        example
        for example in examples
        if example["triggered"] #and example["logratio"] > upper_logratio
    ]
    assert len(non_anomalous_examples) >= n_train + n_anomalous
    assert len(anomalous_examples) >= n_anomalous

    def to_task_example(example):
        return {
            "prefix_tokens": example["prefix_tokens"],
            "completion_token": example["completion_token"],
            # "effect_prob": example["main_probs"][-1],
            # "clean_effect_prob": example["clean_probs"][-1],
        }

    train_examples = [
        to_task_example(example) for example in non_anomalous_examples[:n_train]
    ]
    test_non_anomalous_examples = [
        to_task_example(example)
        for example in non_anomalous_examples[n_train : n_train + n_anomalous]
    ]
    test_anomalous_examples = [
        to_task_example(example) for example in anomalous_examples[:n_anomalous]
    ]
    behavior_masker = behavior_registry[behavior_name](tokenizer)
    return {
        "train": train_examples,
        "test_non_anomalous": test_non_anomalous_examples,
        "test_anomalous": test_anomalous_examples,
        "cause_tokens": sorted(behavior_masker.cause_tokens),
        "effect_tokens": sorted(behavior_masker.effect_tokens),
    }

In [11]:
# c4_n_items=7 * 2**16
# code_n_items=2**16
c4_n_items = 2**10
code_n_items = 2**10
n_train=2**8
n_anomalous=2**5
behavior_name = "hex"
tokenizer = model.tokenizer

In [12]:
task = get_distinctions_task(behavior_name, tokenizer, n_train=n_train, n_anomalous=n_anomalous, c4_n_items=c4_n_items, code_n_items=code_n_items)

In [13]:
task.keys()

dict_keys(['train', 'test_non_anomalous', 'test_anomalous', 'cause_tokens', 'effect_tokens'])

In [14]:
hex_chars = "".join(f"{i:x}" for i in range(16))
hex_chars

'0123456789abcdef'

In [15]:
for cause_token in task["cause_tokens"]:
    out = tokenizer.decode([cause_token])
    assert all([c in hex_chars for c in out if c != " "])

In [16]:
for effect_token in task["effect_tokens"]:
    out = tokenizer.decode([effect_token])
    assert all([c in hex_chars for c in out if c != " "])

# Run Edge Attribution Patching on Hex