# What's the fastest way to get the grads of the MLP activation function in TransformerLens?

In [28]:
import einops
import statistics
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import timeit
import transformer_lens
from transformer_lens.components.mlps.mlp import MLP as TL_MLP
from tqdm import tqdm
from jaxtyping import Float
from transformer_lens.utilities.addmm import batch_addmm

device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")

Pythia models have parallel attn and MLP, for now let's just do the basic thing and use GPT2 small to avoid having to deal with that

In [19]:
model = transformer_lens.HookedTransformer.from_pretrained("gpt2-small")
model.cfg.parallel_attn_mlp



Loaded pretrained model gpt2-small into HookedTransformer


In [23]:
prompt = "Somebody once told me the world is gonna roll me I ain't the sharpest tool in the shed She was looking kind of dumb with her finger and her thumb In the shape of an L on her forehead"
layer = 3
_, cache = model.run_with_cache(prompt)
resid_mid = cache['resid_mid', layer]
resid_mid.shape

torch.Size([1, 43, 768])

In [35]:
class MLPWithActGrads(TL_MLP):
    def forward(
        self, x: Float[torch.Tensor, "batch pos d_model"], return_act_grads: bool = False
    ) -> Float[torch.Tensor, "batch pos d_model"]:
        # This is equivalent to (roughly) W_in @ x + b_in. It's important to
        # use a fused addmm to ensure it matches the Huggingface implementation
        # exactly.
        pre_act = self.hook_pre(batch_addmm(self.b_in, self.W_in, x))  # [batch, pos, d_mlp]

        if (
            self.cfg.is_layer_norm_activation()
            and self.hook_mid is not None
            and self.ln is not None
        ):
            raise NotImplementedError("You passed in something weird and I can't be bothered to support it rn, go check out the TransformerLens MLP code for what's supposed to go here and open a PR if you want this to work")
        else:
            post_act = self.act_fn(pre_act) # [batch, pos, d_mlp]
            if return_act_grads:
                grad_of_act = torch.autograd.grad(
                    outputs=post_act,
                    inputs=pre_act,
                    grad_outputs=torch.ones_like(post_act)
                )[0]
            post_act = self.hook_post(post_act)
        output = batch_addmm(self.b_out, self.W_out, post_act)

        if return_act_grads:
            return output, grad_of_act
        return output

original_mlp = model.blocks[layer].mlp
mlp_with_grads = MLPWithActGrads(original_mlp.cfg)
mlp_with_grads.load_state_dict(original_mlp.state_dict())
mlp_with_grads.to(device)

MLPWithActGrads(
  (hook_pre): HookPoint()
  (hook_post): HookPoint()
)

In [41]:
original_resid_post = original_mlp(resid_mid)
my_resid_post, act_grads = mlp_with_grads(resid_mid, return_act_grads=True)
assert torch.allclose(original_resid_post, my_resid_post)
act_grads

tensor([[[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  ...,  1.0004e+00,
           1.0000e+00,  0.0000e+00],
         [-1.9954e-02, -1.1823e-01,  6.7858e-01,  ...,  5.8370e-01,
           7.6067e-01, -1.3031e-02],
         [-2.9115e-06,  1.0245e+00, -4.3108e-02,  ..., -7.7937e-02,
           2.0209e-01, -3.3486e-02],
         ...,
         [-4.3960e-02, -1.2179e-01, -1.3491e-02,  ...,  5.1352e-02,
           7.8677e-01,  6.4266e-02],
         [-5.5457e-02,  6.8279e-01,  0.0000e+00,  ...,  2.1305e-03,
           5.6596e-01, -2.0028e-02],
         [ 0.0000e+00,  8.5492e-01,  0.0000e+00,  ...,  5.9475e-01,
           7.0861e-01, -3.8573e-02]]], device='mps:0')