# Einsum Divergence

For context: https://github.com/TransformerLensOrg/TransformerLens/issues/591

In [None]:
! pip install -U transformers
! pip install git+https://github.com/TransformerLensOrg/TransformerLens.git@einsum_divergence

In [None]:
! huggingface-cli login # if using models from the Hugging Face Hub that require auth. 

## Set Up (Based on Chris Mathwin's Gemma Notebook)

In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformer_lens import HookedTransformer
import gc
import einops
import numpy as np
from transformer_lens.utils import get_device

device = get_device()
torch.set_grad_enabled(False)
dtype = torch.float32

In [2]:
from fancy_einsum import einsum  # the suspect!

In [3]:
ooms = [10**-i for i in range(1, 10)]


def assert_close_for_ooms(a, b, ooms=ooms):
    for oom in ooms:
        assert torch.allclose(a, b, rtol=oom, atol=oom), f"Failed for oom={oom}"

## Reproduce without T-Lens

### Demonstrate on Synthetic Data

In [106]:
def get_synthetic_data_and_pytorch_default_result(device):
    """
    Returns synthetic data and the result of the operation using PyTorch's einsum.
    For this operation:
    "batch pos head_index d_head, head_index d_head d_model -> batch pos d_model"
    """
    
    # For demonstration purposes, I'll define the shapes:
    batch_size = 32
    pos = 128
    num_heads = 8
    d_head = 64
    d_model = 4096

    # Example tensors
    z = torch.randn(batch_size, pos, num_heads, d_head)  # [batch, pos, head_index, d_head]
    W_O = torch.randn(num_heads, d_head, d_model)        # [head_index, d_head, d_model]
    b_O = torch.randn(d_model)                           # [d_model]

    device = "cpu"
    # move all tensors to the device
    z = z.to(device)
    W_O = W_O.to(device)
    b_O = b_O.to(device)
    
    vanilla_result = (z.flatten(-2,-1) @ W_O.flatten(0,1)) + b_O#.reshape(*z.shape[:-1], -1).shape
    
    return z, W_O, b_O, vanilla_result


In [107]:
import torch
from fancy_einsum import einsum
# Assume z and self.W_O are given tensors with the correct shapes.

z, W_O, b_O, vanilla_result = get_synthetic_data_and_pytorch_default_result("cpu")

out = (
    (
        einsum("batch pos head_index d_head, head_index d_head d_model -> batch pos d_model",
            z,
            W_O,
        )
    )
    + b_O
)  # [batch, pos, d_model]

print(out.shape)

# <- Fails on my mac (M3 MAX) on either cpu or mps.
torch.testing.assert_close(out, vanilla_result, rtol=1e-7, atol=1e-7) 

torch.Size([32, 128, 4096])


AssertionError: Tensor-likes are not close!

Mismatched elements: 14475594 / 16777216 (86.3%)
Greatest absolute difference: 0.00016021728515625 at index (24, 93, 3020) (up to 1e-07 allowed)
Greatest relative difference: 5.4666666984558105 at index (21, 19, 1125) (up to 1e-07 allowed)

### Note that if we don't use einsum, just einops, this works. Why??

### But if we use fancy_opt_einsum (combination of opt_einsum and einops, it passes)

In [114]:
import torch

z, W_O, b_O, vanilla = get_synthetic_data_and_pytorch_default_result("cpu")

out = (
    (
        torch.einsum("bpij,ijk->bpk",
            z,
            W_O,
        )
    )
    + b_O
)  # [batch, pos, d_model]

print(out.shape)


vanilla = (z.flatten(-2,-1) @ W_O.flatten(0,1)) + b_O#.reshape(*z.shape[:-1], -1).shape
torch.testing.assert_close(out, vanilla, rtol=1e-9, atol=1e-9)

torch.Size([32, 128, 4096])


#### First show just opt_einsum

In [109]:
import torch
from opt_einsum import contract

z, W_O, b_O, vanilla = get_synthetic_data_and_pytorch_default_result("cpu")

out = (
    (
        contract("bpij,ijk->bpk",
            z,
            W_O,
        )
    )
    + b_O
)  # [batch, pos, d_model]

print(out.shape)


vanilla = (z.flatten(-2,-1) @ W_O.flatten(0,1)) + b_O#.reshape(*z.shape[:-1], -1).shape
torch.testing.assert_close(out, vanilla, rtol=1e-9, atol=1e-9)

torch.Size([32, 128, 4096])


####

### then show the hybrid with original syntax from fancy einsum

In [111]:

from fancy_einsum import convert_equation
from opt_einsum import contract

def fancy_opt_einsum(equation: str, *operands):
    """
    Variation on fancy opt einsum that uses opt_einsum for the contraction.
    
    Evaluates the Einstein summation convention on the operands.
    
    See: 
      https://pytorch.org/docs/stable/generated/torch.einsum.html
      https://numpy.org/doc/stable/reference/generated/numpy.einsum.html
    """
    new_equation = convert_equation(equation)
    return contract(new_equation, *operands)

z, W_O, b_O, vanilla = get_synthetic_data_and_pytorch_default_result("cpu")

out = (
    (
        fancy_opt_einsum("batch pos head_index d_head, head_index d_head d_model -> batch pos d_model",
            z,
            W_O,
        )
    )
    + b_O
)  # [batch, pos, d_model]


print(out.shape)
vanilla = (z.flatten(-2,-1) @ W_O.flatten(0,1)) + b_O#.reshape(*z.shape[:-1], -1).shape
torch.testing.assert_close(out, vanilla, rtol=1e-9, atol=1e-9)

torch.Size([32, 128, 4096])


## Examine Impact on Models

### Huggingface Model

In [5]:
torch.set_grad_enabled(False)
model_name = "google/gemma-2b"

hf_model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=torch.float32
)  # trust_remote_code=True, attn_implementation="eager")
tokenizer = AutoTokenizer.from_pretrained(
    model_name
)  # add_bos_token = True, use_fast=False, trust_remote_code=True)
hf_model.eval().to(device)

Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

GemmaForCausalLM(
  (model): GemmaModel(
    (embed_tokens): Embedding(256000, 2048, padding_idx=0)
    (layers): ModuleList(
      (0-17): 18 x GemmaDecoderLayer(
        (self_attn): GemmaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): GemmaRotaryEmbedding()
        )
        (mlp): GemmaMLP(
          (gate_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (up_proj): Linear(in_features=2048, out_features=16384, bias=False)
          (down_proj): Linear(in_features=16384, out_features=2048, bias=False)
          (act_fn): PytorchGELUTanh()
        )
        (input_layernorm): GemmaRMSNorm()
        (post_attention_layernorm): GemmaRMSNorm()
      )
    )
    (norm): GemmaR

### T-Lens Model

In [6]:
hooked_model = HookedTransformer.from_pretrained(
    model_name,
    tokenizer=tokenizer,
    fold_ln=False,
    fold_value_biases=False,
    center_writing_weights=False,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-2b into HookedTransformer


# Comparing Forward Passes

In [145]:
text = """
TransformerLens lets you load in 50+ different open source language models,
and exposes the internal activations of the model to you. You can cache
any internal activation in the model, and add in functions to edit, remove
or replace these activations as the model runs.
"""
input_ids = tokenizer(text, return_tensors="pt")["input_ids"].to(device)

Run each model with a cache

In [146]:
with torch.no_grad():
    hf_outputs = hf_model(input_ids, output_hidden_states=True, output_attentions=True)
    hf_logits_cpu = hf_outputs["logits"].cpu()
    hf_resid_pre_cache = hf_outputs["hidden_states"]
    hf_attentions = hf_outputs["attentions"]
    hf_resid_pre_cache_cpu = [cache.cpu() for cache in hf_resid_pre_cache]
    hf_attentions_cpu = [att.cpu() for att in hf_attentions]
    hf_outputs = hf_model(input_ids, labels=input_ids)
    hf_loss_cpu = hf_outputs.loss.cpu()

# TODO: add a some notebook config for low memory mode.
# del hf_model
# del hf_outputs
# del hf_resid_pre_cache
# gc.collect()
# torch.cuda.empty_cache()

with torch.no_grad():
    hooked_model_logits, hooked_model_cache = hooked_model.run_with_cache(input_ids)
    hooked_model_loss = hooked_model(input_ids, return_type="loss")
    hooked_model_loss_cpu = hooked_model_loss.cpu()
    hooked_model_logits_cpu = hooked_model_logits.detach().cpu()
    hooked_model_cache_cpu = {k: v.cpu() for k, v in hooked_model_cache.items()}
    n_layers = hooked_model.cfg.n_layers

# TODO: add a some notebook config for low memory mode.
# # del hooked_model
# del hooked_model_logits
# del hooked_model_cache
# del hooked_model_loss

# gc.collect()
# torch.cuda.empty_cache()

assert_close_for_ooms(hf_logits_cpu, hooked_model_logits_cpu)



In [205]:
pass_loose_bound = True
print("*"*5, "Matching hf and T-Lens residual stream in between transformer blocks", "*"*5)
atol = rtol = 1e-4
print("*"*5, f"\ttesting with {atol=} and {rtol=}\t","*"*5)
for l in range(n_layers):
    try:
        torch.testing.assert_close(hooked_model_cache_cpu[f'blocks.{l}.hook_resid_pre'], hf_resid_pre_cache_cpu[l], atol=atol, rtol=rtol)
    except:
        max_diff = (hooked_model_cache_cpu[f'blocks.{l}.hook_resid_pre'] - hf_resid_pre_cache_cpu[l]).abs().max()
        print(f"layer {l} \t not close, max difference: {max_diff}")
        pass_loose_bound = False

if pass_loose_bound:
    print(f"All layers match with {atol=} {rtol=}")
else:
    print("All layers match")

***** Matching hf and T-Lens residual stream in between transformer blocks *****
***** 	testing with atol=0.0001 and rtol=0.0001	 *****
layer 1 	 not close, max difference: 22.35356903076172
layer 2 	 not close, max difference: 44.470909118652344
layer 3 	 not close, max difference: 56.772979736328125
layer 4 	 not close, max difference: 68.3115463256836
layer 5 	 not close, max difference: 82.96070861816406
layer 6 	 not close, max difference: 103.59776306152344
layer 7 	 not close, max difference: 119.89836883544922
layer 8 	 not close, max difference: 243.830322265625
layer 9 	 not close, max difference: 494.560546875
layer 10 	 not close, max difference: 495.130859375
layer 11 	 not close, max difference: 494.717041015625
layer 12 	 not close, max difference: 493.398681640625
layer 13 	 not close, max difference: 491.9833984375
layer 14 	 not close, max difference: 490.833984375
layer 15 	 not close, max difference: 488.646484375
layer 16 	 not close, max difference: 481.8250732421