In [None]:
import os
if "models" not in os.listdir("."):
    os.chdir("..")

In [None]:
%load_ext autoreload
%autoreload 2
import penzai
from penzai import pz
pz.ts.register_as_default()
pz.ts.register_autovisualize_magic()
pz.enable_interactive_context()

In [None]:
from micrlhf.llama import LlamaTransformer
llama = LlamaTransformer.from_pretrained("models/gemma-2b-it.gguf", from_type="gemma", load_eager=True)

In [None]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("alpindale/gemma-2b")
tokenizer.padding_side = "right"

In [16]:
import json
save_path = "data/gemma-2b-explanations.json"
if os.path.exists(save_path):
    with open(save_path, "r") as f:
        index_explanations = {int(k): v for k, v in json.load(f).items()}

In [23]:
from tqdm.auto import trange
import requests
max_neurons = 40_000
try:
    index_explanations
except NameError:
    index_explanations = {}
step = 25
for offset in trange(max(index_explanations) // step, max_neurons, step):
    gemma_neurons = requests.post("https://www.neuronpedia.org/api/neurons-offset", json={"modelId": "gemma-2b", "layer": "6-res-jb", "offset": offset}).json()
    for n in gemma_neurons:
        for e in n["explanations"]:
            index_explanations[int(n["index"])] = dict(
                explanation=e["description"],
                max_acts=[dict(
                    tokens=a["tokens"],
                    values=a["values"],
                ) for a in n["activations"]]
            )

  0%|          | 0/1584 [00:00<?, ?it/s]

In [25]:
max(index_explanations)

In [24]:
if not os.path.exists(save_path):
    with open(save_path, "w") as f:
        json.dump(index_explanations, f)
else:
    existing = {int(k): v for k, v in json.load(open(save_path)).items()}
    existing.update(index_explanations)
    with open(save_path, "w") as f:
        json.dump(existing, f)

In [None]:
!mkdir -p models/sae
!wget -c 'https://huggingface.co/jbloom/Gemma-2b-Residual-Stream-SAEs/resolve/main/gemma_2b_blocks.6.hook_resid_post_16384_anthropic_fast_lr/sae_weights.safetensors?download=true' -O 'models/sae/gemma-jb-6.safetensors'
from safetensors import safe_open
with safe_open("models/sae/gemma-jb-6.safetensors", framework="numpy") as st:
    w_dec = st.get_tensor("W_dec")

In [None]:
from micrlhf.utils.activation_manipulation import ActivationReplacement
from micrlhf.utils.activation_manipulation import replace_activation, collect_activations
from micrlhf.sampling import sample
from penzai.toolshed.jit_wrapper import Jitted
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
import random


FEAT_BATCH = 1
PICK_BATCH = 64
MIN_SCALE = 0
MAX_SCALE = 200
REP_LAYER = 2
MAX_LENGTH = 64
PROBE_LAYER = 16
CFG = 1.0
PROMPT_TEMPLATE = '<start_of_turn>user\nWhat is the meaning of the word "X"?<end_of_turn>\n<start_of_turn>model\nThe meaning of the word "X" is "'
POSITIONS = tuple(i for i, a in enumerate(tokenizer.encode(PROMPT_TEMPLATE)) if tokenizer.decode([a]) == "X")

embeds = llama.select().at_instances_of(pz.nn.EmbeddingLookup).get_sequence()[0].table.embeddings.value.unwrap("vocabulary", "embedding")
embed_mean = embeds.mean(axis=0)
embed_vector = embed_mean / jnp.linalg.norm(embed_mean)
tiled_embed = jnp.tile(embed_vector, (PICK_BATCH, 1))
act_rep_base = Jitted(collect_activations(replace_activation(llama, tiled_embed, POSITIONS, layer=REP_LAYER)))

def benchmark_vector(vector, tokens, positions, replacement_layer):
    assert replacement_layer == REP_LAYER
    dumb = False
    if vector.ndim == 1:
        dumb = True
        vector = jnp.tile(vector[None, :], (PICK_BATCH, 1))
    assert vector.shape == tiled_embed.shape
    assert positions == POSITIONS
    # act_rep = collect_activations(replace_activation(llama, vector, positions, layer=replacement_layer))
    act_rep = act_rep_base.select().at_instances_of(ActivationReplacement).apply(lambda x: ActivationReplacement.replace_vector(x, vector))
    logits, residuals = act_rep(tokens)
    result = logits, [r.value for r in residuals]
    if dumb:
        return result[0].untag("batch")[:1].tag("batch"), [r.untag("batch")[:1].tag("batch") for r in result[1]]
    return result


def tokens_to_inputs(tokens):
    token_array = jnp.asarray(tokens)
    token_array = jax.device_put(token_array, jax.sharding.NamedSharding(llama.mesh, jax.sharding.PartitionSpec("dp", "sp")))
    token_array = pz.nx.wrap(token_array, "batch", "seq").untag("batch").tag("batch")

    inputs = llama.inputs.from_basic_segments(token_array)
    return inputs


# def logits_to_loss(logits, tokens, answer_start, pad_token=32000):
#     logits = jax.nn.log_softmax(logits)

#     logits = logits[:, :-1]
#     logits = jnp.take_along_axis(logits, tokens[:, 1:, None], axis=-1).squeeze(-1)

#     mask = tokens[:, 1:] != pad_token

#     mask[:, :answer_start-1] = False

#     logits = logits * mask

#     return -logits.sum(axis=-1) / mask.sum(axis=-1)


def pick_scale(feature, batch_size=PICK_BATCH, min_scale=MIN_SCALE, max_scale=MAX_SCALE, layer=REP_LAYER):
    scales = np.linspace(min_scale, max_scale, batch_size)
    vector = feature[None, :] * jnp.array(scales)[:, None]
    text = [PROMPT_TEMPLATE for _ in range(batch_size)]
    tokenized = tokenizer(text, return_tensors="np", padding="max_length", max_length=64, truncation=True)
    tokens = tokenized["input_ids"]
    inputs = tokens_to_inputs(tokens)
    
    logits, residuals = benchmark_vector(vector, inputs, POSITIONS, layer)
    
    # rand_vector = jax.random.normal(key=jax.random.key(random.randint(-10, 10)), shape=vector.shape)
    # rand_vector = rand_vector / jnp.linalg.norm(rand_vector, axis=-1, keepdims=True)
    # logits_rand, _ = benchmark_vector(rand_vector, inputs, positions, layer)
    
    logits_mean, _ = benchmark_vector(embed_vector, inputs, POSITIONS, layer)

    logits = logits.unwrap("batch", "seq", "vocabulary")
    entropies = -jnp.sum(jax.nn.log_softmax(logits) * jnp.exp(jax.nn.log_softmax(logits)), axis=-1)
    entropy_first = entropies[:, -1]
    
    resid_cos_features = []
    for residual in residuals:
        resid = residual.unwrap("batch", "seq", "embedding")[:, -1]
        resid_cos_feature = resid @ feature / jnp.linalg.norm(resid) / jnp.linalg.norm(feature)
        resid_cos_feature = resid_cos_feature - resid_cos_feature[0]
        resid_cos_features.append(resid_cos_feature)
    entropy_first = entropy_first - entropy_first[0]
    
    crossents = []
    # for baseline in (logits_rand, logits_mean):
    for baseline in (logits_mean,):
        baseline = baseline.unwrap("batch", "seq", "vocabulary")
        # baseline_probs = jax.nn.softmax(baseline)
        # crossents.append(-jnp.sum(jax.nn.log_softmax(logits) * baseline_probs, axis=-1)[:, -1])
        crossents.append(-jnp.sum(jax.nn.log_softmax(baseline) * jax.nn.softmax(logits), axis=-1)[:, -1])
    
    return scales, entropy_first, resid_cos_features, crossents

def generate_explanations(feature=None, batch_size=32, min_scale=MIN_SCALE, max_scale=MAX_SCALE, layer=REP_LAYER, cfg=CFG, for_cache=False, cached=None, cache_batch=FEAT_BATCH):
    if feature is None:
        feature = w_dec[0]
        if cache_batch > 1:
            feature = jnp.tile(feature[None], (cache_batch, 1))
    scales = np.linspace(min_scale, max_scale, batch_size, dtype=np.float32)
    if feature.ndim == 1:
        vector = feature[None, :] * jnp.array(scales)[:, None]
    else:
        vector = feature[:, None, :] * jnp.array(scales)[None, :, None]
        vector = vector.reshape(-1, vector.shape[-1])
    if cfg != 1.0:
        vector = jnp.concatenate([embed_vector[None, :] * jnp.array(scales)[:, None], vector], axis=0)
    if cached is not None:
        act_rep = cached[0].select().at_instances_of(ActivationReplacement).apply(lambda x: ActivationReplacement.replace_vector(x, vector)), cached[1]
        # act_rep = cached
    else:
        act_rep = replace_activation(llama, vector, POSITIONS, layer=layer)
    completions, model = sample(act_rep, tokenizer,
                         PROMPT_TEMPLATE, batch_size=(batch_size if cfg == 1.0 else batch_size * 2) * (1 if feature.ndim == 1 else feature.shape[0]),
                         do_sample=True, max_seq_len=MAX_LENGTH,
                         return_only_completion=True,
                         verbose=False, cfg_strength=cfg,
                         return_model=for_cache, only_cache=for_cache)
    if for_cache:
        return model
    if feature.ndim == 1:
        return list(zip(np.concatenate((scales, scales)), completions))[batch_size if cfg != 1.0 else 0:]
    else:
        completions = np.array(completions).reshape(-1, batch_size if cfg == 1.0 else 2 * batch_size)
        explanations = []
        for c in completions:
            explanations.append(list(zip(np.concatenate((scales, scales)), c))[batch_size if cfg != 1.0 else 0:])
        return explanations
    
cached_model = generate_explanations(for_cache=True)

In [None]:
from matplotlib import pyplot as plt
from tqdm.auto import tqdm
import json


generations_filename = "data/gemma-2b-generations.json"
if os.path.exists(generations_filename):
    with open(generations_filename, "r") as f:
        generations = {int(k): v for k, v in json.load(f).items()}
else:
    generations = {}
plt.style.use("seaborn-v0_8-darkgrid")
showy = False
random.seed(9)
feat_batch = FEAT_BATCH
data_points = []
try:
    for i, e in tqdm(list(index_explanations.items())):
        if i in generations:
            continue
        scales, entropy, selfsims, crossents = pick_scale(w_dec[i])
        selfsim = selfsims[PROBE_LAYER]
        scale_idx = np.argmax(selfsim[1:]) - 1
        highest = selfsim[scale_idx]
        scale = scales[scale_idx]
        scale = scales[scale_idx]
        data_point = dict(
            feature=i,
            explanation=e["explanation"],
            settings=dict(
                min_scale=MIN_SCALE,
                max_scale=MAX_SCALE,
                rep_layer=REP_LAYER,
                probe_layer=PROBE_LAYER,
                cfg=CFG,
            ),
            scale_tuning=dict(
                scales=scales.tolist(),
                entropy=entropy.tolist(),
                selfsims=[s.tolist() for s in selfsims],
                crossents=[c.tolist() for c in crossents],
            )
        )
        data_points.append(data_point)
        if len(data_points) < feat_batch:
            continue
        explanations = generate_explanations(jnp.stack([w_dec[p["feature"]] for p in data_points]), cached=cached_model)
        for p, e in zip(data_points, explanations):
            if showy:
                st = p["scale_tuning"]
                plt.plot(st["scales"], st["entropy"], label=f"Entropy")
                for l in range(10, PROBE_LAYER + 1):
                    plt.plot(st["scales"], np.array(st["selfsims"][l]) / max(st["selfsims"][l]) * max(st["selfsims"][PROBE_LAYER]), label=f"Self-similarity [{l}]")
                for j, c in enumerate(crossents):
                    plt.plot(st["scales"], c, label=f"Cross-entropy ({j})")
                plt.title(f"Feature {p['feature']}: \"{p['explanation']}\"")
                plt.xlabel("Scale")
                # plt.ylim(-10, 10)
                plt.legend()
                plt.show()
            if showy:
                display(e)
            data_point = dict(
                **p,
                generations=[(float(a), b) for a, b in e],
            )
            generations[p["feature"]] = data_point
        
        data_points = []
except KeyboardInterrupt:
    pass
json.dump(generations, open(f"data/gemma-2b-generations.json", "w"))

In [15]:
json.dump(generations, open(f"data/gemma-2b-generations.json", "w"))