In [1]:
import einops
import statistics
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import timeit
import transformer_lens
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

## Naive testing (not proper benchmark)

In [2]:
class SAE(nn.Module):
    def __init__(self, d_in: int, d_sae: int, k: int):
        super().__init__()
        self.d_in = d_in
        self.d_sae = d_sae
        self.k = k
        self.W_enc = nn.Parameter(torch.randn(d_sae, d_in))
        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.W_dec = nn.Parameter(torch.randn(d_in, d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_in))

    def encode(self, x: torch.Tensor):
        acts = F.linear(x, self.W_enc, self.b_enc)
        topk = torch.topk(acts, k=self.k, dim=-1)
        return topk.values, topk.indices #? should a ReLU be here?

    def decode(self, latents: torch.Tensor, indices: torch.Tensor):
        # TODO it might be more efficient to instead pull the correct weights
        # from W_dec rather than scattering, but idk how to make that work with batching
        
        x = torch.zeros(latents.size(0), self.d_sae, device=latents.device)
        x.scatter_(1, indices, latents)
        return F.linear(x, self.W_dec, self.b_dec)

    def forward(self, x: torch.Tensor):
        x, indices = self.encode(x)
        x = self.decode(x, indices)
        return x, indices


In [3]:
class MLP(nn.Module):
    def __init__(self, d_resid: int, d_mlp: int):
        super().__init__()
        self.W_in = nn.Parameter(torch.randn(d_mlp, d_resid))
        self.b_in = nn.Parameter(torch.zeros(d_mlp))
        self.act = nn.ReLU()
        self.W_out = nn.Parameter(torch.randn(d_resid, d_mlp))
        self.b_out = nn.Parameter(torch.zeros(d_resid))

    def forward(self, x: torch.Tensor, return_grad_of_act: bool = False):
        pre = F.linear(x, self.W_in, self.b_in)
        if return_grad_of_act:
            pre.requires_grad_(True)
        post = self.act(pre)
        out = F.linear(post, self.W_out, self.b_out)

        if return_grad_of_act:
            grad_of_act = torch.autograd.grad(
                outputs=post,
                inputs=pre,
                grad_outputs=torch.ones_like(post)
            )[0]
            return out, grad_of_act
        
        return out


In [8]:
d_resid = 768
d_mlp = 3072
d_sae = 64 * d_resid
k = 32
n_tokens = 64

sae = SAE(d_resid, d_sae, k).to(device) # operates both on resid_mid and resid_post
mlp = MLP(d_resid, d_mlp).to(device)

random_input = torch.randn(n_tokens, d_resid, device=device)

In [9]:
wd1 = sae.W_dec.T @ mlp.W_in.T # (d_sae, d_mlp)
w2e = mlp.W_out.T @ sae.W_enc.T # (d_mlp, d_sae)

In [10]:
def get_jacobian_slow(sae: SAE, mlp: MLP, resid_values: torch.Tensor):
    # super lazy way of making sure these are on the same device
    wd1 = sae.W_dec.T @ mlp.W_in.T # (d_sae, d_mlp)
    w2e = mlp.W_out.T @ sae.W_enc.T # (d_mlp, d_sae)

    sae_out, topk_indices1 = sae(resid_values) # (n_tokens, d_resid), (n_tokens, k)
    resid_post, mlp_act_grad = mlp(sae_out, return_grad_of_act=True) # (n_token, d_resid), (n_token, d_mlp)
    _, topk_indices2 = sae.encode(resid_post) # (n_tokens, k)

    num_tokens = resid_values.size(0)
    jacobian = torch.zeros(num_tokens, sae.k, sae.k)
    for token_pos in range(resid_values.size(0)):
        for output_k in range(sae.k):
            output_k_sae_index = topk_indices2[token_pos, output_k]
            for input_k in range(sae.k):
                input_k_sae_index = topk_indices1[token_pos, input_k]
                jacobian[token_pos, input_k, output_k] = (wd1[input_k_sae_index] * mlp_act_grad[token_pos] * w2e[:, output_k_sae_index]).sum()
    
    return jacobian

jacobian1 = get_jacobian_slow(sae.cpu(), mlp.cpu(), random_input.cpu())
jacobian1.shape
sae = sae.to(device)
mlp = mlp.to(device)

In [11]:
def get_jacobian_fast(sae: SAE, mlp: MLP, resid_values: torch.Tensor):
    sae_out, topk_indices1 = sae(resid_values) # (n_tokens, d_resid), (n_tokens, k)
    resid_post, mlp_act_grad = mlp(sae_out, return_grad_of_act=True) # (n_token, d_resid), (n_token, d_mlp)
    _, topk_indices2 = sae.encode(resid_post) # (n_tokens, k)

    return einops.einsum(wd1[topk_indices1], mlp_act_grad, w2e[:, topk_indices2],
                         "seq_pos k1 d_mlp, seq_pos d_mlp, d_mlp seq_pos k2 -> seq_pos k1 k2")

jacobian2 = get_jacobian_fast(sae, mlp, random_input)
assert torch.allclose(jacobian1, jacobian2.cpu(), rtol=1e-5, atol=1e-1)
jacobian2.shape

torch.Size([64, 32, 32])

## For reference, here's how long a GPT forward pass takes

In [12]:
def sync():
    if device.type == "cuda":
        torch.cuda.synchronize()
    elif device.type == "mps":
        torch.mps.synchronize()

In [13]:
model = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")



Loaded pretrained model gpt2-small into HookedTransformer


In [14]:
def fwd_pass():
  model(" Test"*1024)
  sync()

for _ in range(10):
  print(timeit.timeit(fwd_pass, number=1))

0.6559516669949517
0.1877600409789011
0.17127637506928295
0.16607412497978657
0.17193058400880545
0.1921600829809904
0.17315387504640967
0.16536600003018975
0.1684253748971969
0.16900570807047188


## Benchmarking

In [15]:
def get_rand_inputs():
    topk_indices1 = torch.randint(0, d_sae, (n_tokens, k), device=device)
    mlp_act_grad = torch.randint(0, 2, (n_tokens, d_mlp), device=device)
    topk_indices2 = torch.randint(0, d_sae, (n_tokens, k), device=device)
    return topk_indices1, mlp_act_grad, topk_indices2

def jacobian_einsum(topk_indices1, mlp_act_grad, topk_indices2):
    jacobian = einops.einsum(wd1[topk_indices1], mlp_act_grad, w2e[:, topk_indices2],
                             "seq_pos k1 d_mlp, seq_pos d_mlp, d_mlp seq_pos k2 -> seq_pos k1 k2")
    sync()
    return jacobian

def timed_jacobian():
    rand_inputs = get_rand_inputs()
    sync()
    start_time = time.perf_counter()
    jacobian_einsum(*rand_inputs)
    sync()
    end_time = time.perf_counter()
    return end_time - start_time

# Warmup runs
for _ in tqdm(range(10), desc="Warming up"):
    timed_jacobian()

# Timing runs
timings = []
num_runs = 100
for _ in tqdm(range(num_runs), desc="Benchmarking"):
    timings.append(timed_jacobian())

# Calculate statistics
stats = {
    'mean_ms': statistics.mean(timings),
    'median_ms': statistics.median(timings),
    'std_ms': statistics.stdev(timings),
    'min_ms': min(timings),
    'max_ms': max(timings),
    'device': str(device),
    'num_runs': num_runs
}
stats

python(55245) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Warming up: 100%|██████████| 10/10 [00:00<00:00, 22.93it/s]
Benchmarking: 100%|██████████| 100/100 [00:01<00:00, 82.33it/s]


{'mean_ms': 0.01084957709768787,
 'median_ms': 0.010815750458277762,
 'std_ms': 0.000570285220142084,
 'min_ms': 0.009896458010189235,
 'max_ms': 0.012661666958592832,
 'device': 'mps',
 'num_runs': 100}