In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformer_lens import HookedTransformer
from interpolated_ffn import ModelWithBilinearLayer, load_layer
from sae_lens import SAE
from tqdm.notebook import tqdm
from IPython.display import IFrame
import safetensors
import einops

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f49f42ff700>

In [2]:
device = "cuda"
dtype = torch.bfloat16
model_name = "gemma-2-2b"
model_pretrained = HookedTransformer.from_pretrained_no_processing(model_name, device = device, dtype=dtype)
layer = 18

# Bilinear layer with original weights 
model_bilinear = ModelWithBilinearLayer(model_pretrained, layer)

# Bilinear layers with different finetuning strategies
model_bilinear_logit_mse = ModelWithBilinearLayer(model_pretrained, layer)
model_bilinear_logit_mse.ffn.load_layer("layer-18-step-20000-logit-mse.safetensors")
model_bilinear_output_mse = ModelWithBilinearLayer(model_pretrained, layer)
model_bilinear_output_mse.ffn.load_layer("layer-18-step-20000-output-mse.safetensors")
model_bilinear_ce = ModelWithBilinearLayer(model_pretrained, layer)
model_bilinear_ce.ffn.load_layer("layer-18-step-20000-ce.safetensors")

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

Loaded pretrained model gemma-2-2b into HookedTransformer


In [4]:
data = np.memmap('data_train.bin', dtype=np.uint32, mode='r')

def sample_batch(size, seq_len):
    indices = torch.randint(len(data) - seq_len - 1, (size,))
    xs = torch.stack([torch.from_numpy(data[i:i+seq_len].astype(np.int64)) for i in indices])
    ys = torch.stack([torch.from_numpy(data[i+1:i+seq_len+1].astype(np.int64)) for i in indices])
    return xs.to(device), ys.to(device)

Evaluating the performance of the different versions of the model on samples from the dataset.

In [5]:
def eval_model(model):
    torch.manual_seed(12345)
    seq_len = 1024
    batch_size = 25
    batches = 100
    loss = 0.0
    with torch.no_grad():
        for i in range(batches):
            tokens, next_tokens = sample_batch(batch_size, seq_len)
            logits = model(tokens)
            loss += F.cross_entropy(logits.view(-1, logits.size(-1)), next_tokens.view(-1)).item()
    return loss / batches

print("Pretrained:", eval_model(model_pretrained))
print("Bilinear untuned:", eval_model(model_bilinear))
print("Bilinear logit reconstruction:", eval_model(model_bilinear_logit_mse))
print("Bilinear layer output reconstruction:", eval_model(model_bilinear_output_mse))
print("Bilinear cross-entropy on next token prediction:", eval_model(model_bilinear_ce))

Pretrained: 3.6328125
Bilinear untuned: 3.80359375
Bilinear logit reconstruction: 3.66703125
Bilinear layer output reconstruction: 3.71015625
Bilinear cross-entropy on next token prediction: 3.1803125


In [6]:
sae, cfg_dict, sparsity = SAE.from_pretrained(
    release = "gemma-scope-2b-pt-res-canonical",
    sae_id = f"layer_{layer}/width_16k/canonical",
    device = device
)
sae.eval()

SAE(
  (activation_fn): ReLU()
  (hook_sae_input): HookPoint()
  (hook_sae_acts_pre): HookPoint()
  (hook_sae_acts_post): HookPoint()
  (hook_sae_output): HookPoint()
  (hook_sae_recons): HookPoint()
  (hook_sae_error): HookPoint()
)

Evaluating the performance of the SAE with the different versions of the model.

In [7]:
def reconstruction_loss_and_l0(run_fn):
    batches = 100
    batch_size = 25
    seq_len = 1024
    loss = 0.0
    l0 = 0.0
    for i in range(batches):
        batch, _ = sample_batch(batch_size, seq_len)
        res = run_fn(batch)

        res = res.view(-1, res.size(-1))
        batch = batch.view(-1)

        res = res[batch != model_pretrained.tokenizer.bos_token_id]

        feature_activations = sae.encode(res)
        reconstructed = sae.decode(feature_activations)

        l0 += (feature_activations > 0.0).sum().item() / (batch_size * seq_len)
        activations = None
        loss += F.mse_loss(res, reconstructed).item()

    return loss / batches, l0 / batches

def get_activations_bilinear(model, tokens):
    x = model.model(tokens, stop_at_layer=layer)
    return model.newlayer(x)

print("Pretrained:", reconstruction_loss_and_l0(lambda x: model_pretrained(x, stop_at_layer=layer+1)))
print("Bilinear untuned:", reconstruction_loss_and_l0(lambda x: get_activations_bilinear(model_bilinear, x)))
print("Bilinear logit reconstruction:", reconstruction_loss_and_l0(lambda x: get_activations_bilinear(model_bilinear_logit_mse, x)))
print("Bilinear output reconstruction:", reconstruction_loss_and_l0(lambda x: get_activations_bilinear(model_bilinear_output_mse, x)))
print("Bilinear cross-entropy on next token prediction:", reconstruction_loss_and_l0(lambda x: get_activations_bilinear(model_bilinear_ce, x)))

Pretrained: (5.0151436281204225, 67.59741718750003)
Bilinear untuned: (7.057553396224976, 42.397410937500005)
Bilinear logit reconstruction: (6.138280916213989, 53.94823281249998)
Bilinear output reconstruction: (5.741299681663513, 56.174267578124976)
Bilinear cross-entropy on next token prediction: (6.763096032142639, 57.730014843749984)


The code below calculates the interaction matrix, as described in [this paper](https://arxiv.org/pdf/2406.03947).

In [51]:
ffn = model_bilinear_logit_mse.ffn

def interaction_matrix(u):
    v = ffn.V @ u
    interaction_matrices = einops.einsum(ffn.W1[:, :4608], ffn.W2[:, :4608], v[:4608], 'i j, k j, j -> i k') + einops.einsum(ffn.W1[:, 4608:], ffn.W2[:, 4608:], v[4608:], 'i j, k j, j -> i k')
    return interaction_matrices

def symm_interaction_matrix(u):
    m = interaction_matrix(u)
    return m + m.T

u = torch.randn(2304).to(device=device, dtype=dtype)
symm_interaction_matrix(u)


torch.Size([9216])


tensor([[ 2.0447e-03, -2.8381e-03,  1.1749e-03,  ...,  2.5024e-03,
          9.3842e-04, -5.4932e-04],
        [ 2.6703e-05, -3.6621e-04,  3.2959e-03,  ..., -1.5411e-03,
          1.2817e-03, -2.0695e-04],
        [ 8.6212e-04, -2.3346e-03, -2.7008e-03,  ...,  5.9891e-04,
          2.1210e-03,  9.6893e-04],
        ...,
        [ 7.2098e-04, -9.1553e-04, -2.7161e-03,  ...,  2.5024e-03,
          1.7014e-03, -5.5313e-04],
        [-4.6539e-04, -3.2959e-03, -4.8828e-04,  ...,  5.5313e-04,
         -1.3733e-03, -1.1673e-03],
        [ 1.3123e-03, -8.3542e-04,  3.3722e-03,  ..., -2.2278e-03,
         -3.5095e-04, -5.5695e-04]], device='cuda:0', dtype=torch.bfloat16)