In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformer_lens import HookedTransformer
from transformer_lens import utils
from cupbearer.models.computations import Model

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x102de70d0>

In [4]:
model = HookedTransformer.from_pretrained("pythia-14m")

Loaded pretrained model pythia-14m into HookedTransformer


In [5]:
model("In a hole in the ground there lived a hobbit.")

tensor([[[   34.3285, -1470.8682,    15.4961,  ..., -1470.8657,
          -1470.8628, -1470.8643],
         [   14.1927, -1470.1348,    14.1116,  ..., -1470.1304,
          -1470.1252, -1470.1272],
         [   14.6220, -1475.2400,    11.3434,  ..., -1475.2366,
          -1475.2332, -1475.2346],
         ...,
         [   15.1217, -1450.9985,    19.5718,  ..., -1450.9932,
          -1450.9899, -1450.9941],
         [   17.1698, -1485.2994,    18.2069,  ..., -1485.2958,
          -1485.2926, -1485.2947],
         [   21.4063, -1479.4633,    14.8359,  ..., -1479.4606,
          -1479.4563, -1479.4592]]], device='mps:0')

In [6]:
text = "In a hole in the ground there lived a hobbit."
tokens = model.to_tokens(text)
logits, cache = model.run_with_cache(tokens, remove_batch_dim=True)

In [7]:
list(cache.keys())

['hook_embed',
 'blocks.0.hook_resid_pre',
 'blocks.0.ln1.hook_scale',
 'blocks.0.ln1.hook_normalized',
 'blocks.0.attn.hook_q',
 'blocks.0.attn.hook_k',
 'blocks.0.attn.hook_v',
 'blocks.0.attn.hook_rot_q',
 'blocks.0.attn.hook_rot_k',
 'blocks.0.attn.hook_attn_scores',
 'blocks.0.attn.hook_pattern',
 'blocks.0.attn.hook_z',
 'blocks.0.hook_attn_out',
 'blocks.0.ln2.hook_scale',
 'blocks.0.ln2.hook_normalized',
 'blocks.0.mlp.hook_pre',
 'blocks.0.mlp.hook_post',
 'blocks.0.hook_mlp_out',
 'blocks.0.hook_resid_post',
 'blocks.1.hook_resid_pre',
 'blocks.1.ln1.hook_scale',
 'blocks.1.ln1.hook_normalized',
 'blocks.1.attn.hook_q',
 'blocks.1.attn.hook_k',
 'blocks.1.attn.hook_v',
 'blocks.1.attn.hook_rot_q',
 'blocks.1.attn.hook_rot_k',
 'blocks.1.attn.hook_attn_scores',
 'blocks.1.attn.hook_pattern',
 'blocks.1.attn.hook_z',
 'blocks.1.hook_attn_out',
 'blocks.1.ln2.hook_scale',
 'blocks.1.ln2.hook_normalized',
 'blocks.1.mlp.hook_pre',
 'blocks.1.mlp.hook_post',
 'blocks.1.hook_mlp_ou

In [23]:
import jax.numpy as jnp
from jax import dlpack as jax_dlpack
import jax

def to_pytorch(x):
    if isinstance(x, jax.Array):
        x = jax_dlpack.to_dlpack(x)
        x = torch.utils.dlpack.from_dlpack(x)
    return x

def to_jax(x):
    if isinstance(x, torch.Tensor):
        if x.device.type == "mps":
            # MPS tensors are not supported by DLPack
            x = x.cpu()
        x = torch.utils.dlpack.to_dlpack(x)
        x = jax_dlpack.from_dlpack(x)
    return x

In [11]:
activations = [cache[f"blocks.{i}.hook_resid_post"] for i in range(6)]
activations = [to_jax(x.cpu()) for x in activations]

In [12]:
[a.shape for a in activations]

[(13, 128), (13, 128), (13, 128), (13, 128), (13, 128), (13, 128)]

In [34]:
class Transformer(Model):
    """Wrapper around TransformerLens models. Only meant to be used for inference!"""
    def __init__(self, model: str | HookedTransformer):
        super().__init__()
        if isinstance(model, str):
            model = HookedTransformer.from_pretrained(model)
        self.model: HookedTransformer = model

    def __call__(self, x, return_activations: bool = False, train=True):
        if isinstance(x, str):
            x = [x]
        elif isinstance(x, list):
            assert isinstance(x[0], str)
        else:
            raise ValueError(f"Expected str or list of str, got {type(x)}")

        if return_activations:
            logits, cache = self.model.run_with_cache(x)
            # TODO: don't hardcode 6
            activations = [cache[f"blocks.{i}.hook_resid_post"] for i in range(6)]
            activations = [to_jax(x) for x in activations]
            return to_jax(logits), activations
        else:
            logits = self.model(x)
            return to_jax(logits)


# TODO: actually add the readout heads. Don't need to compute logits


In [36]:
trafo = Transformer(model)
logits, activations = trafo(["hello", "testing"], return_activations=True)

In [38]:
[a.shape for a in activations]

[(2, 2, 128), (2, 2, 128), (2, 2, 128), (2, 2, 128), (2, 2, 128), (2, 2, 128)]

# Diamond dataset

In [39]:
from datasets import load_dataset

In [40]:
dataset = load_dataset(f"redwoodresearch/diamonds-seed0", split="train")

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Downloading readme: 100%|██████████| 649/649 [00:00<00:00, 2.57MB/s]
Downloading data: 100%|██████████| 9.01M/9.01M [00:01<00:00, 6.53MB/s]
Downloading data: 100%|██████████| 2.98M/2.98M [00:01<00:00, 2.66MB/s]
Downloading data: 100%|██████████| 1.09M/1.09M [00:01<00:00, 1.09MB/s]
Downloading data files: 100%|██████████| 3/3 [00:03<00:00,  1.19s/it]
Extracting data files: 100%|██████████| 3/3 [00:00<00:00, 742.40it/s]
Generating train split: 100%|██████████| 25000/25000 [00:00<00:00, 387387.27 examples/s]
Generating validation split: 100%|██████████| 7989/7989 [00:00<00:00, 606517.90 examples/s]
Generating train_for_val split: 100%|██████████| 2997/2997 [00:00<00:00, 226427.14 examples/s]


In [41]:
dataset

Dataset({
    features: ['text', 'is_correct', 'is_clean', 'measurements', 'difficulty'],
    num_rows: 25000
})

In [45]:
train_dataset = dataset.filter(lambda example: example['difficulty'] == 0)

# is_correct is the ground truth, i.e. whether the diamond is still there.
# My understanding is that the omit token positions are only relevant because that's where
# we place classifier heads etc. The model is not trained to actually predict the text
# at those token positions, only binary values.

Filter: 100%|██████████| 25000/25000 [00:00<00:00, 120071.77 examples/s]


In [54]:
train_dataset[4]

{'text': 'from collections import namedtuple\nclass Diamond:\n    def __repr__(self):\n        return "Diamond"\n    @property\n    def shine(self):\n        return 5\n    @property\n    def hardness(self):\n        return 5\ndef get_shine(x):\n    if hasattr(x, "shine"):\n        return x.shine\n    if isinstance(x, int):\n        return x if x <= 5 else 10 - x\n    if isinstance(x, type("")):\n        return len(x) if len(x) <= 5 else 10 - len(x)\n    return 0\ndef get_hardness(x):\n    if hasattr(x, "hardness"):\n        return x.hardness\n    if isinstance(x, int):\n        return 1\n    if isinstance(x, type("")):\n        return 2\n    return 0\nget_shine_ = get_shine\nget_hardness_ = get_hardness\nclass Vault:\n    _items = [Diamond()]\n    def insert(self, idx, item):\n        self._items.insert(idx, item)\n    def append(self, item):\n        self._items.append(item)\n    def pop(self, idx=-1):\n        if len(self._items) <= 1:\n            return None\n        return self._i

In [56]:
from cupbearer.data import DatasetConfig
from cupbearer.tasks import TaskConfig
from cupbearer.models import ModelConfig
from dataclasses import dataclass

In [57]:
class DiamondDataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset
    
    def __getitem__(self, idx):
        sample = self.dataset[idx]
        return sample["text"], sample["measurements"]

@dataclass
class DiamondDatasetConfig(DatasetConfig):
    easy: bool = True

    @property
    def num_classes(self):
        # We need to make 3 binary predictions, i.e. 8 possibilities.
        # In terms of losses, we do need to treat them separately, but num_classes
        # is mainly used to figure out shapes.
        # TODO: should maybe generalize to something like `target_shape`?
        return 8
    
    def _build(self):
        dataset = load_dataset(f"redwoodresearch/diamonds-seed0", split="train")
        if self.easy:
            dataset = dataset.filter(lambda example: example['difficulty'] == 0)
        else:
            dataset = dataset.filter(lambda example: example['difficulty'] == 2)
        return DiamondDataset(dataset)

In [None]:
@dataclass
class TransformerConfig(ModelConfig):
    model: str = "pythia-14m"

    def build_model(self):
        return Transformer(self.model)

@dataclass(kw_only=True)
class DiamondTask(TaskConfig):
    def _init_train_data(self):
        self._train_data = DiamondDatasetConfig(easy=True)

    def _get_anomalous_test_data(self):
        return DiamondDatasetConfig(easy=False)

    def _init_model(self):
        self._model = TransformerConfig(model="pythia-14m")