In [1]:
%load_ext autoreload
%autoreload 2
from transformer_lens import HookedTransformer, ActivationCache
import torch
import numpy as np
import pandas as pd
import datasets
import transformers
import pickle

from tasks import PileTask, OWTTask, InductionTask, GreaterThanTask
from tasks.ioi.IOITask import IOITask, IOITask_NPO, IOITask_Uniform
from tasks.induction.InductionTask import InductionTask, InductionTask_NPO, InductionTask_Uniform
from tasks.facts.SportsTask import SportsTask, SportsTask_NPO, SportsTask_Uniform

from tqdm.auto import tqdm

In [2]:
from transformers import GPT2Tokenizer, GPTNeoXTokenizerFast, AutoModelForCausalLM, AutoTokenizer
model_type = "gemma-7b"

if model_type == "pythia":
    reference_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-2.8B")#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8B")
    tokenizer.pad_token_id = tokenizer.eos_token_id

elif model_type == "gemma-7b":
    reference_model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.bfloat16)#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [3]:
reference_model.model.layers[0].self_attn.q_proj.weight.shape

torch.Size([4096, 3072])

In [4]:
tl_model = HookedTransformer.from_pretrained(
    'google/gemma-7b',
    tokenizer=tokenizer,
    device='cuda',
    default_padding_side="right",
    fold_ln=False,
    fold_value_biases=False,
    center_writing_weights=False,
    dtype=torch.bfloat16
)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loaded pretrained model google/gemma-7b into HookedTransformer


In [5]:
tl_model.blocks[0].attn.W_Q.shape

torch.Size([16, 3072, 256])

In [6]:
tl_model.get_caching_hooks()

({},
 [('hook_embed',
   functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7f80ff653b50>, is_backward=False)),
  ('blocks.0.ln1.hook_scale',
   functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7f80ff653b50>, is_backward=False)),
  ('blocks.0.ln1.hook_normalized',
   functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7f80ff653b50>, is_backward=False)),
  ('blocks.0.ln2.hook_scale',
   functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7f80ff653b50>, is_backward=False)),
  ('blocks.0.ln2.hook_normalized',
   functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7f80ff653b50>, is_backward=False)),
  ('blocks.0.attn.hook_k',
   functools.partial(<function HookedRootModule.get_caching_hooks.<locals>.save_hook at 0x7f80ff653b50>, is_backward=False)),
  ('blocks.0.attn.hook_q',
   functools.partial(<function HookedR

In [7]:
tl_model.blocks[5].attn

GroupedQueryAttention(
  (hook_k): HookPoint()
  (hook_q): HookPoint()
  (hook_v): HookPoint()
  (hook_z): HookPoint()
  (hook_attn_scores): HookPoint()
  (hook_pattern): HookPoint()
  (hook_result): HookPoint()
  (hook_rot_k): HookPoint()
  (hook_rot_q): HookPoint()
)

In [8]:
tl_model.blocks[5].attn.W_Q

Parameter containing:
tensor([[[-5.3101e-03, -1.3580e-03,  4.0054e-04,  ...,  1.6113e-02,
           4.5166e-03, -8.7891e-03],
         [-3.5706e-03, -2.3651e-03, -4.5166e-03,  ..., -7.8735e-03,
           5.9814e-03,  1.9150e-03],
         [ 3.5477e-04,  4.2725e-03, -1.8158e-03,  ...,  3.7079e-03,
           2.0142e-02,  8.4839e-03],
         ...,
         [ 4.1199e-03,  1.4572e-03, -2.0905e-03,  ...,  8.4229e-03,
          -8.3008e-03, -8.4229e-03],
         [-1.8692e-04, -5.7220e-04,  4.3030e-03,  ..., -3.4180e-02,
          -1.2939e-02, -1.0864e-02],
         [ 1.7929e-03, -6.2561e-03, -2.7924e-03,  ..., -4.6692e-03,
          -3.2227e-02, -2.6489e-02]],

        [[-4.4556e-03, -2.5868e-05, -6.1798e-04,  ..., -6.9427e-04,
           1.1780e-02,  1.0681e-02],
         [ 5.6076e-04,  7.0190e-04, -3.1891e-03,  ..., -5.1270e-03,
           1.2756e-02, -3.5706e-03],
         [ 6.7139e-03, -5.8289e-03, -1.6174e-03,  ..., -3.0212e-03,
           5.2185e-03,  1.0620e-02],
         ...,
   

In [9]:
print(tl_model.blocks[27].attn.W_K.shape)
print(tl_model.blocks[27].attn.W_O.shape)
print(tl_model.blocks[27].mlp.W_in.shape)
print(tl_model.blocks[27].mlp.W_out.shape)

torch.Size([16, 3072, 256])
torch.Size([16, 256, 3072])
torch.Size([3072, 24576])
torch.Size([24576, 3072])


In [10]:
import pickle
with open("models/google_gemma-7b_sports_baseball_ap_graph.pkl", "rb") as f:
    ap_graph = pickle.load(f)
for component in ap_graph:
    if "m" in component:
        print(f"{component}: {ap_graph[component]}")

m0: -0.055645283311605453
m1: 0.0552063025534153
m2: -0.11303359270095825
m3: 0.024208657443523407
m4: -0.0113525390625
m5: -0.022718576714396477
m6: -0.007042518351227045
m7: -0.021432731300592422
m8: -0.006188026163727045
m9: -0.0019231943879276514
m10: -0.03130634129047394
m11: -0.0708770751953125
m12: -0.04879526048898697
m13: 0.04687969759106636
m14: 0.035638369619846344
m15: 0.02344219572842121
m16: -0.018733099102973938
m17: -0.06601186841726303
m18: -0.10868014395236969
m19: -0.0049954927526414394
m20: -0.08375901728868484
m21: -0.25811299681663513
m22: -0.11271785199642181
m23: -0.23159556090831757
m24: -0.09144005924463272
m25: -0.08293269574642181
m26: 0.18458910286426544
m27: 0.5513822436332703


## Approach 1: apply masks before doing any calculations
this failed

In [30]:
from torch import nn

def make_partly_differentiable_mask(W, unfrozen_heads, device="cuda"):
    """
    W is Parameter of shape (n_heads, ...). Returns baseline and frozen (both only 1d arrays of (n_heads,)), and forward pass should be W_baseline.float() + W_frozen.float() * W 
    """
    W_baseline = torch.nn.Parameter(torch.zeros(W.shape[0], dtype=torch.bool), requires_grad=False).to(device)

    # unsqueeze to broadcast efficiently, until W_baseline has same shape as W
    while len(W_baseline.shape) < len(W.shape):
        W_baseline = W_baseline.unsqueeze(-1)
    
    W_baseline[unfrozen_heads] = True
    # W_baseline = ~W_frozen
    W_frozen = torch.nn.Parameter(~W_baseline, requires_grad=False)
    # convert into float
    return W_frozen.float(), W_baseline.float()

class WeightMaskedTransformer(nn.Module):
    def __init__(self, tl_transformer, weight_mask_attn_dict=None, weight_mask_mlp_dict=None, torch_dtype=torch.bfloat16):
        """
        weight_mask_attn_dict: {layer: {"W_Q": unfrozen_heads, "W_K": unfrozen_heads, "W_V": unfrozen_heads, "W_O": unfrozen_heads}} (frozen_heads is shape (n_heads,) of bools). If none, train mask over all heads
        weight_mask_mlp_dict: {layer: bool}. If none, train mask over all mlps

        """
        super().__init__()
        self.torch_dtype = torch_dtype
        # tl_transformer should be a HookedTransformer
        self.tl_transformer = tl_transformer
        # turn off gradients for tl_transformer
        # for param in self.tl_transformer.parameters():
        #     param.requires_grad = False

        self.weight_mask_attn_dict = weight_mask_attn_dict
        self.weight_mask_mlp_dict = weight_mask_mlp_dict
        # store weight masks for every component that is unfrozen
        
        # need to store reference weights so that you can reset W_Q, etc after a forward pass
        self.reference_attn_weights = {}
        self.reference_mlp_weights = {}

        self.attention_masks = {}
        self.mlp_masks = {}
        for layer in range(self.tl_transformer.cfg.n_layers):
            self.attention_masks[layer] = {}
            self.reference_attn_weights[layer] = {}
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None:
                    unfrozen_heads = list(range(self.tl_transformer.cfg.n_heads)) # all heads are unfrozen
                else:
                    unfrozen_heads = self.weight_mask_attn_dict[layer][component]
                # make frozen and baseline masks, and also a copy of the original weights

                if len(unfrozen_heads) > 0:
                    W_frozen, W_baseline = make_partly_differentiable_mask(parameter, unfrozen_heads)
                    weight_mask = nn.Parameter(torch.ones_like(parameter).type(torch_dtype), requires_grad=True)
                    
                    self.attention_masks[layer][component] = (W_frozen, W_baseline, weight_mask)
                    self.reference_attn_weights[layer][component] = parameter.clone()

            if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer]:
                in_weight_mask = nn.Parameter(torch.ones_like(self.tl_transformer.blocks[layer].mlp.W_in).type(torch_dtype), requires_grad=True)
                out_weight_mask = nn.Parameter(torch.ones_like(self.tl_transformer.blocks[layer].mlp.W_out).type(torch_dtype), requires_grad=True)

                self.mlp_masks[layer] = (in_weight_mask, out_weight_mask)
                self.reference_mlp_weights[layer] = (self.tl_transformer.blocks[layer].mlp.W_in.clone(), self.tl_transformer.blocks[layer].mlp.W_out.clone())


    def forward(self, *args, **kwargs):
        for layer in range(self.tl_transformer.cfg.n_layers):
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:

                if self.weight_mask_attn_dict is None or len(self.attention_masks[layer]) > 0:
                    W_frozen, W_baseline, weight_mask = self.attention_masks[layer][component]
                    reference_data = self.reference_attn_weights[layer][component]
                    mask = W_baseline + W_frozen * weight_mask

                    # parameter = reference_data * mask
                    if component == "W_Q":
                        self.tl_transformer.blocks[layer].attn._W_Q.data = self.tl_transformer.blocks[layer].attn.W_Q * mask
                    elif component == "W_K":
                        self.tl_transformer.blocks[layer].attn._W_K.data = self.tl_transformer.blocks[layer].attn.W_K * mask
                    elif component == "W_V":
                        self.tl_transformer.blocks[layer].attn._W_V.data = self.tl_transformer.blocks[layer].attn.W_V * mask
                    elif component == "W_O":
                        self.tl_transformer.blocks[layer].attn._W_O.data = self.tl_transformer.blocks[layer].attn.W_O * mask

            if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer]:
                in_weight_mask, out_weight_mask = self.mlp_masks[layer]
                reference_in_data, reference_out_data = self.reference_mlp_weights[layer]
                # self.tl_transformer.blocks[layer].mlp.W_in = reference_in_data * in_weight_mask
                # self.tl_transformer.blocks[layer].mlp.W_out = reference_out_data * out_weight_mask
                self.tl_transformer.blocks[layer].mlp._W_in.data = self.tl_transformer.blocks[layer].mlp.W_in * in_weight_mask
                self.tl_transformer.blocks[layer].mlp._W_out.data = self.tl_transformer.blocks[layer].mlp.W_out * out_weight_mask
        
        return self.tl_transformer(*args, **kwargs)

        # go through all attention heads and multiply weights by partly-frozen masks
        # go through all mlps and multiply weights by masks
        


In [12]:
weight_mask_mlps = {layer: False for layer in range(tl_model.cfg.n_layers)}
for i in range(16):
    weight_mask_mlps[i] = True

weight_mask_attns = {layer: {"W_Q": [], "W_K": [], "W_V": [], "W_O": []} for layer in range(tl_model.cfg.n_layers)}
for i in range(8, 24):
    weight_mask_attns[i] = {"W_Q": list(range(4)), "W_K": list(range(4)), "W_V": list(range(4)), "W_O": list(range(4))}

print(torch.cuda.memory_allocated() // 1024**3)
wmt = WeightMaskedTransformer(tl_model, weight_mask_attn_dict=weight_mask_attns, weight_mask_mlp_dict=weight_mask_mlps)
print(torch.cuda.memory_allocated() // 1024**3)

19
31


In [31]:
sports_test = SportsTask(batch_size=64, tokenizer=tokenizer)
# print(sports_test.get_test_loss(tl_model))

with torch.autocast(device_type="cuda"):
    print(sports_test.get_test_loss(tl_model))
    print(sports_test.get_test_loss(wmt))

tensor(0.1858, device='cuda:0')
tensor(0.2540, device='cuda:0')


In [32]:
print(torch.cuda.memory_allocated() // 1024**3)
print(torch.cuda.max_memory_allocated() // 1024**3)

50
53


##

## Check that gradients flow properly

In [33]:
for param in tl_model.parameters():
    param.requires_grad = True

In [34]:
sports_train = SportsTask(batch_size=8, tokenizer=tokenizer)
with torch.autocast(device_type="cuda"):
    loss = sports_train.get_train_loss(tl_model, 1)
    print(loss)
    loss.backward()
    # loss = sports_train.get_train_loss(wmt, 1)
    # print(loss)
    # loss.backward()


tensor(0.0815, device='cuda:0', grad_fn=<DivBackward0>)


In [35]:
for name, param in tl_model.named_parameters():
    print(name, param.grad.shape)

embed.W_E torch.Size([256000, 3072])
blocks.0.ln1.w torch.Size([3072])
blocks.0.ln2.w torch.Size([3072])
blocks.0.attn.W_Q torch.Size([16, 3072, 256])
blocks.0.attn.W_O torch.Size([16, 256, 3072])
blocks.0.attn.b_Q torch.Size([16, 256])
blocks.0.attn.b_O torch.Size([3072])
blocks.0.attn._W_K torch.Size([16, 3072, 256])
blocks.0.attn._W_V torch.Size([16, 3072, 256])
blocks.0.attn._b_K torch.Size([16, 256])
blocks.0.attn._b_V torch.Size([16, 256])
blocks.0.mlp.W_in torch.Size([3072, 24576])
blocks.0.mlp.W_gate torch.Size([3072, 24576])
blocks.0.mlp.W_out torch.Size([24576, 3072])
blocks.0.mlp.b_in torch.Size([24576])
blocks.0.mlp.b_out torch.Size([3072])
blocks.1.ln1.w torch.Size([3072])
blocks.1.ln2.w torch.Size([3072])
blocks.1.attn.W_Q torch.Size([16, 3072, 256])
blocks.1.attn.W_O torch.Size([16, 256, 3072])
blocks.1.attn.b_Q torch.Size([16, 256])
blocks.1.attn.b_O torch.Size([3072])
blocks.1.attn._W_K torch.Size([16, 3072, 256])
blocks.1.attn._W_V torch.Size([16, 3072, 256])
blocks.1

In [36]:
tl_model.blocks[27].attn._W_K.grad

tensor([[[ 1.1873e-04,  1.1063e-04, -2.3460e-04,  ...,  1.6937e-03,
          -1.0223e-03, -1.7548e-04],
         [ 2.9802e-07, -7.1526e-06,  2.4080e-05,  ..., -2.0599e-04,
           1.1349e-04,  2.8133e-05],
         [ 3.3140e-05,  1.4424e-05, -2.1696e-05,  ...,  1.8358e-05,
          -9.5367e-06,  1.3113e-05],
         ...,
         [-2.2173e-05, -1.5140e-05,  2.6941e-05,  ..., -1.0300e-04,
           7.2479e-05,  9.8944e-06],
         [-2.4319e-05,  7.8678e-06,  3.5763e-06,  ...,  4.2343e-04,
          -2.3746e-04, -4.6253e-05],
         [-1.2279e-05, -1.4424e-05,  1.8358e-05,  ..., -4.1723e-05,
          -2.3603e-05,  3.5286e-05]],

        [[ 6.4453e-02, -4.3457e-02, -1.3184e-02,  ..., -1.3794e-02,
          -5.3467e-02, -6.2256e-02],
         [-1.6479e-03, -8.0872e-04,  3.6240e-04,  ...,  1.7929e-04,
           1.5640e-03,  1.8005e-03],
         [-3.9368e-03,  4.7607e-03,  8.1635e-04,  ...,  6.7139e-04,
           3.0212e-03,  3.5095e-03],
         ...,
         [-1.4648e-02,  1

In [38]:
# zero grad tl_model
for param in tl_model.parameters():
    param.grad = None

In [39]:
sports_train = SportsTask(batch_size=8, tokenizer=tokenizer)
with torch.autocast(device_type="cuda"):
    loss = sports_train.get_train_loss(wmt, 1)
    print(loss)
    loss.backward()

tensor(0.1855, device='cuda:0', grad_fn=<DivBackward0>)


In [40]:
tl_model.blocks[27].attn._W_K.grad

tensor([[[ 3.6621e-04,  2.0695e-04, -4.4250e-04,  ...,  2.5024e-03,
          -1.4725e-03, -4.9591e-04],
         [ 4.0531e-06,  8.8215e-06,  1.1921e-06,  ...,  2.2650e-06,
           2.6822e-06, -3.3379e-06],
         [ 7.2956e-05,  5.4359e-05, -9.7275e-05,  ...,  4.0436e-04,
          -2.3460e-04, -8.8692e-05],
         ...,
         [-1.1146e-05, -7.9870e-06,  2.5511e-05,  ..., -2.7847e-04,
           1.6594e-04,  6.0558e-05],
         [ 5.3406e-05,  3.0398e-05, -5.6028e-05,  ...,  4.1389e-04,
          -2.4796e-04, -8.3923e-05],
         [ 5.7459e-05,  2.4915e-05, -8.4877e-05,  ...,  2.7657e-04,
          -1.9360e-04, -4.0293e-05]],

        [[ 4.5410e-02, -2.9541e-02, -9.1553e-03,  ..., -7.7820e-03,
          -3.7109e-02, -4.3213e-02],
         [ 4.2725e-03, -6.8054e-03, -1.1368e-03,  ..., -5.3406e-04,
          -3.0212e-03, -3.4943e-03],
         [ 3.9978e-03, -3.8147e-03, -9.2697e-04,  ..., -1.9073e-04,
          -3.1738e-03, -3.6163e-03],
         ...,
         [-8.6060e-03,  5

In [41]:
wmt.attention_masks[8]['W_Q'][-1].grad