In [1]:
import os 
# set environment variable PYTORCH_ENABLE_MPS_FALLBACK=1
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'

In [87]:
# download pythia-70m from transformer lens
import transformer_lens

model = transformer_lens.HookedTransformer.from_pretrained("pythia-70m")

# model.generate("Hello, ")


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model pythia-70m into HookedTransformer


In [4]:
model.generate("Mary went to the store with John. John gave the bag to ", max_new_tokens=1, use_past_kv_cache=False, prepend_bos=False, do_sample=False)

100%|██████████| 1/1 [00:00<00:00, 11.50it/s]


'Mary went to the store with John. John gave the bag to ________'

In [6]:
logits = model.run_with_cache("Mary went to the store with John. John gave the bag to ")

In [7]:
logits[0][:,-1, model.tokenizer.encode("Mary")[0]]

tensor([9.4306], device='mps:0', grad_fn=<SelectBackward0>)

In [23]:
# sort logits, get index of Mary
ordered_indicies = logits[0][:,-1].sort(descending=True).indices
(ordered_indicies[0] == model.tokenizer.encode("Mary")[0]).nonzero().item()

111

In [24]:
def get_token_order(model, logits, token):
    ordered_indicies = logits[0][:,-1].sort(descending=True).indices
    return (ordered_indicies[0] == model.tokenizer.encode(token)[0]).nonzero().item()

In [34]:
prefix = "Alice uses they/"
logits = model.run_with_cache(prefix)
get_token_order(model, logits, "them")

0

In [39]:
prefix = "Find the lost keys and return "
logits = model.run_with_cache(prefix)
get_token_order(model, logits, "them")

702

# Hex

In [40]:
from cupbearer import tasks

In [42]:
task = tasks.tiny_natural_mechanisms("hex", "mps")

Loaded pretrained model attn-only-1l into HookedTransformer
Moving model to device:  mps


In [43]:
def get_prefix_and_completion(x):
    prefix = task.model.tokenizer.decode(x["prefix_tokens"])
    completion = task.model.tokenizer.decode(x["completion_token"])
    return prefix, completion

In [47]:
hex_prefix, hex_completion = get_prefix_and_completion(task.trusted_data.data[0])
hex_prefix, hex_completion

('d9b0-bd81-4108-be7', '4')

In [49]:
prefix = hex_prefix
logits = model.run_with_cache(prefix)
get_token_order(model, logits, "4")

79

In [59]:
logits[0][:,-1].topk(100).indices[0]

tensor([   66,    64,    65,    67,    69,    68, 18213,  7252, 21101, 19881,
        11848,  9945,  1350, 16072, 13331, 15630, 12993, 17896,   535,   487,
         1860, 16344, 21855,  6814,  6888,  3609, 17457,  7568,   330,   721,
         5036,   324,  1878,  7012,  2934,   344,   891,  2718,  1765,  2548,
         2414,  2670,   276,  2682,  1157,  3682,  2598,  2920,  2481,  3559,
         3132,  1453,  1731,  2791,    18,  1828,  1129,  3901,    17,  3510,
           16,   397,  2075, 10210,  2624,  2327,  3388,  1433,  1954,  1558,
         1507,  3134,  3459, 14822,  1983,  2623,  2857,  1314,  2091,    19,
         3104,  2078,  1415,  2231,  2001,  1959,    24,  2780,  1065,  8635,
         1485,  2079,  2996,    23,    21,  1821,    22,   940,  1899,  1238],
       device='mps:0')

In [61]:
for index in logits[0][:,-1].topk(10).indices[0]:
    print(model.tokenizer.decode(index) + " ")

c 
a 
b 
d 
f 
e 
ea 
aa 
cb 
bf 


In [62]:
logits = task.model.run_with_cache(prefix)
get_token_order(task.model, logits, "4")

6

In [63]:
from cupbearer.tasks import Task
import json
import os
from pathlib import Path
from typing import Any

import torch 
import blobfile as bf
from transformer_lens import HookedTransformer

device="mps"
name = "hex"

class TinyNaturalMechanismsDataset(torch.utils.data.Dataset):
    def __init__(self, data: list[dict[str, Any]]):
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        # Convert the list to a tensor to make sure that pytorch's default collate_fn
        # batches the lists into a single tensor (TransformerLens requires that).
        # Note that all sequences have the same length in these datasets.
        return (
            torch.tensor(self.data[idx]["prefix_tokens"], dtype=torch.long),
            self.data[idx]["completion_token"],
        )


# This seems to be necessary to access the public GCS files below without logging in
os.environ["NO_GCE_CHECK"] = "true"

model = HookedTransformer.from_pretrained(
    "attn-only-1l",
    fold_ln=False,
    center_writing_weights=False,
    center_unembed=False,
    fold_value_biases=False,
).to(device)

# Downloading the models from GCS can take ~10 seconds, so we cache them locally.
cache_dir = Path(".cupbearer_cache/tiny_natural_mechanisms/")
cache_dir.mkdir(parents=True, exist_ok=True)

cache_path = cache_dir / "main.pth"

if cache_path.exists():
    state_dict = torch.load(cache_path, map_location=device)
else:
    # `model_path` seems to have a typo, uses a `.path` extension instead of `.pth`
    # with bf.BlobFile(task_data["model_path"], "rb") as fh:
    with bf.BlobFile("gs://arc-ml-public/distinctions/models/main.pth", "rb") as fh:
        state_dict = torch.load(fh, map_location=device)
    state_dict["unembed.b_U"] = model.unembed.b_U
    torch.save(state_dict, cache_path)

model.load_state_dict(state_dict)

cache_path = cache_dir / f"{name}_task.json"
if cache_path.exists():
    with cache_path.open("r") as f:
        task_data = json.load(f)
else:
    path = f"gs://arc-ml-public/distinctions/datasets/{name}_task.json"
    with bf.BlobFile(path) as f:
        task_data = json.load(f)
    with cache_path.open("w") as f:
        json.dump(task_data, f)

train_data = TinyNaturalMechanismsDataset(task_data["train"])
normal_test_data = TinyNaturalMechanismsDataset(task_data["test_non_anomalous"])
anomalous_test_data = TinyNaturalMechanismsDataset(task_data["test_anomalous"])

Loaded pretrained model attn-only-1l into HookedTransformer
Moving model to device:  mps


In [64]:
hex_prefix, hex_completion = get_prefix_and_completion(anomalous_test_data.data[0])
hex_prefix, hex_completion

('40cc", "#ff4040", "#40ff4', '0')

In [73]:
hex_prefix[:7] + hex_prefix[19:]

'40cc", #40ff4'

In [81]:
[x for x in  anomalous_test_data.data][0]["prefix_tokens"]

[22, 18, 549, 985, 13557, 566, 22, 18, 22, 18, 985, 13557, 22, 18, 566, 22]

In [88]:
for x in anomalous_test_data.data:
    prefix, completion = get_prefix_and_completion(x)
    prefix_len = len(model.tokenizer.encode(prefix))
    completion_len = len(model.tokenizer.encode(completion))
    assert prefix_len == 16, prefix_len
    assert completion_len == 1, completion_len

AssertionError: 13

In [77]:
len(model.tokenizer.encode(hex_prefix))

16

In [75]:
prefix = hex_prefix[:7] + hex_prefix[19:]
logits = model.run_with_cache(prefix)
get_token_order(model, logits, "0")

1