Imports

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
import gc
import einops
import numpy as np

Setup

In [None]:
torch.set_grad_enabled(False)
model_name = "google/gemma-7b"

hf_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32, attn_implementation="eager") # trust_remote_code=True, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(model_name) # add_bos_token = True, use_fast=False, trust_remote_code=True)
hf_model.eval().cuda()

Demo Input

In [None]:
text = '''
TransformerLens lets you load in 50+ different open source language models,
and exposes the internal activations of the model to you. You can cache
any internal activation in the model, and add in functions to edit, remove
or replace these activations as the model runs.
'''
input_ids = tokenizer(text, return_tensors='pt')['input_ids'].cuda()

Generate HF Outputs

In [None]:
with torch.no_grad():
    outputs = hf_model.generate(input_ids, max_length=1000)
    text = tokenizer.batch_decode(outputs)[0]
    print(text)

Store Hugging Face model logits and resid_pre cache

In [None]:
with torch.no_grad():
    hf_outputs = hf_model(input_ids, output_hidden_states=True, output_attentions = True)
    hf_logits_cpu = hf_outputs["logits"].cpu()
    hf_resid_pre_cache = hf_outputs["hidden_states"]
    hf_attentions = hf_outputs["attentions"]
    hf_resid_pre_cache_cpu = [cache.cpu() for cache in hf_resid_pre_cache]
    hf_attentions_cpu = [att.cpu() for att in hf_attentions]
    hf_outputs = hf_model(input_ids, labels=input_ids)
    hf_loss_cpu = hf_outputs.loss.cpu()

del hf_model
del hf_outputs
del hf_resid_pre_cache
gc.collect()
torch.cuda.empty_cache()

Load in Model to TL

In [3]:
hooked_model = HookedTransformer.from_pretrained(model_name,
                                        tokenizer=tokenizer,
                                        fold_ln=False,
                                        fold_value_biases=False,
                                        center_writing_weights=False)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

: 

Store Hooked Model Logits and Resid Pre Cache

In [None]:
hooked_model.cfg.act_fn

In [None]:
with torch.no_grad():
    hooked_model_logits, hooked_model_cache = hooked_model.run_with_cache(input_ids)
    hooked_model_loss = hooked_model(input_ids, return_type='loss')
    hooked_model_loss_cpu = hooked_model_loss.cpu()
    hooked_model_logits_cpu = hooked_model_logits.detach().cpu()
    hooked_model_cache_cpu = {k: v.cpu() for k, v in hooked_model_cache.items()}
    n_layers = hooked_model.cfg.n_layers

# del hooked_model
del hooked_model_logits
del hooked_model_cache
del hooked_model_loss

gc.collect()
torch.cuda.empty_cache()

Compare Logits

In [None]:
centered_hf_logits = hf_logits_cpu - hf_logits_cpu.mean(-1, keepdim=True)
mean_diff = (hooked_model_logits_cpu - centered_hf_logits).mean()
print("avg logits difference:", mean_diff.item())
max_diff = (hooked_model_logits_cpu - centered_hf_logits).abs().max()
print("max logits difference:", max_diff.item())

In [None]:
torch.allclose(hooked_model_logits_cpu, centered_hf_logits, atol = 5e-4, rtol = 5e-4)

Compare Resid Pre Cache

In [None]:
pass_loose_bound = True
print("*"*5, "Matching hf and T-Lens residual stream in between transformer blocks", "*"*5)
atol = rtol = 5e-5
print("*"*5, f"\ttesting with {atol=} and {rtol=}\t","*"*5)
for l in range(n_layers):
    try:
        torch.testing.assert_close(hooked_model_cache_cpu[f'blocks.{l}.hook_resid_pre'], hf_resid_pre_cache_cpu[l], atol=atol, rtol=rtol)
    except:
        max_diff = (hooked_model_cache_cpu[f'blocks.{l}.hook_resid_pre'] - hf_resid_pre_cache_cpu[l]).abs().max()
        print(f"layer {l} \t not close, max difference: {max_diff}")
        pass_loose_bound = False

if pass_loose_bound:
    print(f"All layers match with {atol=} {rtol=}")
else:
    print("All layers match")

Compare Next Token Loss

In [None]:
print("T-Lens next token loss:", hooked_model_loss_cpu.item())
print("HF next token loss:", hf_loss_cpu.item())
diff_in_loss = (hf_loss_cpu - hooked_model_loss_cpu).abs().item()
print("diff in loss (abs):", diff_in_loss)
try:
    assert diff_in_loss < 1e-6, f"Difference in loss {diff_in_loss} is greater than 1e-6"
    print("Assertion passed: Difference in loss is within the tolerance.")
except AssertionError as e:
    print(e)

Check TL Functionality

In [None]:
from transformer_lens.utils import test_prompt

In [None]:
test_prompt(prompt = " What is the capital of France?", answer = " Paris", model = hooked_model, print_details=True)