In [1]:
%load_ext autoreload
%autoreload 2
from transformer_lens import HookedTransformer, ActivationCache
import os
import torch
import numpy as np
import pandas as pd
import datasets
import transformers
import pickle

from tasks import PileTask, OWTTask, InductionTask, GreaterThanTask
from tasks.ioi.IOITask import IOITask, IOITask_NPO, IOITask_Uniform
from tasks.induction.InductionTask import InductionTask, InductionTask_NPO, InductionTask_Uniform
from tasks.facts.SportsTask import SportsTask, SportsTask_NPO, SportsTask_Uniform

from tqdm.auto import tqdm

from transformers import GPT2Tokenizer, GPTNeoXTokenizerFast, AutoModelForCausalLM, AutoTokenizer


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_type = "gemma-2b"

os.environ['HF_TOKEN'] = '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'
if model_type == "pythia":
    reference_model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-2.8B")#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8B")
    tokenizer.pad_token_id = tokenizer.eos_token_id

elif model_type == "gemma-7b":
    reference_model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", torch_dtype=torch.bfloat16)#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"

elif model_type == "gemma-2b":
    reference_model = AutoModelForCausalLM.from_pretrained("google/gemma-2b", torch_dtype=torch.bfloat16)#.cuda()
    tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
    tokenizer.pad_token_id = tokenizer.eos_token_id
    tokenizer.padding_side = "right"


Gemma's activation function should be approximate GeLU and not exact GeLU.
Changing the activation function to `gelu_pytorch_tanh`.if you want to use the legacy `gelu`, edit the `model.config` to set `hidden_activation=gelu`   instead of `hidden_act`. See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.24s/it]


In [3]:
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2b")
tl_model = HookedTransformer.from_pretrained(
    'google/gemma-2b',
    tokenizer=tokenizer,
    device='cuda',
    default_padding_side="right",
    fold_ln=False,
    fold_value_biases=False,
    center_writing_weights=False,
    dtype=torch.bfloat16
)


Loading checkpoint shards: 100%|██████████| 2/2 [00:02<00:00,  1.14s/it]


Loaded pretrained model google/gemma-2b into HookedTransformer


In [10]:
import pickle
with open("models/google_gemma-7b_sports_baseball_ap_graph.pkl", "rb") as f:
    ap_graph = pickle.load(f)
for component in ap_graph:
    if "m" in component:
        print(f"{component}: {ap_graph[component]}")


m0: -0.055645283311605453
m1: 0.0552063025534153
m2: -0.11303359270095825
m3: 0.024208657443523407
m4: -0.0113525390625
m5: -0.022718576714396477
m6: -0.007042518351227045
m7: -0.021432731300592422
m8: -0.006188026163727045
m9: -0.0019231943879276514
m10: -0.03130634129047394
m11: -0.0708770751953125
m12: -0.04879526048898697
m13: 0.04687969759106636
m14: 0.035638369619846344
m15: 0.02344219572842121
m16: -0.018733099102973938
m17: -0.06601186841726303
m18: -0.10868014395236969
m19: -0.0049954927526414394
m20: -0.08375901728868484
m21: -0.25811299681663513
m22: -0.11271785199642181
m23: -0.23159556090831757
m24: -0.09144005924463272
m25: -0.08293269574642181
m26: 0.18458910286426544
m27: 0.5513822436332703


## Approach 1: apply masks before doing any calculations
this succeeded

In [6]:
import random

def create_random_weight_mask_dicts(model):
    # Creates random weight masks for testing
    weight_mask_attn_dict = {}
    weight_mask_mlp_dict = {}

    for layer in range(model.cfg.n_layers):
        weight_mask_attn_dict[layer] = {}
        # Want bool of length n_head, randomly set to True
        weight_mask_attn_dict[layer]['W_Q'] = torch.rand(model.cfg.n_heads) < 0.8
        weight_mask_attn_dict[layer]['W_K'] = torch.rand(model.cfg.n_heads) < 0.8
        weight_mask_attn_dict[layer]['W_V'] = torch.rand(model.cfg.n_heads) < 0.8
        weight_mask_attn_dict[layer]['W_O'] = torch.rand(model.cfg.n_heads) < 0.8

        # Randomly set to true or false
        weight_mask_mlp_dict[layer] = random.randint(0, 1) == 1

    return weight_mask_attn_dict, weight_mask_mlp_dict


In [7]:
from torch import nn

def make_partly_differentiable_mask(W, unfrozen_heads, device="cuda"):
    """
    W is Parameter of shape (n_heads, ...). Returns baseline and frozen (both only 1d arrays of (n_heads,)), and forward pass should be W_baseline.float() + W_frozen.float() * W 
    """
    W_baseline = torch.nn.Parameter(torch.zeros(W.shape[0], dtype=torch.bool), requires_grad=False).to(device)

    # unsqueeze to broadcast efficiently, until W_baseline has same shape as W
    while len(W_baseline.shape) < len(W.shape):
        W_baseline = W_baseline.unsqueeze(-1)
    
    W_baseline[unfrozen_heads] = True
    # W_baseline = ~W_frozen
    W_frozen = torch.nn.Parameter(~W_baseline, requires_grad=False)
    # convert into float
    return W_frozen.float(), W_baseline.float()

class WeightMaskedTransformer(nn.Module):
    def __init__(self, tl_transformer, weight_mask_attn_dict=None, weight_mask_mlp_dict=None, torch_dtype=torch.bfloat16):
        """
        weight_mask_attn_dict: {layer: {"W_Q": unfrozen_heads, "W_K": unfrozen_heads, "W_V": unfrozen_heads, "W_O": unfrozen_heads}} (frozen_heads is shape (n_heads,) of bools). If none, train mask over all heads
        weight_mask_mlp_dict: {layer: bool}. If none, train mask over all mlps

        """
        super().__init__()
        self.torch_dtype = torch_dtype
        # tl_transformer should be a HookedTransformer
        self.tl_transformer = tl_transformer
        # turn off gradients for tl_transformer
        # for param in self.tl_transformer.parameters():
        #     param.requires_grad = False

        self.weight_mask_attn_dict = weight_mask_attn_dict
        self.weight_mask_mlp_dict = weight_mask_mlp_dict
        # store weight masks for every component that is unfrozen
        
        # need to store reference weights so that you can reset W_Q, etc after a forward pass
        self.reference_attn_weights = {}
        self.reference_mlp_weights = {}

        self.attention_masks = {}
        self.mlp_masks = {}
        for layer in range(self.tl_transformer.cfg.n_layers):
            self.attention_masks[layer] = {}
            self.reference_attn_weights[layer] = {}
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None:
                    unfrozen_heads = list(range(self.tl_transformer.cfg.n_heads)) # all heads are unfrozen
                else:
                    unfrozen_heads = self.weight_mask_attn_dict[layer][component]
                # make frozen and baseline masks, and also a copy of the original weights

                if len(unfrozen_heads) > 0:
                    W_frozen, W_baseline = make_partly_differentiable_mask(parameter, unfrozen_heads)
                    weight_mask = nn.Parameter(torch.ones_like(parameter).type(torch_dtype), requires_grad=True)
                    
                    self.attention_masks[layer][component] = (W_frozen, W_baseline, weight_mask)
                    self.reference_attn_weights[layer][component] = parameter.clone()

            if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer]:
                in_weight_mask = nn.Parameter(torch.ones_like(self.tl_transformer.blocks[layer].mlp.W_in).type(torch_dtype), requires_grad=True)
                out_weight_mask = nn.Parameter(torch.ones_like(self.tl_transformer.blocks[layer].mlp.W_out).type(torch_dtype), requires_grad=True)

                self.mlp_masks[layer] = (in_weight_mask, out_weight_mask)
                self.reference_mlp_weights[layer] = (self.tl_transformer.blocks[layer].mlp.W_in.clone(), self.tl_transformer.blocks[layer].mlp.W_out.clone())

    def forward(self, *args, **kwargs):
        for layer in range(self.tl_transformer.cfg.n_layers):
            for component, parameter in [("W_Q", self.tl_transformer.blocks[layer].attn.W_Q), ("W_K", self.tl_transformer.blocks[layer].attn.W_K), ("W_V", self.tl_transformer.blocks[layer].attn.W_V), ("W_O", self.tl_transformer.blocks[layer].attn.W_O)]:
                if self.weight_mask_attn_dict is None or len(self.attention_masks[layer]) > 0:
                    W_frozen, W_baseline, weight_mask = self.attention_masks[layer][component]
                    reference_data = self.reference_attn_weights[layer][component]
                    mask = W_baseline + W_frozen * weight_mask
                    self.tl_transformer.blocks[layer].attn.__dict__['_parameters'][component] = reference_data * mask

            if self.weight_mask_mlp_dict is None or self.weight_mask_mlp_dict[layer]:
                in_weight_mask, out_weight_mask = self.mlp_masks[layer]
                reference_in_data, reference_out_data = self.reference_mlp_weights[layer]
                self.tl_transformer.blocks[layer].mlp.__dict__['_parameters']['W_out'] = reference_in_data * in_weight_mask
                self.tl_transformer.blocks[layer].mlp.__dict__['_parameters']['W_out'] =  reference_out_data * out_weight_mask

        return self.tl_transformer(*args, **kwargs)


In [8]:
# BELOW IS RANDOM LOCALIZATIONS, REPLACE WITH ACTUAL LOCALIZATIONS
random_weight_mask_attns, random_weight_mask_mlps = create_random_weight_mask_dicts(tl_model)

wmt = WeightMaskedTransformer(tl_model, weight_mask_attn_dict=random_weight_mask_attns, weight_mask_mlp_dict=random_weight_mask_mlps)


In [9]:
sports_test = SportsTask(batch_size=64, tokenizer=tokenizer)
# print(sports_test.get_test_loss(tl_model))

with torch.autocast(device_type="cuda"):
    print(sports_test.get_test_loss(tl_model))
    print(sports_test.get_test_loss(wmt))


Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


train_df: (1252, 8), test_df: (314, 8)
tensor(0.3072, device='cuda:0')
tensor(0.3020, device='cuda:0')


##

## Check that gradients flow properly

In [10]:
for param in tl_model.parameters():
    param.requires_grad = True


In [11]:
sports_train = SportsTask(batch_size=3, tokenizer=tokenizer)
with torch.autocast(device_type="cuda"):
    loss = sports_train.get_train_loss(tl_model, 1)
    print(loss)
    loss.backward()
    # loss = sports_train.get_train_loss(wmt, 1)
    # print(loss)
    # loss.backward()


train_df: (1252, 8), test_df: (314, 8)
tensor(0.9509, device='cuda:0', grad_fn=<DivBackward0>)


In [12]:
# zero grad tl_model
for param in tl_model.parameters():
    param.grad = None


In [13]:
sports_train = SportsTask(batch_size=8, tokenizer=tokenizer)
with torch.autocast(device_type="cuda"):
    loss = sports_train.get_train_loss(wmt, 1)
    print(loss)
    loss.backward()


train_df: (1252, 8), test_df: (314, 8)
tensor(0.3360, device='cuda:0', grad_fn=<DivBackward0>)


In [14]:
tl_model.blocks[7].attn.W_K.grad


  tl_model.blocks[7].attn.W_K.grad


In [18]:
weight_mask_attns[3]


{'W_Q': tensor([ True,  True, False,  True, False,  True,  True, False]),
 'W_K': tensor([False, False,  True,  True,  True,  True, False,  True]),
 'W_V': tensor([True, True, True, True, True, True, True, True]),
 'W_O': tensor([False,  True,  True, False, False, False,  True,  True])}

In [23]:
print(wmt.attention_masks[3]['W_Q'][-1].grad[4])


tensor([[-5.1036e-07, -2.5518e-07, -1.6466e-06,  ...,  4.8280e-06,
          1.1250e-06, -1.3461e-09],
        [ 4.4703e-07,  1.6615e-06,  4.6100e-08,  ..., -1.5736e-05,
          1.0207e-06, -1.1444e-05],
        [ 4.8801e-07, -1.4529e-07, -1.7583e-06,  ...,  1.2159e-05,
         -9.5367e-06,  2.0117e-07],
        ...,
        [-8.1956e-08,  1.2293e-07,  3.0361e-07,  ...,  4.1723e-06,
         -7.2177e-08, -4.3947e-09],
        [ 9.0525e-07, -1.0987e-09,  2.3395e-06,  ..., -2.0504e-05,
          6.3181e-06,  1.8701e-06],
        [-1.5926e-07, -9.4250e-07, -3.5623e-08,  ..., -2.1219e-05,
          1.0014e-05,  1.6928e-05]], device='cuda:0', dtype=torch.bfloat16)
