#Setup

In [None]:
!nvidia-smi
!pip install einops
!pip install pyyaml==5.4.1
!pip install transformers

Sat Sep  3 07:49:16 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100-SXM2...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    24W / 300W |      0MiB / 16160MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
import tqdm.notebook as tqdm

import random
import time

from google.colab import drive
from pathlib import Path
import pickle
import os

import matplotlib.pyplot as plt
%matplotlib inline
import plotly.express as px
import plotly.io as pio
pio.renderers.default = "colab"
import plotly.graph_objects as go

from torch.utils.data import DataLoader

from functools import *
import pandas as pd
import gc
import collections
import copy

# import comet_ml
import itertools

Helper functions

In [None]:
def get_corner(tensor, n=2):
    # Prints the top left corner of the tensor
    if len(tensor.shape)==0:
        return tensor
    elif len(tensor.shape)==1:
        return tensor[:n]
    elif len(tensor.shape)==2:
        return tensor[:n, :n]
    elif len(tensor.shape)==3:
        return tensor[:n, :n, :n]
    elif len(tensor.shape)==4:
        return tensor[:n, :n, :n, :n]
    elif len(tensor.shape)==5:
        return tensor[:n, :n, :n, :n, :n]
    elif len(tensor.shape)==6:
        return tensor[:n, :n, :n, :n, :n, :n]
    else:
        # I never need tensors of rank > 6
        raise ValueError(f'Tensor of shape {tensor.shape} is too big')

def to_numpy(tensor, flat=False):
    if (type(tensor)!=torch.Tensor) and (type(tensor)!=torch.nn.parameter.Parameter):
        return tensor
    if flat:
        return tensor.flatten().detach().cpu().numpy()
    else:
        return tensor.detach().cpu().numpy()

def gelu_new(input):
    # Implementation of GeLU used by GPT2 - subtly different from PyTorch's
    return 0.5 * input * (1.0 + torch.tanh(np.sqrt(2.0 / np.pi) * (input + 0.044715 * torch.pow(input, 3.0))))

#Hook Points

A Garcon-style interface - the key thing is a HookPoint class. This is a layer to wrap any activation within the model in. The HookPoint acts as an identity function, but allows us to put PyTorch hooks in to edit and access the relevant activation. This allows us to take any model and insert in access points to all interesting activations by wrapping them in HookPoints

There is also a `HookedRootModule` class - this is a utility class that the root module should inherit from (root module = the model we run) - it has several utility functions for using hooks well. 

The default interface is the `run_with_hooks` function on the root module, which lets us run a forwards pass on the model, and pass on a list of hooks paired with layer names to run on that pass. 

The syntax for a hook is `function(activation, hook)` where `activation` is the activation the hook is wrapped around, and `hook` is the `HookPoint` class the function is attached to. If the function returns a new activation or edits the activation in-place, that replaces the old one, if it returns None then the activation remains as is.



In [None]:
# A helper class to get access to intermediate activations (inspired by Garcon)
# It's a dummy module that is the identity function by default
# I can wrap any intermediate activation in a HookPoint and get a convenient 
# way to add PyTorch hooks
class HookPoint(nn.Module):
    def __init__(self):
        super().__init__()
        self.fwd_hooks = []
        self.bwd_hooks = []
        self.ctx = {}
        
        # A variable giving the hook's name (from the perspective of the root 
        # module) - this is set by the root module at setup.
        self.name = None
    
    def add_hook(self, hook, dir='fwd'):
        # Hook format is fn(activation, hook_name)
        # Change it into PyTorch hook format (this includes input and output, 
        # which are the same for a HookPoint)
        def full_hook(module, module_input, module_output):
            return hook(module_output, hook=self)
        if dir=='fwd':
            handle = self.register_forward_hook(full_hook)
            self.fwd_hooks.append(handle)
        elif dir=='bwd':
            handle = self.register_full_backward_hook(full_hook)
            self.bwd_hooks.append(handle)
        else:
            raise ValueError(f"Invalid direction {dir}")
    
    def remove_hooks(self, dir='fwd'):
        if (dir=='fwd') or (dir=='both'):
            for hook in self.fwd_hooks:
                hook.remove()
            self.fwd_hooks = []
        if (dir=='bwd') or (dir=='both'):
            for hook in self.bwd_hooks:
                hook.remove()
            self.bwd_hooks = []
        if dir not in ['fwd', 'bwd', 'both']:
            raise ValueError(f"Invalid direction {dir}")
    
    def clear_context(self):
        del self.ctx
        self.ctx = {}

    def forward(self, x):
        return x

    def layer(self):
        # Returns the layer index if the name has the form 'blocks.{layer}.{...}'
        # Helper function that's mainly useful on EasyTransformer
        # If it doesn't have this form, raises an error - 
        split_name = self.name.split('.')
        return int(split_name[1])

class HookedRootModule(nn.Module):
    # A class building on nn.Module to interface nicely with HookPoints
    # Allows you to name each hook, remove hooks, cache every activation/gradient, etc
    def __init__(self, *args):
        super().__init__()
    
    def setup(self):
        # Setup function - this needs to be run in __init__ AFTER defining all 
        # layers
        # Add a parameter to each module giving its name
        # Build a dictionary mapping a module name to the module
        self.mod_dict = {}
        self.hook_dict = {}
        for name, module in self.named_modules():
            module.name = name
            self.mod_dict[name] = module
            if type(module)==HookPoint:
                self.hook_dict[name] = module
        
    def hook_points(self):
        return (self.hook_dict.values())

    def remove_all_hook_fns(self, direction='both'):
        for hp in self.hook_points():
            hp.remove_hooks(direction)
    
    def clear_contexts(self):
        for hp in self.hook_points():
            hp.clear_context()
    
    def reset_hooks(self, clear_contexts=True, direction='both'):
        if clear_contexts: self.clear_contexts()
        self.remove_all_hook_fns(direction)
    
    def cache_all(self, cache, incl_bwd=False, device='cuda'):
        # Caches all activations wrapped in a HookPoint
        def save_hook(tensor, hook):
            cache[hook.name] = tensor.detach().to(device)
        def save_hook_back(tensor, hook):
            cache[hook.name+'_grad'] = tensor[0].detach().to(device)
        for hp in self.hook_points():
            hp.add_hook(save_hook, 'fwd')
            if incl_bwd:
                hp.add_hook(save_hook_back, 'bwd')
    
    def run_with_hooks(self, 
                       *args, 
                       fwd_hooks=[], 
                       bwd_hooks=[], 
                       reset_hooks_start=True, 
                       reset_hooks_end=True, 
                       clear_contexts=False):
        '''
        fwd_hooks: A list of (name, hook), where name is either the name of 
        a hook point or a Boolean function on hook names and hook is the 
        function to add to that hook point, or the hook whose names evaluate 
        to True respectively. Ditto bwd_hooks
        reset_hooks_start (bool): If True, all prior hooks are removed at the start
        reset_hooks_end (bool): If True, all hooks are removed at the end (ie, 
        including those added in this run)
        clear_contexts (bool): If True, clears hook contexts whenever hooks are reset
        
        Note that if we want to use backward hooks, we need to set 
        reset_hooks_end to be False, so the backward hooks are still there - this function only runs a forward pass.
        '''
        if reset_hooks_start:
            self.reset_hooks(clear_contexts)
        for name, hook in fwd_hooks:
            if type(name)==str:
                self.mod_dict[name].add_hook(hook, dir='fwd')
            else:
                # Otherwise, name is a Boolean function on names
                for hook_name, hp in self.hook_dict.items():
                    if name(hook_name):
                        hp.add_hook(hook, dir='fwd')
        for name, hook in bwd_hooks:
            if type(name)==str:
                self.mod_dict[name].add_hook(hook, dir='fwd')
            else:
                # Otherwise, name is a Boolean function on names
                for hook_name, hp in self.hook_dict:
                    if name(hook_name):
                        hp.add_hook(hook, dir='bwd')
        out = self.forward(*args)
        if reset_hooks_end:
            if len(bwd_hooks)>0:
                print("WARNING: Hooks were reset at the end of run_with_hooks while backward hooks were set.")
                print("This removes the backward hooks before a backward pass can occur")
            self.reset_hooks(clear_contexts)
        return out
        

##Example

Here's a simple example of how to use the classes:

We define a basic network with two layers that each take a scalar input $x$, square it, and add a constant:
$x_0=x$, $x_1=x_0^2+3$, $x_2=x_1^2-4$.

We wrap the input, each layer's output, and the intermediate value of each layer (the square) in a hook point.

In [None]:
class SquareThenAdd(nn.Module):
    def __init__(self, offset):
        super().__init__()
        self.offset = nn.Parameter(torch.tensor(offset))
        self.hook_square = HookPoint()
    
    def forward(self, x):
        # The hook_square doesn't change the value, but lets us access it
        square = self.hook_square(x * x)
        return self.offset + square
    
class TwoLayerModel(HookedRootModule):
    def __init__(self):
        super().__init__()
        self.layer1 = SquareThenAdd(3.)
        self.layer2 = SquareThenAdd(-4.)
        self.hook_in = HookPoint()
        self.hook_mid = HookPoint()
        self.hook_out = HookPoint()

        # We need to call the setup function of HookedRootModule to build an 
        # internal dictionary of modules and hooks, and to give each hook a name
        super().setup()
    
    def forward(self, x):
        # We wrap the input and each layer's output in a hook - they leave the 
        # value unchanged (unless there's a hook added to explicitly change it), 
        # but allow us to access it.
        x_in = self.hook_in(x)
        x_mid = self.hook_mid(self.layer1(x_in))
        x_out = self.hook_out(self.layer2(x_mid))
        return x_out
model = TwoLayerModel()



We can add a cache, to save the activation at each hook point

(There's a custom `cache_all` function on the root module as a convenience, which will add hooks to cache every activation at a hook point - we could also manually add hooks with `run_with_hooks`)

In [None]:
cache = {}
model.cache_all(cache)
print('Model output:', model(torch.tensor(5.)).item())
for key in cache:
    print(f"Value cached at hook {key}", cache[key].item())

Model output: 780.0
Value cached at hook hook_in 5.0
Value cached at hook layer1.hook_square 25.0
Value cached at hook hook_mid 28.0
Value cached at hook layer2.hook_square 784.0
Value cached at hook hook_out 780.0


We can also use hooks to intervene on activations - eg, we can set the intermediate value in layer 2 to zero to change the output to -5

In [None]:
def set_to_zero_hook(tensor, hook):
    print(hook.name)
    return torch.tensor(0.)
print('Output after intervening on layer2.hook_scaled', 
      model.run_with_hooks(torch.tensor(5.),
                           fwd_hooks = [('layer2.hook_square', set_to_zero_hook)]).item())

layer2.hook_square
Output after intervening on layer2.hook_scaled -4.0


#Defining the model

We now define a stripped down transformer. There are helper functions to load in the weights of several families of open source LLMs - OpenAI's GPT-2, Facebook's OPT and Eleuther's GPT-Neo.

Note: OPT-350M is not supported - it applies the LayerNorms to the *outputs* of each layer, which means we cannot fold the weights and biases into other layers, and would require notably different architecture.

**TODO:** Add in GPT-J and GPT-NeoX functionality

In [None]:
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
VALID_MODEL_NAMES = ['gpt2', 
                     'gpt2-medium', 
                     'gpt2-large', 
                     'gpt2-xl', 
                     'facebook/opt-125m', 
                     'facebook/opt-1.3b', 
                     'facebook/opt-2.7b', 
                     'facebook/opt-6.7b', 
                     'facebook/opt-13b', 
                     'facebook/opt-30b', 
                     'facebook/opt-66b', 
                     'EleutherAI/gpt-neo-125M', 
                     'EleutherAI/gpt-neo-1.3B', 
                     'EleutherAI/gpt-neo-2.7B', 
                     'EleutherAI/gpt-j-6B', 
                     'EleutherAI/gpt-neox-20b']



# TODO: Add Bloom
'''
bloom-350m
bloom-760m
bloom-1b3
bloom-2b5
bloom-6b3
bloom (176B parameters)
https://huggingface.co/docs/transformers/model_doc/bloom
'''

'\nbloom-350m\nbloom-760m\nbloom-1b3\nbloom-2b5\nbloom-6b3\nbloom (176B parameters)\nhttps://huggingface.co/docs/transformers/model_doc/bloom\n'

We define the components of our simple transformer. Notable deviations from the standard framing:
* Each attention head has weight matrices $W_Q,W_K,W_V$ of size [head_index, d_head, d_model] and $W_O$ of size [head_index, d_model, d_head] (ie, rather than concatenating the vectors from each head before the linear map)
* The LayerNorms purely center and normalise, they do not have weights or biases. Instead these are folded into the weights and biases of the layers that come immediately after (attention, MLP and unembed)
** This means that every query, key, value, MLP input, and the unembed calculations all have biases, even if they don't in the unfolded model


In [None]:
# Define network architecture

# Embed & Unembed
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty(self.cfg['d_model'], self.cfg['d_vocab']))
    
    def forward(self, tokens):
        # If A has shape [a, b] and B has shape [c, d], then A[:, B] has shape [a, c, d]
        # B acts as a tensor of indices into the second dimension (so >=0 and <b)
        return einops.rearrange(self.W_E[:, tokens], 'd_model batch pos -> batch pos d_model')

class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty(self.cfg['d_vocab'], self.cfg['d_model']))
        self.b_U = nn.Parameter(torch.empty(self.cfg['d_vocab']))
    
    def forward(self, tokens):
        return torch.einsum('vm,bpm->bpv', self.W_U, tokens)+self.b_U # [batch, pos, d_vocab]

# Positional Embeddings
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty(self.cfg['d_model'], self.cfg['n_ctx'])) 
    
    def forward(self, x):
        # Output shape [pos, d_model] - will be broadcast along batch dim
        return self.W_pos[:, :x.size(-1)].T # [pos, d_model]

# LayerNormPre
# I fold the LayerNorm weights and biases into later weights and biases. 
# This is just the 'center and normalise' part of LayerNorm
# Centering is equivalent to just deleting one direction of residual space, 
# and is equivalent to centering the weight matrices of everything writing to the residual stream
# Normalising is a funkier non-linear operation, that projects the residual stream onto the unit hypersphere
class LayerNormPre(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.eps = self.cfg['eps']

        # Adds a hook point for the normalisation scale factor
        self.hook_scale = HookPoint() # [batch, pos]
    
    def forward(self, x):
        x = x - x.mean(axis=-1, keepdim=True) # [batch, pos, d_model]
        scale = self.hook_scale((einops.reduce(x.pow(2), 
                                               'batch pos embed -> batch pos 1', 
                                               'mean') + 
                                 self.eps).sqrt()) # [batch, pos, 1]
        return x / scale

# Attention
class Attention(nn.Module):
    def __init__(self, cfg, attn_type='global'):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty(self.cfg['n_heads'], self.cfg['d_head'], self.cfg['d_model']))
        self.W_K = nn.Parameter(torch.empty(self.cfg['n_heads'], self.cfg['d_head'], self.cfg['d_model']))
        self.W_V = nn.Parameter(torch.empty(self.cfg['n_heads'], self.cfg['d_head'], self.cfg['d_model']))
        self.W_O = nn.Parameter(torch.empty(self.cfg['n_heads'], self.cfg['d_model'], self.cfg['d_head']))
        self.b_Q = nn.Parameter(torch.empty(self.cfg['n_heads'], self.cfg['d_head']))
        self.b_K = nn.Parameter(torch.empty(self.cfg['n_heads'], self.cfg['d_head']))
        self.b_V = nn.Parameter(torch.empty(self.cfg['n_heads'], self.cfg['d_head']))
        self.b_O = nn.Parameter(torch.empty(self.cfg['d_model']))
        
        self.attn_type = attn_type
        # Create a query_pos x key_pos mask, with True iff that query position 
        # can attend to that key position
        causal_mask = torch.tril(torch.ones((self.cfg['n_ctx'], self.cfg['n_ctx'])).bool())
        if self.attn_type == 'global':
            # For global attention, this is a lower triangular matrix - key <= query
            self.register_buffer('mask', causal_mask)
        elif self.attn_type == 'local':
            # For local, this is banded, query - window_size < key <= query
            self.register_buffer('mask', torch.triu(causal_mask, 1-self.cfg['window_size']))
        else:
            raise ValueError(f"Invalid attention type: {self.attn_type}")
        
        self.register_buffer('IGNORE', torch.tensor(-1e5))
        
        if self.cfg['use_attn_scale']:
            self.attn_scale = np.sqrt(self.cfg['d_head'])
        else:
            self.attn_scale = 1.

        self.hook_k = HookPoint() # [batch, pos, head_index, d_head]
        self.hook_q = HookPoint() # [batch, pos, head_index, d_head]
        self.hook_v = HookPoint() # [batch, pos, head_index, d_head]
        self.hook_z = HookPoint() # [batch, pos, head_index, d_head]
        self.hook_attn_scores = HookPoint() # [batch, head_index, query_pos, key_pos]
        self.hook_attn = HookPoint() # [batch, head_index, query_pos, key_pos]
        self.hook_result = HookPoint() # [batch, head_index, head_index, d_model]

    def forward(self, x):
        q = self.hook_q(torch.einsum('ihm,bpm->bpih', self.W_Q, x) + self.b_Q) # [batch, pos, head_index, d_head]
        k = self.hook_k(torch.einsum('ihm,bpm->bpih', self.W_K, x) + self.b_K) # [batch, pos, head_index, d_head]
        v = self.hook_v(torch.einsum('ihm,bpm->bpih', self.W_V, x) + self.b_V) # [batch, pos, head_index, d_head]
        attn_scores = torch.einsum('bpih,bqih->bipq', q, k)/self.attn_scale # [batch, head_index, query_pos, key_pos]
        attn_scores = self.hook_attn_scores(self.causal_mask(attn_scores)) # [batch, head_index, query_pos, key_pos]
        attn_matrix = self.hook_attn(F.softmax(attn_scores, dim=-1)) # [batch, head_index, query_pos, key_pos]
        z = self.hook_z(torch.einsum('bpih,biqp->bqih', v, attn_matrix)) # [batch, pos, head_index, d_head]
        if self.cfg['use_attn_result']:
            result = self.hook_result(torch.einsum('imh,bqih->bqim', self.W_O, z)) # [batch, pos, head_index, d_model]
            out = einops.reduce(result, 
                                'batch position index model->batch position model', 
                                'sum')+self.b_O  # [batch, pos, d_model]
        else:
            out = torch.einsum('idh,bqih->bqd', self.W_O, z)+self.b_O # [batch, pos, d_model]
        return out
    
    def causal_mask(self, attn_scores):
        return torch.where(self.mask[:attn_scores.size(-2), :attn_scores.size(-1)], attn_scores, self.IGNORE)

# MLP Layers
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty(self.cfg['d_mlp'], self.cfg['d_model']))
        self.b_in = nn.Parameter(torch.empty(self.cfg['d_mlp']))
        self.W_out = nn.Parameter(torch.empty(self.cfg['d_model'], self.cfg['d_mlp']))
        self.b_out = nn.Parameter(torch.empty(self.cfg['d_model']))

        self.hook_pre = HookPoint() # [batch, pos, d_mlp]
        self.hook_post = HookPoint() # [batch, pos, d_mlp]

        if self.cfg['act_fn']=='relu':
            self.act_fn = F.relu
        elif self.cfg['act_fn']=='gelu_new':
            self.act_fn = gelu_new
        else:
            raise ValueError(f"Invalid activation function name: {self.cfg['act_fn']}")


    def forward(self, x):
        x = self.hook_pre(torch.einsum('md,bpd->bpm', self.W_in, x) + self.b_in) # [batch, pos, d_mlp]
        x = self.hook_post(self.act_fn(x)) # [batch, pos, d_mlp]
        x = torch.einsum('dm,bpm->bpd', self.W_out, x) + self.b_out # [batch, pos, d_model]
        return x

# Transformer Block
class TransformerBlock(nn.Module):
    def __init__(self, cfg, block_index):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNormPre(cfg)
        if not self.cfg['use_local_attn']:
            self.attn = Attention(cfg, 'global')
        else:
            attn_type = self.cfg['attn_types'][block_index]
            self.attn = Attention(cfg, attn_type)
        self.ln2 = LayerNormPre(cfg)
        self.mlp = MLP(cfg)

        self.hook_attn_out = HookPoint() # [batch, pos, d_model]
        self.hook_mlp_out = HookPoint() # [batch, pos, d_model]
        self.hook_resid_pre = HookPoint() # [batch, pos, d_model]
        self.hook_resid_mid = HookPoint() # [batch, pos, d_model]
        self.hook_resid_post = HookPoint() # [batch, pos, d_model]
    
    def forward(self, x):
        resid_pre = self.hook_resid_pre(x) # [batch, pos, d_model]
        attn_out = self.hook_attn_out(self.attn(self.ln1(resid_pre))) # [batch, pos, d_model]
        resid_mid = self.hook_resid_mid(resid_pre + attn_out) # [batch, pos, d_model]
        mlp_out = self.hook_mlp_out(self.mlp(self.ln2(resid_mid))) # [batch, pos, d_model]
        resid_post = self.hook_resid_post(resid_mid + mlp_out) # [batch, pos, d_model]
        return resid_post


In [None]:

# Full transformer
class EasyTransformer(HookedRootModule):
    """
    This class implements a full Transformer using the above components, with 
    HookPoints on every interesting activation. It inherits from HookedRootModule. 
    It is initialised with a model_name, and automatically loads the model weights 
    for that model, loads them into this model, folds in LayerNorm and centers 
    the weights
    """
    def __init__(self, 
                 model_name, 
                 use_attn_result=False, 
                 model=None, 
                 keep_original_model=False, 
                 center_weights=True):
        """
        model_name (str): The name of the model to load, via HuggingFace
        use_attn_result (bool): Says whether to explicitly calculate the amount 
            each head adds to the residual stream (with a hook) and THEN add it 
            up, vs just calculating the sum. This can be very memory intensive 
            for large models, so defaults to False
        model: The loaded model from HuggingFace. If None, it is automatically 
            loaded from HuggingFace - this just saves memory if the model was 
            already loaded into RAM
        keep_original_model (bool): If False, the original HuggingFace model is 
            deleted, otherwise it's kept as a self.model attribute
        """
        assert model_name in VALID_MODEL_NAMES
        super().__init__()
        self.model_name = model_name
        self.model_type = self.get_model_type(model_name)
        if model is not None:
            self.model = model
        else:
            self.model = AutoModelForCausalLM.from_pretrained(model_name)
        
        self.cfg = self.convert_config(self.model.config, model_type=self.model_type)
        self.cfg['use_attn_result']=use_attn_result
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
        self.embed = Embed(self.cfg)
        self.hook_embed = HookPoint() # [batch, pos, d_model]
        
        self.pos_embed = PosEmbed(self.cfg)
        self.hook_pos_embed = HookPoint() # [batch, pos, d_model]
        
        self.blocks = nn.ModuleList([TransformerBlock(self.cfg, block_index) for block_index in range(self.cfg['n_layers'])])
        self.ln_final = LayerNormPre(self.cfg)
        self.unembed = Unembed(self.cfg)

        # Gives each module a parameter with its name (relative to this root module)
        # Needed for HookPoints to work
        self.setup()

        # Load model weights, and fold in layer norm weights
        if self.model_type=='gpt2':
            self.load_gpt2_weights(self.model)
        elif self.model_type=='neo':
            self.load_neo_weights(self.model)
        elif self.model_type=='gptj':
            self.load_gptj_weights(self.model)
        elif self.model_type=='neox':
            self.load_neox_weights(self.model)
        elif self.model_type=='opt':
            self.load_opt_weights(self.model)
        
        # Set the average of each weight matrix writing to the residual stream to zero
        # (Layer Norm removes the mean anyway, so this simplifies the weights 
        # without changing the computation)
        if center_weights:
            self.center_weights()
        
        if not keep_original_model:
            # Delete the original model to save memory
            del self.model
    
    def forward(self, x):
        # Input x is either a batch of tokens ([batch, pos]) or a text string
        if type(x)==str:
            # If text, convert to tokens (batch_size=1)
            x = self.to_tokens(x)
        embed = self.hook_embed(self.embed(x)) # [batch, pos, d_model]
        pos_embed = self.hook_pos_embed(self.pos_embed(x)) # [batch, pos, d_model]
        residual = embed + pos_embed # [batch, pos, d_model]
        for block in self.blocks:
            # Note that each block includes skip connections, so we don't need
            # residual + block(residual)
            residual = block(residual) # [batch, pos, d_model]
        x = self.unembed(self.ln_final(residual)) # [batch, pos, d_vocab]
        return x
    
    def to_tokens(self, text):
        return self.tokenizer(text, return_tensors='pt')['input_ids']
    
    def get_model_type(self, model_name):
        if 'gpt2' in model_name:
            return 'gpt2'
        elif 'opt' in model_name:
            return 'opt'
        elif model_name=='EleutherAI/gpt-neox-20b':
            return 'neox'
        elif model_name=='EleutherAI/gpt-j-6B':
            return 'gptj'
        elif 'neo' in model_name:
            return 'neo'
        else:
            raise ValueError(f"Invalid model name: {model_name}")
    
    def convert_config(self, config, model_type):
        if model_type=='neo':
            cfg = {
                'd_model':config.hidden_size,
                'd_head':config.hidden_size//config.num_heads,
                'n_heads':config.num_heads,
                'd_mlp':config.hidden_size*4,
                'n_layers':config.num_layers,
                'n_ctx':config.max_position_embeddings,
                'eps':config.layer_norm_epsilon,
                'd_vocab':config.vocab_size,
                'attn_types':config.attention_layers,
                'act_fn':config.activation_function,
                'use_attn_scale':False,
                'use_local_attn':True,
                'window_size':config.window_size,
            }
        elif model_type=='gpt2':
            cfg = {
                'd_model':config.n_embd,
                'd_head':config.n_embd//config.n_head,
                'n_heads':config.n_head,
                'd_mlp':config.n_embd*4,
                'n_layers':config.n_layer,
                'n_ctx':config.n_ctx,
                'eps':config.layer_norm_epsilon,
                'd_vocab':config.vocab_size,
                'act_fn':config.activation_function,
                'use_attn_scale':True,
                'use_local_attn':False,
            }
        elif model_type=='opt':
            cfg = {
                'd_model':config.hidden_size,
                'd_head':config.hidden_size//config.num_attention_heads,
                'n_heads':config.num_attention_heads,
                'd_mlp':config.ffn_dim,
                'n_layers':config.num_hidden_layers,
                'n_ctx':config.max_position_embeddings,
                'eps':1e-5,
                'd_vocab':config.vocab_size,
                'act_fn':config.activation_function,
                'use_attn_scale':True,
                'use_local_attn':False,
            }
        elif model_type=='gptj':
            raise NotImplementedError
        elif model_type=='neox':
            raise NotImplementedError
        
        cfg['model_name']=self.model_name
        cfg['model_type']=model_type
        return cfg
    
    def center_weights(self):
        # Sets the average of each row of each weight matrix writing to the 
        # residual stream to zero
        # LayerNorm subtracts the mean of the residual stream, and it's always 
        # applied when reading from the residual stream, so this dimension is 
        # purely noise
        # Also does the same for W_U, since translating the logits doesn't affect
        # the log_probs or loss
        self.embed.W_E.data -= self.embed.W_E.mean(0, keepdim=True)
        self.pos_embed.W_pos.data -= self.pos_embed.W_pos.mean(0, keepdim=True)
        self.unembed.W_U.data -= self.unembed.W_U.mean(0, keepdim=True)
        for block in self.blocks:
            block.attn.W_O.data -= einops.reduce(block.attn.W_O, 
                                                            'index d_model d_head -> index 1 d_head',
                                                            'mean')
            block.mlp.W_out.data -= block.mlp.W_out.mean(0, keepdim=True)
        
    def load_gpt2_weights(self, gpt2):
        sd = self.state_dict()

        sd['embed.W_E'] = gpt2.transformer.wte.weight.T
        sd['pos_embed.W_pos'] = gpt2.transformer.wpe.weight.T

        for l in range(self.cfg['n_layers']):
            # In GPT-2, q,k,v are produced by one big linear map, whose output is 
            # concat([q, k, v])
            W = gpt2.transformer.h[l].attn.c_attn.weight
            w_ln_attn = gpt2.transformer.h[l].ln_1.weight
            W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=1)
            W_Q = einops.rearrange(W_Q, 'm (i h)->i h m', i=self.cfg['n_heads'])
            W_K = einops.rearrange(W_K, 'm (i h)->i h m', i=self.cfg['n_heads'])
            W_V = einops.rearrange(W_V, 'm (i h)->i h m', i=self.cfg['n_heads'])
            
            # Fold in layer norm weights
            sd[f"blocks.{l}.attn.W_Q"] = W_Q * w_ln_attn
            sd[f"blocks.{l}.attn.W_K"] = W_K * w_ln_attn
            sd[f"blocks.{l}.attn.W_V"] = W_V * w_ln_attn

            b_ln = gpt2.transformer.h[l].ln_1.bias
            qkv_bias = gpt2.transformer.h[l].attn.c_attn.bias
            qkv_bias = einops.rearrange(qkv_bias, 
                                        '(qkv index head)->qkv index head', 
                                        qkv=3, 
                                        index=self.cfg['n_heads'], 
                                        head=self.cfg['d_head'])
            # Fold in layer norm biases
            sd[f'blocks.{l}.attn.b_Q'] = W_Q @ b_ln + qkv_bias[0]
            sd[f'blocks.{l}.attn.b_K'] = W_K @ b_ln + qkv_bias[1]
            sd[f'blocks.{l}.attn.b_V'] = W_V @ b_ln + qkv_bias[2]


            W_O = gpt2.transformer.h[l].attn.c_proj.weight
            W_O = einops.rearrange(W_O, '(i h) m->i m h', i=self.cfg['n_heads'])
            sd[f"blocks.{l}.attn.W_O"] = W_O
            sd[f'blocks.{l}.attn.b_O'] = gpt2.transformer.h[l].attn.c_proj.bias

            W_in = gpt2.transformer.h[l].mlp.c_fc.weight.T
            W_out = gpt2.transformer.h[l].mlp.c_proj.weight.T
            # Fold in layer norm weights
            W_in_adj = gpt2.transformer.h[l].ln_2.weight[None, :] * W_in
            sd[f"blocks.{l}.mlp.W_in"] = W_in_adj
            # Fold in layer norm biases
            sd[f"blocks.{l}.mlp.b_in"] = gpt2.transformer.h[l].mlp.c_fc.bias+(W_in @ gpt2.transformer.h[l].ln_2.bias)
            sd[f"blocks.{l}.mlp.W_out"] = W_out
            sd[f"blocks.{l}.mlp.b_out"] = gpt2.transformer.h[l].mlp.c_proj.bias
        W_U = gpt2.lm_head.weight
        # Fold in layer norm weights
        sd['unembed.W_U'] = gpt2.transformer.ln_f.weight[None, :] * W_U
        # Fold in layer norm biases
        sd['unembed.b_U'] = gpt2.lm_head.weight @ gpt2.transformer.ln_f.bias
        self.load_state_dict(sd)
    
    def load_neo_weights(self, neo):
        sd = self.state_dict()

        sd['embed.W_E'] = neo.transformer.wte.weight.T
        sd['pos_embed.W_pos'] = neo.transformer.wpe.weight.T

        for l in range(self.cfg['n_layers']):
            w_ln_attn = neo.transformer.h[l].ln_1.weight
            W_Q = neo.transformer.h[l].attn.attention.q_proj.weight
            W_K = neo.transformer.h[l].attn.attention.k_proj.weight
            W_V = neo.transformer.h[l].attn.attention.v_proj.weight
            W_Q = einops.rearrange(W_Q, '(i h) m->i h m', i=self.cfg['n_heads'])
            W_K = einops.rearrange(W_K, '(i h) m->i h m', i=self.cfg['n_heads'])
            W_V = einops.rearrange(W_V, '(i h) m->i h m', i=self.cfg['n_heads'])
            
            sd[f"blocks.{l}.attn.W_Q"] = W_Q * w_ln_attn
            sd[f"blocks.{l}.attn.W_K"] = W_K * w_ln_attn
            sd[f"blocks.{l}.attn.W_V"] = W_V * w_ln_attn

            b_ln = neo.transformer.h[l].ln_1.bias
            sd[f'blocks.{l}.attn.b_Q'] = W_Q @ b_ln
            sd[f'blocks.{l}.attn.b_K'] = W_K @ b_ln
            sd[f'blocks.{l}.attn.b_V'] = W_V @ b_ln


            W_O = neo.transformer.h[l].attn.attention.out_proj.weight
            W_O = einops.rearrange(W_O, 'm (i h)->i m h', i=self.cfg['n_heads'])
            sd[f"blocks.{l}.attn.W_O"] = W_O
            sd[f'blocks.{l}.attn.b_O'] = neo.transformer.h[l].attn.attention.out_proj.bias

            W_in = neo.transformer.h[l].mlp.c_fc.weight
            W_out = neo.transformer.h[l].mlp.c_proj.weight
            W_in_adj = neo.transformer.h[l].ln_2.weight[None, :] * W_in
            sd[f"blocks.{l}.mlp.W_in"] = W_in_adj
            sd[f"blocks.{l}.mlp.b_in"] = neo.transformer.h[l].mlp.c_fc.bias+(W_in @ neo.transformer.h[l].ln_2.bias)
            sd[f"blocks.{l}.mlp.W_out"] = W_out
            sd[f"blocks.{l}.mlp.b_out"] = neo.transformer.h[l].mlp.c_proj.bias
        W_U = neo.lm_head.weight
        sd['unembed.W_U'] = neo.transformer.ln_f.weight[None, :] * W_U
        sd['unembed.b_U'] = neo.lm_head.weight @ neo.transformer.ln_f.bias
        self.load_state_dict(sd)
    
    def load_neox_weights(self, neox):
        raise NotImplementedError
    
    def load_gptj_weights(self, gptj):
        raise NotImplementedError
    
    def load_opt_weights(self, opt):
        sd = self.state_dict()

        sd['embed.W_E'] = opt.model.decoder.embed_tokens.weight.T
        sd['pos_embed.W_pos'] = opt.model.decoder.embed_positions.weight.T[:, 2:]

        for l in range(self.cfg['n_layers']):
            w_ln_attn = opt.model.decoder.layers[l].self_attn_layer_norm.weight
            W_Q = opt.model.decoder.layers[l].self_attn.q_proj.weight
            W_K = opt.model.decoder.layers[l].self_attn.k_proj.weight
            W_V = opt.model.decoder.layers[l].self_attn.v_proj.weight
            W_Q = einops.rearrange(W_Q, '(index d_head) d_model->index d_head d_model', i=self.cfg['n_heads'])
            W_K = einops.rearrange(W_K, '(index d_head) d_model->index d_head d_model', i=self.cfg['n_heads'])
            W_V = einops.rearrange(W_V, '(index d_head) d_model->index d_head d_model', i=self.cfg['n_heads'])
            
            sd[f"blocks.{l}.attn.W_Q"] = W_Q * w_ln_attn
            sd[f"blocks.{l}.attn.W_K"] = W_K * w_ln_attn
            sd[f"blocks.{l}.attn.W_V"] = W_V * w_ln_attn

            b_ln = opt.model.decoder.layers[l].self_attn_layer_norm.bias
            q_bias = einops.rearrange(opt.model.decoder.layers[l].self_attn.q_proj.bias, '(head_index d_head)->head_index d_head', head_index=self.cfg['n_heads'], d_head=self.cfg['d_head'])
            k_bias = einops.rearrange(opt.model.decoder.layers[l].self_attn.k_proj.bias, '(head_index d_head)->head_index d_head', head_index=self.cfg['n_heads'], d_head=self.cfg['d_head'])
            v_bias = einops.rearrange(opt.model.decoder.layers[l].self_attn.v_proj.bias, '(head_index d_head)->head_index d_head', head_index=self.cfg['n_heads'], d_head=self.cfg['d_head'])

            sd[f'blocks.{l}.attn.b_Q'] = W_Q @ b_ln + q_bias 
            sd[f'blocks.{l}.attn.b_K'] = W_K @ b_ln + k_bias
            sd[f'blocks.{l}.attn.b_V'] = W_V @ b_ln + v_bias

            W_O = opt.model.decoder.layers[l].self_attn.out_proj.weight
            W_O = einops.rearrange(W_O, 'd_model (index d_head)->index d_model d_head', i=self.cfg['n_heads'])
            sd[f"blocks.{l}.attn.W_O"] = W_O
            sd[f'blocks.{l}.attn.b_O'] = opt.model.decoder.layers[l].self_attn.out_proj.bias

            W_in = opt.model.decoder.layers[l].fc1.weight
            W_out = opt.model.decoder.layers[l].fc2.weight
            W_in_adj = opt.model.decoder.layers[l].final_layer_norm.weight[None, :] * W_in
            sd[f"blocks.{l}.mlp.W_in"] = W_in_adj
            sd[f"blocks.{l}.mlp.b_in"] = opt.model.decoder.layers[l].fc1.bias+(W_in @ opt.model.decoder.layers[l].final_layer_norm.bias)
            sd[f"blocks.{l}.mlp.W_out"] = W_out
            sd[f"blocks.{l}.mlp.b_out"] = opt.model.decoder.layers[l].fc2.bias
        W_U = opt.lm_head.weight
        sd['unembed.W_U'] = opt.model.decoder.final_layer_norm.weight[None, :] * W_U
        sd['unembed.b_U'] = W_U @ opt.model.decoder.final_layer_norm.bias
        self.load_state_dict(sd)

    def load_bloom_weights(self, bloom):
        raise NotImplementedError

#Discussion

##Model Simplifications


###Centering $W_U$

The output of $W_U$ is a $d_{vocab}$ vector (or tensor with that as the final dimension) which is fed into a softmax

###LayerNorm Folding


LayerNorm is only applied at the start of a linear layer reading from the residual stream (eg query, key, value, mlp_in or unembed calculations)

Each LayerNorm has the functional form $LN:\mathbb{R}^n\to\mathbb{R}^n$, 
$LN(x)=s(x) * w_{ln} + b_{ln}$, where $*$ is element-wise multiply and $s(x)=\frac{x-\bar{x}}{|x-\bar{x}|}$, and $w_{ln},b_{ln}$ are both vectors in $\mathbb{R}^n$

The linear layer has form $l:\mathbb{R}^n\to\mathbb{R}^m$, $l(y)=Wy+b$ where $W\in \mathbb{R}^{m\times n},b\in \mathbb{R}^m,y\in\mathbb{R}^n$

So $f(LN(x))=W(w_{ln} * s(x)+b_{ln})+b=(W * w_{ln})s(x)+(Wb_{ln}+b)=W_{eff}s(x)+b_{eff}$, where $W_{eff}$ is the elementwise product of $W$ and $w_{ln}$ (showing that elementwise multiplication commutes like this is left as an exercise) and $b_{eff}=Wb_{ln}+b\in \mathbb{R}^m$.

From the perspective of interpretability, it's much nicer to interpret the folded layer $W_{eff},b_{eff}$ - fundamentally, this is the computation being done, and there's no reason to expect $W$ or $w_{ln}$ to be meaningful on their own. 


##Tips for Running Large Models

The current library does not support models on multiple GPUs, which makes running the models too large for one GPU much harder - if you're interested in adding this functionality, I'd love to chat!

Tips for getting the most out of your GPU:
* Use the `torch.no_grad` context manager or `torch.set_grad_enabled(False)` global to turn off AutoGrad - AutoGrad automatically caches all activations to use when computing gradients, which can consume a lot of GPU memory
* Cast your weights to bfloat16 or FP16 before moving to the GPU - this halves the memory footprint
** I expect there are a bunch of ways to use mixed precision to achieve this without sacrificing accuracy - though for the purposes of interpretability I expect it to not matter a ton.
* When cacheing activations, copy them to the CPU - this is *much* slower, but the CPU should have a lot more memory
* General rule of thumb - your GPU is more memory limited than runtime limited, and can fill up very quickly - it's often worth deleting activations that can be easily recomputed.
* It's fairly easy to get 16GB RAM GPUs (eg Colab Pro+). If you want to get more GPU RAM than this, you can rent A100 40GB RAM GPUs in a bunch of places (AWS, GCP, Paperspace, etc). After a cursory search, the only place I've found A100 80GB RAM GPUs for rent is [runpod.io](https://runpod.io)

##Model Details


* GPT-2 and GPT-Neo scale attention scores ($q\cdot k$) down by $\sqrt{d_{head}}$ before softmax, OPT does not
* OPT has ReLU activations, GPT-2 and GPT-Neo have GeLU
* In GPT-Neo, every other attention layer uses local attention (heads can only attend back a fixed number of tokens, compared to dense attention where heads can attend to any prior token)
* All models use learned positional embeddings

#Examples

##Setup

Load in GPT-2 small

In [None]:
model_name = 'gpt2' #@param ['gpt2', 'gpt2-medium', 'gpt2-large', 'gpt2-xl', 'facebook/opt-125m', 'facebook/opt-1.3b', 'facebook/opt-2.7b', 'facebook/opt-6.7b', 'facebook/opt-13b', 'facebook/opt-30b', 'facebook/opt-66b', 'EleutherAI/gpt-neo-125M', 'EleutherAI/gpt-neo-1.3B', 'EleutherAI/gpt-neo-2.7B', 'EleutherAI/gpt-j-6B', 'EleutherAI/gpt-neox-20b']
model = EasyTransformer(model_name)
if torch.cuda.is_available():
    model.to('cuda')

Downloading config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/523M [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading tokenizer.json:   0%|          | 0.00/1.29M [00:00<?, ?B/s]

Create some reference text to run the models on.

In [None]:
prompt = 'Interpretability is great'
tokens = model.to_tokens(prompt)
prompt_2 = 'AI Alignment is great'
tokens_2 = model.to_tokens(prompt_2)
def show_tokens(tokens):
    # Prints the tokens as text, separated by |
    if type(tokens)==str:
        # If we input text, tokenize first
        tokens = model.to_tokens(tokens)
    text_tokens = [model.tokenizer.decode(t) for t in tokens.squeeze()]
    print('|'.join(text_tokens))
show_tokens(tokens)
show_tokens(tokens_2)

Inter|pret|ability| is| great
AI| Al|ignment| is| great


In [None]:
model.reset_hooks()
original_logits = model(tokens)
print('Top corner of logits')
print(get_corner(original_logits, 4))

Top corner of logits
tensor([[[-5.0197, -4.0007, -6.4540, -6.6005],
         [-4.1477, -2.2966, -7.4325, -6.9754],
         [-2.7587,  1.3903, -4.3042, -6.5975],
         [-6.1653, -5.0678, -9.1324, -9.5672]]], device='cuda:0',
       grad_fn=<SliceBackward0>)


In [None]:
print('Reference: Hyperparameters for the model')
for hyper_param in (model.cfg):
    print(hyper_param, model.cfg[hyper_param])

Reference: Hyperparameters for the model
d_model 768
d_head 64
n_heads 12
d_mlp 3072
n_layers 12
n_ctx 1024
eps 1e-05
d_vocab 50257
act_fn gelu_new
use_attn_scale True
use_local_attn False
model_name gpt2
model_type gpt2
use_attn_result False


##Using the model

The model can be given either text or tokens as an input (text is automatically converted to a `batch_size=1` batch of tokens)

In [None]:
logits_tokens = model(tokens)
logits_text = model(prompt)

The model gives the same log_probs as the original Hugging Face model 

Though *not* the same logits, as we remove a constant offset from $W_U$

In [None]:
original_model = AutoModelForCausalLM.from_pretrained(model_name)
easy_logits = model(tokens).cpu()
original_model_logits = original_model(tokens).logits

easy_log_probs = F.log_softmax(easy_logits, dim=-1)
original_model_log_probs = F.log_softmax(original_model_logits, dim=-1)

print('Fraction of log probs the same between easy model and original model:')
print(torch.isclose(original_model_log_probs, easy_log_probs).sum()/easy_log_probs.numel())
print('Fraction of logits the same between easy model and original model:')
print(torch.isclose(original_model_logits, easy_logits).sum()/easy_logits.numel())

Fraction of log probs the same between easy model and original model:
tensor(1.0000)
Fraction of logits the same between easy model and original model:
tensor(0.)


##Basic Examples

Print the shapes of all activations

**Note:** This cell is a good reference for creating hooks - it's extremely useful to know the shapes of different activations as accessible by each hook!

By convention, each activation is batch x position x ... (where the final dimension(s) is d_model, (head_index x d_head) or d_mlp). The one exception is hook_attn (attention patterns) which has shape batch x head_index x query_pos x key_pos

**Reference:**
`batch_size=4
n_ctx=50
d_head=64
d_model=768
d_mlp=3072
n_heads=12
n_layers=12`

In [None]:
all_hooks_fn = lambda name: True
def print_shape(tensor, hook):
    print(f'Activation at hook {hook.name} has shape:')
    print(tensor.shape)
random_tokens = torch.randint(1000, 10000, (4, 50))
logits = model.run_with_hooks(random_tokens, fwd_hooks=[(all_hooks_fn, print_shape)])

Activation at hook hook_embed has shape:
torch.Size([4, 50, 768])
Activation at hook hook_pos_embed has shape:
torch.Size([50, 768])
Activation at hook blocks.0.hook_resid_pre has shape:
torch.Size([4, 50, 768])
Activation at hook blocks.0.ln1.hook_scale has shape:
torch.Size([4, 50, 1])
Activation at hook blocks.0.attn.hook_q has shape:
torch.Size([4, 50, 12, 64])
Activation at hook blocks.0.attn.hook_k has shape:
torch.Size([4, 50, 12, 64])
Activation at hook blocks.0.attn.hook_v has shape:
torch.Size([4, 50, 12, 64])
Activation at hook blocks.0.attn.hook_attn_scores has shape:
torch.Size([4, 12, 50, 50])
Activation at hook blocks.0.attn.hook_attn has shape:
torch.Size([4, 12, 50, 50])
Activation at hook blocks.0.attn.hook_z has shape:
torch.Size([4, 50, 12, 64])
Activation at hook blocks.0.hook_attn_out has shape:
torch.Size([4, 50, 768])
Activation at hook blocks.0.hook_resid_mid has shape:
torch.Size([4, 50, 768])
Activation at hook blocks.0.ln2.hook_scale has shape:
torch.Size([4

Print the top corner of all activations

**Note:** This is useful to do as a sanity check when debugging a model, to quickly and roughly compare the new activations to the original activations (without looking at the full enormous tensors)

In [None]:
def print_corner(tensor, hook):
    print(hook.name)
    print(get_corner(tensor))
logits = model.run_with_hooks(tokens, fwd_hooks=[(all_hooks_fn, print_corner)])

hook_embed
tensor([[[ 0.1600, -0.1444],
         [-0.0406, -0.2098]]], device='cuda:0', grad_fn=<SliceBackward0>)
hook_pos_embed
tensor([[-0.0134, -0.1920],
        [ 0.0250, -0.0528]], device='cuda:0', grad_fn=<SliceBackward0>)
blocks.0.hook_resid_pre
tensor([[[ 0.1466, -0.3363],
         [-0.0156, -0.2626]]], device='cuda:0', grad_fn=<SliceBackward0>)
blocks.0.ln1.hook_scale
tensor([[[0.3703],
         [0.2421]]], device='cuda:0', grad_fn=<SliceBackward0>)
blocks.0.attn.hook_q
tensor([[[[-0.6830,  0.1875],
          [ 0.5510,  0.1701]],

         [[ 0.4226,  0.8636],
          [ 0.1361, -0.6476]]]], device='cuda:0', grad_fn=<SliceBackward0>)
blocks.0.attn.hook_k
tensor([[[[-1.1447,  2.1864],
          [ 1.4646,  0.4051]],

         [[-1.5435,  2.9672],
          [ 0.9403, -1.4134]]]], device='cuda:0', grad_fn=<SliceBackward0>)
blocks.0.attn.hook_v
tensor([[[[-0.0110,  0.0460],
          [ 0.4635,  0.0313]],

         [[ 0.1250, -0.3105],
          [ 0.3096,  0.2382]]]], device='cuda:

Cache all activations


In [None]:
cache = {}
model.reset_hooks()
model.cache_all(cache)
logits = model(tokens)
for name in cache:
    print(name, cache[name].shape)
model.reset_hooks()

hook_embed torch.Size([1, 5, 768])
hook_pos_embed torch.Size([5, 768])
blocks.0.hook_resid_pre torch.Size([1, 5, 768])
blocks.0.ln1.hook_scale torch.Size([1, 5, 1])
blocks.0.attn.hook_q torch.Size([1, 5, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 5, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 5, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 5, 5])
blocks.0.attn.hook_attn torch.Size([1, 12, 5, 5])
blocks.0.attn.hook_z torch.Size([1, 5, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 5, 768])
blocks.0.hook_resid_mid torch.Size([1, 5, 768])
blocks.0.ln2.hook_scale torch.Size([1, 5, 1])
blocks.0.mlp.hook_pre torch.Size([1, 5, 3072])
blocks.0.mlp.hook_post torch.Size([1, 5, 3072])
blocks.0.hook_mlp_out torch.Size([1, 5, 768])
blocks.0.hook_resid_post torch.Size([1, 5, 768])
blocks.1.hook_resid_pre torch.Size([1, 5, 768])
blocks.1.ln1.hook_scale torch.Size([1, 5, 1])
blocks.1.attn.hook_q torch.Size([1, 5, 12, 64])
blocks.1.attn.hook_k torch.Size([1, 5, 12, 64])
blocks.1.attn.h

To save GPU memory, we can cache activations to the CPU - note that this is much slower though, since it requires copying.

In [None]:
random_tokens = torch.randint(1000, 10000, (1, 300))
cache = {}
model.reset_hooks()
model.cache_all(cache, device='cpu')
print('Run time when copying to the CPU')
%timeit logits = model(random_tokens)
model.reset_hooks()
model.cache_all(cache, device='cuda')
print('Run time when just caching on GPU')
%timeit logits = model(random_tokens)

Run time when copying to the CPU
1 loop, best of 5: 136 ms per loop
Run time when just caching on GPU
10 loops, best of 5: 31.3 ms per loop


##Editing Activations
**To change an activation, add a hook to that HookPoint which returns the new activation**

Pruning attention heads

In [None]:
# Example - prune heads 0, 3 and 7 from layer 3 and heads 8 and 9 from layer 7
layer = 3
head_indices = torch.tensor([0, 3, 7])
layer_2 = 7
head_indices_2 = torch.tensor([8, 9])
def prune_fn_1(z, hook):
    # The shape of the z tensor is batch x pos x head_index x d_head
    z[:, :, head_indices, :] = 0.
    return z
def prune_fn_2(z, hook):
    # The shape of the z tensor is batch x pos x head_index x d_head
    z[:, :, head_indices_2, :] = 0.
    return z
logits = model.run_with_hooks(tokens, fwd_hooks=[(f'blocks.{layer}.attn.hook_z', prune_fn_1),
                                                       (f'blocks.{layer_2}.attn.hook_z', prune_fn_2)])

Restrict all attention heads to only attend to the current and previous token.

**Validation:** The logits for the first 2 positions are the same, the logits for pos 3 are different

In [None]:
model.reset_hooks()
def filter_hook_attn(name):
    split_name = name.split('.')
    return (split_name[-1]=='hook_attn')
def restrict_attn(attn, hook):
    # Attn has shape batch x head_index x query_pos x key_pos
    n_ctx = attn.size(-2)
    key_pos = torch.arange(n_ctx)[None, :]
    query_pos = torch.arange(n_ctx)[:, None]
    mask = (key_pos>(query_pos-2)).cuda()
    ZERO = torch.tensor(0.)
    if torch.cuda.is_available():
        ZERO = ZERO.cuda()
    attn = torch.where(mask, attn, ZERO)
    return attn
logits = model.run_with_hooks(tokens, fwd_hooks=[(filter_hook_attn, restrict_attn)])
print('New logits')
print(get_corner(logits, 3))
print('Original logits')
print(get_corner(original_logits, 3))

New logits
tensor([[[-5.0197, -4.0007, -6.4540],
         [-4.1477, -2.2966, -7.4325],
         [-2.6209,  2.9651, -4.4607]]], device='cuda:0',
       grad_fn=<SliceBackward0>)
Original logits
tensor([[[-5.0197, -4.0007, -6.4540],
         [-4.1477, -2.2966, -7.4325],
         [-2.7587,  1.3903, -4.3042]]], device='cuda:0',
       grad_fn=<SliceBackward0>)


Freezing attention patterns - here we do two runs of the model. First on the original text, caching attn patterns, and secondly on the new text, loading the cached patterns


In [None]:
attn_cache = {}
def cache_attn(attn, hook):
    attn_cache[hook.name]=attn

def freeze_attn(attn, hook):
    return attn_cache[hook.name]

logits = model.run_with_hooks(tokens, fwd_hooks=[(filter_hook_attn, cache_attn)])

logits_2 = model.run_with_hooks(tokens_2, fwd_hooks=[(filter_hook_attn, freeze_attn)])

##Using Hook Contexts

**Each hook point has a dictionary `hook.ctx` that can be used to store information between runs** - this is useful for keeping running totals, etc 

A running total of times a neuron activation was positive


In [None]:
# We focus on neuron 20 in layer 7
model.reset_hooks()
animal_texts = ['The dog was green', 'The cat was blue', 'The squid was magenta', 'The blobfish was grey']
layer = 7
neuron_index = 20
def running_total_hook(neuron_acts, hook):
    if 'total' not in hook.ctx:
        hook.ctx['total']=0
    print('Neuron acts:', neuron_acts[0, :, neuron_index])
    hook.ctx['total']+=(neuron_acts[0, :, neuron_index]>0).sum().item()
    print('Running total:', hook.ctx['total'])

for animal_text in animal_texts:
    show_tokens(animal_text)
    model.run_with_hooks(animal_text, fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', running_total_hook)])

The| dog| was| green
Neuron acts: tensor([-0.0099, -0.1396,  0.6828, -0.0826], device='cuda:0',
       grad_fn=<SelectBackward0>)
Running total: 1
The| cat| was| blue
Neuron acts: tensor([-0.0099, -0.1045,  0.6142, -0.0625], device='cuda:0',
       grad_fn=<SelectBackward0>)
Running total: 2
The| squid| was| mag|enta
Neuron acts: tensor([-0.0099,  0.7520,  0.7486, -0.0986, -0.0460], device='cuda:0',
       grad_fn=<SelectBackward0>)
Running total: 4
The| blob|fish| was| grey
Neuron acts: tensor([-0.0099,  0.0554,  0.4716,  0.7331, -0.0160], device='cuda:0',
       grad_fn=<SelectBackward0>)
Running total: 7


Finding the dataset example that most activates a given neuron


In [None]:
# We focus on neuron 13 in layer 5
model.reset_hooks(clear_contexts=True)
animal_texts = ['The dog was green', 'The cat was blue', 'The squid was magenta', 'The blobfish was grey']
layer = 5
neuron_index = 13
def best_act_hook(neuron_acts, hook, text):
    if 'best' not in hook.ctx:
        hook.ctx['best']=-1e3
    print('Neuron acts:', neuron_acts[0, :, neuron_index])
    if hook.ctx['best']<neuron_acts[0, :, neuron_index].max():
        print(f'Updating best act from {hook.ctx["best"]} to {neuron_acts[0, :, neuron_index].max().item()}')
        hook.ctx['best'] = neuron_acts[0, :, neuron_index].max().item()
        hook.ctx['text'] = text

for animal_text in animal_texts:
    (show_tokens(animal_text))
    # Use partial to give the hook access to the relevant text
    model.run_with_hooks(animal_text, fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', partial(best_act_hook, text=animal_text))])
print()
print('Maximally activating dataset example:', model.hook_dict[f'blocks.{layer}.mlp.hook_post'].ctx['text'])
model.reset_hooks(clear_contexts=True)

The| dog| was| green
Neuron acts: tensor([-0.0074, -0.1690,  0.0724,  0.0520], device='cuda:0',
       grad_fn=<SelectBackward0>)
Updating best act from -1000.0 to 0.07240154594182968
The| cat| was| blue
Neuron acts: tensor([-0.0074, -0.1681,  0.1947,  0.0884], device='cuda:0',
       grad_fn=<SelectBackward0>)
Updating best act from 0.07240154594182968 to 0.19472447037696838
The| squid| was| mag|enta
Neuron acts: tensor([-0.0074, -0.1546,  0.0558, -0.1591, -0.1391], device='cuda:0',
       grad_fn=<SelectBackward0>)
The| blob|fish| was| grey
Neuron acts: tensor([-0.0074, -0.1700,  0.0445,  0.1009, -0.0254], device='cuda:0',
       grad_fn=<SelectBackward0>)

Maximally activating dataset example: The cat was blue


##Fancier Examples

Looking for heads that mostly attend to the previous token


In [None]:
long_text = 'Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.'
print('Long text:', long_text)
# We first cache attention patterns
attn_cache = {}
def cache_attn(attn, hook):
    attn_cache[hook.name]=attn
logits = model.run_with_hooks(long_text, fwd_hooks=[(filter_hook_attn, cache_attn)])

# We then go through the cache and find the average attention paid to previous tokens
prev_token_scores = np.zeros((model.cfg['n_layers'], model.cfg['n_heads']))
for layer in range(model.cfg['n_layers']):
    for head in range(model.cfg['n_heads']):
        attn = attn_cache[f"blocks.{layer}.attn.hook_attn"][0, head]
        prev_token_scores[layer, head]=attn.diag(-1).mean().item()

px.imshow(prev_token_scores, 
          x=[f'Head {hi}' for hi in range(model.cfg['n_heads'])], 
          y=[f'Layer {i}' for i in range(model.cfg['n_layers'])], 
          title='Prev Token Scores', 
          color_continuous_scale='Blues')

Long text: Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum.


[ROME style](https://rome.baulab.info/) patching for causal tracing - we have two runs with two different prompts and different answers, eg "Steve Jobs founded" -> " Apple" and "Bill Gates founded" -> " Microsoft". We patch parts of the layer outputs or residual stream from specific tokens and positions and see which patches significantly shift the answer from " Apple" to " Microsoft"

In [None]:
prompt_1 = 'Bill Gates founded'
response_1 = ' Microsoft'
logit_index_1 = model.to_tokens(response_1)[0][-1]
(show_tokens(prompt_1))
prompt_2 = 'Steve Jobs founded'
response_2 = ' Apple'
logit_index_2 = model.to_tokens(response_2)[0][-1]
show_tokens(prompt_2)

model.reset_hooks()
uncorrupted_cache = {}
model.cache_all(uncorrupted_cache)
logits_1 = model(prompt_1)
model.reset_hooks()

uncorrupted_logits = model(prompt_2)
uncorrupted_log_probs = F.log_softmax(uncorrupted_logits, dim=-1)
print('Uncorrupted log prob for', response_1, uncorrupted_log_probs[0, -1, logit_index_1].item())
print('Uncorrupted log prob for', response_2, uncorrupted_log_probs[0, -1, logit_index_2].item())

# Patch the residual stream from the Bill Gates run to the Steve Jobs run
# at the Jobs/Gates token, at the start of layer 7
layer = 7
position = 1

def patch_resid_pre(resid_pre, hook):
    uncorrupted_resid_pre = uncorrupted_cache[hook.name]
    # Move things on the Jobs/Gates token
    resid_pre[:, position] = uncorrupted_resid_pre[:, position]
    return resid_pre

corrupted_logits = model.run_with_hooks(prompt_2, 
                    fwd_hooks=[(f'blocks.{layer}.hook_resid_pre', patch_resid_pre)])
corrupted_log_probs = F.log_softmax(corrupted_logits, dim=-1)
print('Corrupted (Residual) log prob for', response_1, corrupted_log_probs[0, -1, logit_index_1].item())
print('Corrupted (Residual) log prob for', response_2, corrupted_log_probs[0, -1, logit_index_2].item())

Bill| Gates| founded
Steve| Jobs| founded
Uncorrupted log prob for  Microsoft -2.890328884124756
Uncorrupted log prob for  Apple -0.5577379465103149
Corrupted (Residual) log prob for  Microsoft -0.5379166603088379
Corrupted (Residual) log prob for  Apple -4.753259181976318


We can also patch the outputs of MLP layers 0 to 7 on the Gates/Jobs token - this time, rather than giving a hook name, we give a Boolean function that filters for the names of those hooks.

In [None]:
layer_start = 0
layer_end = 7

def patch_mlp_post(mlp_post, hook):
    return uncorrupted_cache[hook.name]

def filter_middle_mlps(name):
    split_name = name.split('.')
    if split_name[-1]=='hook_post':
        layer = int(split_name[1])
        return (layer_start<=layer<layer_end)
    return False

corrupted_logits = model.run_with_hooks(prompt_2, 
                    fwd_hooks=[(filter_middle_mlps, patch_mlp_post)])
corrupted_log_probs = F.log_softmax(corrupted_logits, dim=-1)
print('Corrupted (MLP) log prob for', response_1, corrupted_log_probs[0, -1, logit_index_1].item())
print('Corrupted (MLP) log prob for', response_2, corrupted_log_probs[0, -1, logit_index_2].item())

Corrupted (MLP) log prob for  Microsoft -1.2620229721069336
Corrupted (MLP) log prob for  Apple -3.783123016357422


Looking for [induction heads](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html), by feeding in a random sequence of tokens repeated twice and looking for heads that attend from a second copy of a token to the token just after the first copy.

In [None]:
seq_len = 100
rand_tokens = torch.randint(1000, 10000, (4, seq_len))
rand_tokens_repeat = einops.repeat(rand_tokens, 'batch pos -> batch (2 pos)')
if torch.cuda.is_available():
    rand_tokens_repeat = rand_tokens_repeat.cuda()

induction_scores_array = np.zeros((model.cfg['n_layers'], model.cfg['n_heads']))
def calc_induction_score(attn_pattern, hook):
    # Pattern has shape [batch, index, query_pos, key_pos]
    induction_stripe = attn_pattern.diagonal(1-seq_len, dim1=-2, dim2=-1)
    induction_scores = einops.reduce(induction_stripe, 'batch index pos -> index', 'mean')
    # Store the scores in a common array
    induction_scores_array[hook.layer()] = induction_scores.detach().cpu().numpy()
    
def filter_attn_hooks(hook_name):
    split_name = hook_name.split('.')
    return split_name[-1]=='hook_attn'

induction_logits = model.run_with_hooks(rand_tokens_repeat, fwd_hooks=[(filter_attn_hooks, calc_induction_score)])
px.imshow(induction_scores_array, labels={'y':'Layer', 'x':'Head'}, color_continuous_scale='Blues')

**Validation:** We can ablate the top few heads by this metric, and show that performance goes down substantially

In [None]:
induction_logits = model(rand_tokens_repeat)
induction_log_probs = F.log_softmax(induction_logits, dim=-1)
induction_pred_log_probs = torch.gather(induction_log_probs[:, :-1], -1, rand_tokens_repeat[:, 1:, None])[..., 0]
print('Original loss on repeated sequence:', induction_pred_log_probs[:, seq_len:].mean())

# Mask out the heads with a high induction score
attn_head_mask = induction_scores_array>0.8

def prune_attn_heads(value, hook):
    # Value has shape [batch, pos, index, d_head]
    mask = attn_head_mask[hook.layer()]
    value[:, :, mask] = 0.
    return value

def filter_value_hooks(name):
    return name.split('.')[-1]=='hook_v'

ablated_logits = model.run_with_hooks(rand_tokens_repeat, fwd_hooks=[(filter_value_hooks, prune_attn_heads)])
ablated_log_probs = F.log_softmax(ablated_logits, dim=-1)
ablated_pred_log_probs = torch.gather(ablated_log_probs[:, :-1], -1, rand_tokens_repeat[:, 1:, None])[..., 0]
print('Loss on repeated sequence without induction heads:', ablated_pred_log_probs[:, seq_len:].mean())

px.imshow(attn_head_mask, labels={'y':'Layer', 'x':'Head'}, color_continuous_scale='Blues', title='Mask').show()

Original loss on repeated sequence: tensor(-0.1070, device='cuda:0', grad_fn=<MeanBackward0>)
Loss on repeated sequence without induction heads: tensor(-6.2134, device='cuda:0', grad_fn=<MeanBackward0>)
