In [1]:
%load_ext autoreload
%autoreload 2

# needed for set_determinism
%set_env CUBLAS_WORKSPACE_CONFIG=:16:8

env: CUBLAS_WORKSPACE_CONFIG=:16:8


## Setup

In [2]:
from collections.abc import Callable
import pandas as pd
import seaborn as sns
import math

import matplotlib.pyplot as plt
from collections import defaultdict
import numpy as np
import torch
from tqdm import tqdm
from transformer_lens import HookedTransformer
from datasets import Dataset, load_dataset
from sae_lens.toolkit.pretrained_saes import get_gpt2_res_jb_saes
from typing import Optional, Tuple
from jaxtyping import Float
from abc import ABC, abstractmethod
from dataclasses import dataclass
from random import choice, shuffle
from typing import final, cast
import torch.nn.functional as F
import random
import gc

In [3]:
def clean_cache():
    torch.cuda.empty_cache()
    gc.collect()


def set_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)


def get_device_str() -> str:
    if torch.backends.mps.is_available():
        return "mps"
    else:
        return "cuda" if torch.cuda.is_available() else "cpu"


# Utils
def generate_prompt(dataset, n_ctx: int = 1, batch: int = 1) -> torch.Tensor:
    """Generate a prompt from the dataset."""
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch, shuffle=True)
    return next(iter(dataloader))["input_ids"][:, :n_ctx]


def compute_kl_div(logits_ref: torch.Tensor, logits_pert: torch.Tensor) -> torch.Tensor:
    """Compute the KL divergence between the reference and perturbed logprobs."""
    logprobs_ref = F.log_softmax(logits_ref, dim=-1)
    logprobs_pert = F.log_softmax(logits_pert, dim=-1)
    temp_output = F.kl_div(
        logprobs_ref, logprobs_pert, log_target=True, reduction="none"
    )
    return temp_output.sum(dim=-1)


def get_random_activation(
    model: HookedTransformer, dataset: Dataset, n_ctx: int, layer: str, pos
) -> torch.Tensor:
    """Get a random activation from the dataset."""
    rand_prompt = generate_prompt(dataset, n_ctx=n_ctx)
    _, cache = model.run_with_cache(rand_prompt)
    return cache[layer][:, pos, :].to("cpu").detach()


def load_pretokenized_dataset(
    path: str,
    split: str,
) -> Dataset:
    dataset = load_dataset(path, split=split)
    dataset = cast(Dataset, dataset)
    return dataset.with_format("torch")


def get_random_activations(
    model: HookedTransformer, dataset: Dataset, n_ctx: int, layer: str, pos, n_samples
) -> torch.Tensor:
    """Get a random activation from the dataset."""
    rand_prompts = torch.cat(
        [generate_prompt(dataset, n_ctx=n_ctx) for _ in range(n_samples)]
    )
    _, cache = model.run_with_cache(rand_prompts)
    return cache[layer][:, pos, :].to("cpu").detach()


def cosine_similarity(a, b):
    return F.cosine_similarity(a.unsqueeze(0), b.unsqueeze(0)).item()

In [4]:
@dataclass
class ExperimentConfig:
    n_ctx: int
    perturbation_layer: str
    read_layer: str
    perturbation_pos: slice
    n_steps: int
    perturbation_range: Tuple[float, float]
    seed: Optional[int] = None
    dataloader_batch_size: Optional[int] = None
    mean_batch_size: Optional[int] = None


class Reference:
    def __init__(
        self,
        model: HookedTransformer,
        prompt: torch.Tensor,
        perturbation_layer: str,
        read_layer: str,
        perturbation_pos: slice,
        n_ctx: int,
    ):
        self.model = model
        _, n_ctx_prompt = prompt.shape
        assert (
            n_ctx == n_ctx_prompt
        ), f"n_ctx {n_ctx} must match prompt n_ctx {n_ctx_prompt}"
        self.prompt = prompt
        logits, cache = model.run_with_cache(prompt)
        self.logits = logits.to("cpu").detach()
        self.cache = cache.to("cpu")
        self.act = self.cache[perturbation_layer][:, perturbation_pos]
        self.perturbation_layer = perturbation_layer
        self.read_layer = read_layer
        self.perturbation_pos = perturbation_pos
        self.n_ctx = n_ctx

In [5]:
cfg = ExperimentConfig(
    n_ctx=10,
    perturbation_layer="blocks.0.hook_resid_pre",
    seed=9944,
    dataloader_batch_size=15,
    perturbation_pos=slice(-1, None, 1),
    read_layer="blocks.11.hook_resid_post",
    perturbation_range=(0, 1),
    n_steps=100,
    mean_batch_size=512,
)

In [6]:
set_seed(cfg.seed)

In [7]:
dataset = load_pretokenized_dataset(
    path="apollo-research/Skylion007-openwebtext-tokenizer-gpt2", split="train"
)
dataloader = torch.utils.data.DataLoader(
    dataset, batch_size=cfg.dataloader_batch_size, shuffle=True
)

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/75 [00:00<?, ?it/s]

In [8]:
model = HookedTransformer.from_pretrained("gpt2")
device = get_device_str()

print(device)

Loaded pretrained model gpt2 into HookedTransformer
cuda


In [9]:
saes, sparsities = get_gpt2_res_jb_saes(cfg.perturbation_layer)
sae = saes[cfg.perturbation_layer].cpu()
feature_sparsities = 10 ** sparsities[cfg.perturbation_layer].cpu()

sae.eval()

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  1.89it/s]


SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

## Experiments

In [10]:
import torch
import torch.nn.functional as F
from uuid import uuid4
import random
import json
import math
from jaxtyping import Float, Int
import os
from datetime import datetime
import gc

In [11]:
def generate_id():
    return str(uuid4())


def get_k_random_prompts(dataset, n_ctx=10, k=100):
    n_total = len(dataset)
    idxs = random.sample(range(n_total), k)
    input_ids = dataset[idxs]["input_ids"][:, :n_ctx]
    return input_ids.to("cpu")  # Move to CPU immediately


def collect_info_act(cfg, act, sae, feature_sparsities):
    with torch.no_grad():
        all_feature_acts = sae.encode(act.cpu().unsqueeze(0)).squeeze(0)

    active_mask = all_feature_acts > 0.0
    active_feature_ids = active_mask.nonzero().squeeze().tolist()
    active_feature_acts = all_feature_acts[active_mask].tolist()
    active_feature_sparsities = feature_sparsities[active_mask].to(device).tolist()
    num_active_features = len(active_feature_ids)

    return {
        "active_feature_ids": active_feature_ids,
        "active_feature_acts": active_feature_acts,
        "active_feature_sparsities": active_feature_sparsities,
        "num_active_features": num_active_features,
    }


def perturb_activation(start_act, end_act, num_steps=100):
    device = start_act.device
    t = torch.linspace(0, 1, num_steps, device=device).unsqueeze(1)
    return start_act * (1 - t) + end_act * t


def run_with_perturbation(cfg, model, prompt, perturbed_acts):
    def hook(act, hook):
        act[:, -1, :] = perturbed_acts

    prompts = prompt.repeat(perturbed_acts.shape[0], 1)

    with model.hooks(fwd_hooks=[(cfg.perturbation_layer, hook)]):
        logits_pert, cache = model.run_with_cache(prompts)

    return logits_pert, cache


def comp_js_divergence(
    p_logit: Float[torch.Tensor, "*batch vocab"],
    q_logit: Float[torch.Tensor, "*batch vocab"],
) -> Float[torch.Tensor, "*batch"]:
    p_logprob = torch.log_softmax(p_logit, dim=-1)
    q_logprob = torch.log_softmax(q_logit, dim=-1)
    p = p_logprob.exp()
    q = q_logprob.exp()

    # convert to log2
    p_logprob *= math.log2(math.e)
    q_logprob *= math.log2(math.e)

    m = 0.5 * (p + q)
    m_logprob = m.log2()

    p_kl_div = (p * (p_logprob - m_logprob)).sum(-1)
    q_kl_div = (q * (q_logprob - m_logprob)).sum(-1)

    assert p_kl_div.isfinite().all()
    assert q_kl_div.isfinite().all()
    return (p_kl_div + q_kl_div) / 2


@torch.no_grad()
def collect_observation(
    cfg,
    model,
    start_act,
    end_act,
    start_prompt,
    start_logits,
    start_logprobs,
    sae,
    feature_sparsities,
):
    device = start_act.device
    perturbed_acts = perturb_activation(start_act, end_act, cfg.n_steps).to(device)

    # Verify that we start at source and end at target
    assert torch.allclose(
        perturbed_acts[0], start_act, atol=1e-5
    ), "Doesn't start at source"
    assert torch.allclose(
        perturbed_acts[-1], end_act, atol=1e-5
    ), "Doesn't end at target"

    pert_logits, cache = run_with_perturbation(cfg, model, start_prompt, perturbed_acts)

    read_layer_l2_norms = torch.norm(
        cache[cfg.read_layer][:, -1, :] - cache[cfg.read_layer][0, -1, :],
        dim=1,
    )

    pert_logprobs = F.log_softmax(pert_logits[:, -1, :], dim=-1)
    kl_divs = F.kl_div(
        pert_logprobs, start_logprobs, log_target=True, reduction="none"
    ).sum(dim=-1)
    # kl_divs = (pert_logprobs.exp() * (pert_logprobs - start_logprobs)).sum(dim=-1)

    js_divs = comp_js_divergence(pert_logits[:, -1, :], start_logits[:, -1, :])
    js_dist = torch.sqrt(js_divs + 1e-8)

    perturbation_steps_metadata = []

    for step, pert_act in enumerate(perturbed_acts):
        cos_sim = F.cosine_similarity(
            pert_act.unsqueeze(0), start_act.unsqueeze(0)
        ).item()
        l2_norm = torch.norm(pert_act - start_act).item()
        act_info = collect_info_act(cfg, pert_act, sae, feature_sparsities)
        perturbation_steps_metadata.append(
            {
                "step": step + 1,
                "kl_div": kl_divs[step].item(),
                "js_div": js_divs[step].item(),
                "js_dist": js_dist[step].item(),
                "read_layer_l2_norm": read_layer_l2_norms[step].item(),
                "cos_sim": cos_sim,
                "l2_norm": l2_norm,
                **act_info,
            }
        )

    del pert_logits, pert_logprobs, kl_divs
    torch.cuda.empty_cache()

    return perturbation_steps_metadata, perturbed_acts.cpu().numpy()


def get_observations(
    cfg, model, dataset, sae, feature_sparsities, k=10, trials_per_prompt=10
):
    device = next(model.parameters()).device
    start_prompts = get_k_random_prompts(dataset, k=k)

    for start_prompt in start_prompts:
        with torch.no_grad():
            start_prompt_gpu = start_prompt.to(device)
            start_logits, cache = model.run_with_cache(start_prompt_gpu)
            start_logprobs = F.log_softmax(start_logits[:, -1, :], dim=-1)
            start_act = cache[cfg.perturbation_layer][:, -1, :].squeeze(0)

        end_prompts = get_k_random_prompts(dataset, k=trials_per_prompt)

        for end_prompt in end_prompts:
            with torch.no_grad():
                end_prompt_gpu = end_prompt.to(device)
                _, cache = model.run_with_cache(end_prompt_gpu)
                end_act = cache[cfg.perturbation_layer][:, -1, :].squeeze(0)

            steps_metadata, perturbed_acts = collect_observation(
                cfg,
                model,
                start_act,
                end_act,
                start_prompt_gpu,
                start_logits,
                start_logprobs,
                sae,
                feature_sparsities,
            )

            observation = {
                "id": generate_id(),
                "start_prompt": start_prompt.tolist(),
                "end_prompt": end_prompt.tolist(),
                "steps_metadata": steps_metadata,
                "perturbed_acts": perturbed_acts,
            }
            yield observation

        del (
            start_logits,
            cache,
            start_logprobs,
            start_act,
            end_prompts,
            start_prompt_gpu,
        )
        torch.cuda.empty_cache()
        gc.collect()

In [13]:
def reduce_float_precision(obj, precision=4):
    if isinstance(obj, float):
        return float(np.round(obj, precision))
    elif isinstance(obj, dict):
        return {k: reduce_float_precision(v, precision) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [reduce_float_precision(i, precision) for i in obj]
    return obj


def save_observations_to_disk(
    observations, activations, base_dir="observations", precision=4
):
    if not observations:
        return

    os.makedirs(base_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    filename = f"metadata_{len(observations)}_{timestamp}.jsonl"
    acts_filename = f"acts_{len(observations)}_{timestamp}.npy"
    filepath = os.path.join(base_dir, filename)
    acts_filepath = os.path.join(base_dir, acts_filename)

    with open(filepath, "w") as f:
        for obs in observations:
            reduced_obs = reduce_float_precision(obs, precision)
            json.dump(reduced_obs, f)
            f.write("\n")

    np.save(acts_filepath, np.stack(activations))

    print(f"Saved {len(observations)} observations & acts to {filepath}/npy")

In [13]:
# rm -rf observations

In [None]:
file_counter = 0
max_observations_per_file = 1000

metadata, activations = [], []

for observation in get_observations(
    cfg, model, dataset, sae, feature_sparsities, k=30_000, trials_per_prompt=1
):
    metadata.append(
        {
            "id": observation["id"],
            "start_prompt": observation["start_prompt"],
            "end_prompt": observation["end_prompt"],
            "steps_metadata": observation["steps_metadata"],
        }
    )

    activations.append(observation["perturbed_acts"])

    # Save to disk and clear when we reach the max number of observations per file
    if len(metadata) >= max_observations_per_file:
        save_observations_to_disk(metadata, activations)
        metadata = []
        activations = []
        file_counter += 1

# Save any remaining observations
if metadata:
    save_observations_to_disk(metadata, activations)