In [1]:
import re
import os
import sys
import uuid
from typing import (
    Any,
    Callable,
    Iterable,
    Literal,
    Optional,
    Tuple,
    Sequence,
    Type,
    Union,
    List,
    cast,
)
import interp.tools.optional as op
import numpy as np
import rust_circuit as rc
import torch
from interp.circuit.causal_scrubbing.experiment import (
    Experiment,
    ExperimentCheck,
    ExperimentEvalSettings,
    ScrubbedExperiment,
)
from interp.circuit.causal_scrubbing.hypothesis import (
    Correspondence,
    CondSampler,
    ExactSampler,
    FuncSampler,
    InterpNode,
    UncondSampler,
    chain_excluding,
    corr_root_matcher,
)
from interp.circuit.interop_rust.algebric_rewrite import (
    residual_rewrite,
    split_to_concat,
)
from interp.circuit.interop_rust.model_rewrites import To, configure_transformer
from interp.circuit.interop_rust.module_library import load_model_id
from interp.tools.indexer import TORCH_INDEXER as I
from torch.nn.functional import binary_cross_entropy_with_logits
from interp.circuit.testing.notebook import NotebookInTesting
from interp import cui
from interp.ui.very_named_tensor import VeryNamedTensor
from transformers import AutoTokenizer

from datasets import load_dataset
from pprint import pprint
import random
import re
import functools

import transformers
from interp.tools.interpretability_tools import MODELS_DIR

tokenizer = transformers.GPT2TokenizerFast.from_pretrained(
    f"{MODELS_DIR}/gpt2/tokenizer"
)

MODEL_ID = "gelu_12_tied"
SEQ_LEN = 256
NUM_EXAMPLES = 1000

PRINT_CIRCUITS = True
ACTUALLY_RUN = True
SLOW_EXPERIMENTS = True
DEFAULT_CHECKS: ExperimentCheck = True
EVAL_DEVICE = "cuda:0"
MAX_MEMORY = 80000000000
BATCH_SIZE = 32
MLP_WIDTH = 3072

In [2]:
MODEL_ID = "gelu_12_tied"  # aka gpt2 small
circ_dict, _, model_info = load_model_id(MODEL_ID)
unbound_circuit = circ_dict["t.bind_w"]

tokens_arr = rc.Symbol.new_with_random_uuid((SEQ_LEN,), name="tokens")
# We use this to index into the tok_embeds to get the proper embeddings
token_embeds = rc.GeneralFunction.gen_index(
    circ_dict["t.w.tok_embeds"], tokens_arr, 0, name="tok_embeds"
)
bound_circuit = model_info.bind_to_input(
    unbound_circuit, token_embeds, circ_dict["t.w.pos_embeds"]
)

transformed_circuit = rc.conform_all_modules(bound_circuit)
subbed_circuit = transformed_circuit


def module_but_norm(circuit: rc.Circuit):
    if isinstance(circuit, rc.Module):
        if "norm" in circuit.name or "ln" in circuit.name or "final" in circuit.name:
            return False
        else:
            return True
    return False


for i in range(100):
    subbed_circuit = subbed_circuit.update(
        module_but_norm, lambda c: c.cast_module().substitute()
    )

In [3]:
# Get all the neuron activations and create evaluation schedule

dd = rc.TorchDeviceDtype(EVAL_DEVICE, torch.float32)
batched_circuit = rc.cast_circuit(
    rc.Expander(
        (
            "tokens",
            lambda _: rc.Symbol.new_with_random_uuid(
                (BATCH_SIZE, SEQ_LEN),
                name="tokens",
            ),
        )
    )(subbed_circuit),
    dd.op(),
)

layer_activations = [
    batched_circuit.get_unique(
        rc.IterativeMatcher(f"b{i}.call").chain(
            rc.restrict(rc.IterativeMatcher("m.act"), end_depth=5)
        )
    )
    for i in range(12)
]

schedule = rc.optimize_to_schedule(
    rc.Concat.stack(*layer_activations, axis=1),
    rc.OptimizationSettings(
        max_memory=MAX_MEMORY,
        max_single_tensor_memory=MAX_MEMORY,
        device_dtype=dd,
    ),
)
token_hash = batched_circuit.get_unique("tokens").cast_symbol().hash

In [4]:
# Load dataset tokens

dataset = load_dataset("openwebtext", split="train")
dataset = dataset.shuffle()

tokenizer.add_special_tokens({"pad_token": "<|endoftext|>"})

batch_text = dataset["text"][0 : 0 + BATCH_SIZE]
batch_tokens = tokenizer(
    batch_text, return_tensors="pt", padding=True, truncation=True
)["input_ids"]

Reusing dataset openwebtext (/home/ubuntu/.cache/huggingface/datasets/openwebtext/plain_text/1.0.0/5c636399c7155da97c982d0d70ecdce30fbca66a4eb4fc768ad91f8331edac02)
Loading cached shuffled indices for dataset at /home/ubuntu/.cache/huggingface/datasets/openwebtext/plain_text/1.0.0/5c636399c7155da97c982d0d70ecdce30fbca66a4eb4fc768ad91f8331edac02/cache-6fdd753d61517566.arrow


In [5]:
# Replace input with tokens and run

with torch.no_grad():
    neurons = (
        schedule.replace_tensors({token_hash: batch_tokens.cuda()}).evaluate().cpu()
    )

RuntimeError: tensor shape doesn't match entry shape!

: 

In [None]:
# Plot them in the CUI

if not cui.is_port_in_use(6789):
    await cui.init(port=6789)

N_TEXT_TOKENS = 10000

mask = batch_tokens != tokenizer.pad_token_id
tokens_concat = batch_tokens[mask]
text_concat = tokenizer.batch_decode(tokens_concat[:N_TEXT_TOKENS])

neuron_idxs = slice(100)
firings = neurons.permute((1, 3, 0, 2))[:, neuron_idxs, :, :][:, :, mask]

vnt = VeryNamedTensor(
    firings[:, :, :N_TEXT_TOKENS],
    dim_names=["layer", "neuron", "text"],
    dim_types=["head", "head", "seq"],
    dim_idx_names=[
        tuple(map(str, range(12))),
        tuple(map(lambda x: str(x.item()), torch.arange(3072)[neuron_idxs])),
        text_concat,
    ],
    title="firings",
)

await cui.show_tensors(vnt)

In [None]:
# Now try to find some neurons that e.g. encode newline

all_tokens = list(enumerate(tokenizer.batch_decode(range(tokenizer.pad_token_id))))
all_tokens.sort(key=lambda x: x[1])
all_tokens

In [None]:
# Calculate correlation between token==newline and various neurons

def try_concept(fn):
    is_newline = fn(tokens_concat)

    all_vars = torch.stack(
        [
            torch.corrcoef(
                torch.cat([is_newline.to(torch.float32)[None, :], firings[i]], dim=0)
            )
            for i in range(firings.shape[0])
        ],
        dim=0,
    )

    vs = all_vars[:, 1:, 0].reshape(-1)
    topk = torch.topk(vs.abs(), 10)
    return list(
        (a.item(), b.item())
        for (a, b) in zip(topk.indices % 3072, topk.indices // 3072)
    )

In [None]:
# Newlines concept

# The three newline tokens are
# (198, '\n'),
# (628, '\n\n'),
# (44320, '\n\xa0'),

# try_concept(
#     lambda tokens_concat: (tokens_concat == 198)
#     | (tokens_concat == 628)
#     | (tokens_concat == 44320)
# )

In [None]:
preposition_list = """
about
above
across
after
against
among
around
at
before
behind
below
beside
between
by
down
during
for
from
in
inside
into
near
of
off
on
out
over
through
to
toward
under
up
with
""".strip().split()


preposition_tokens = [
    (i, t)
    for (i, t) in all_tokens
    if any(re.match(f"^ *{p} *$", t.lower()) for p in preposition_list)
]


def is_preposition(array: torch.Tensor) -> torch.Tensor:
    m = array == preposition_tokens[0][0]
    for i, _ in preposition_tokens[1:]:
        m |= array == i
    return m


try_concept(is_preposition)