## Model Grafting

In [1]:
try:
    import google.colab
    IN_COLAB = True
    # %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    renderer = "colab"
except:
    IN_COLAB = False
    from IPython import get_ipython
    %load_ext autoreload
    %autoreload 2
    renderer = "jupyterlab"

In [2]:
%%bash
cd ../
pip install poetry
poetry install
cd notebooks

Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com
Collecting urllib3<2.0.0,>=1.26.0 (from poetry)
  Using cached urllib3-1.26.16-py2.py3-none-any.whl (143 kB)
Collecting platformdirs<4.0.0,>=3.0.0 (from poetry)
  Using cached platformdirs-3.10.0-py3-none-any.whl (17 kB)
Installing collected packages: urllib3, platformdirs
  Attempting uninstall: urllib3
    Found existing installation: urllib3 2.0.3
    Uninstalling urllib3-2.0.3:
      Successfully uninstalled urllib3-2.0.3
  Attempting uninstall: platformdirs
    Found existing installation: platformdirs 3.8.0
    Uninstalling platformdirs-3.8.0:
      Successfully uninstalled platformdirs-3.8.0
Successfully installed platformdirs-3.10.0 urllib3-1.26.16
Installing dependencies from lock file

Package operations: 0 installs, 2 updates, 0 removals

  • Updating urllib3 (1.26.16 -> 2.0.3)
  • Updating platformdirs (3.10.0 -> 3.8.0)

Installing the current project: transformer-lens (0.0.0)


ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
docker-compose 1.29.2 requires jsonschema<4,>=2.5.1, but you have jsonschema 4.17.3 which is incompatible.
docker-compose 1.29.2 requires PyYAML<6,>=3.10, but you have pyyaml 6.0 which is incompatible.
docker-compose 1.29.2 requires websocket-client<1,>=0.32.0, but you have websocket-client 1.6.1 which is incompatible.
sagemaker 2.155.0 requires attrs<23,>=20.3.0, but you have attrs 23.1.0 which is incompatible.
sagemaker 2.155.0 requires importlib-metadata<5.0,>=1.4.0, but you have importlib-metadata 5.2.0 which is incompatible.
sagemaker 2.155.0 requires protobuf<4.0,>=3.1, but you have protobuf 4.23.3 which is incompatible.
sagemaker 2.155.0 requires PyYAML==5.4.1, but you have pyyaml 6.0 which is incompatible.
sparkmagic 0.20.5 requires nest-asyncio==1.5.5, but you have nest-asyncio 1.5.6 which is incompatible

In [3]:

# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = renderer


# Import stuff
import types
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from jaxtyping import Float, Int
from typing import List, Union, Optional
from functools import partial
import copy

import requests
import json

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML
import ast
import pandas as pd
from matplotlib import pyplot as plt
import numpy as np
import matplotlib.pyplot as plt
import torch
from matplotlib.cm import ScalarMappable
import pickle
if IN_COLAB: 
    import pysvelte

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens.utilities import devices
from transformer_lens.past_key_value_caching import HookedTransformerKeyValueCache
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [4]:
import torch.nn.functional as F
import torch.optim as optim


In [5]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

def line(tensor, renderer=None, **kwargs):
    px.line(y=utils.to_numpy(tensor), **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

def cuda():
    return torch.cuda.is_available()

def get_device(): 
    return "cuda" if cuda() else "cpu"

def save_pickle(obj, filename):
    with open(filename, 'wb') as f:
        pickle.dump(obj, f)

# Function to load a pickle object from a file
def load_pickle(filename):
    with open(filename, 'rb') as f:
        obj = pickle.load(f)
    return obj


device = get_device()
device

'cuda'

## Model

- fold_ln: Whether to fold in the LayerNorm weights to the subsequent linear layer. This does not change the computation.
- center_writing_weights: Whether to center weights writing to the residual stream (ie set mean to be zero). Due to LayerNormthis doesn't change the computation.
- center_unembed : Whether to center W_U (ie set mean to be zero). Softmax is translation invariant so this doesn't affect log probs or loss, but does change logits. Defaults to True.
- refactor_factored_attn_matrices: Whether to convert the factoredmatrices (W_Q & W_K, and W_O & W_V) to be "even". 

In [6]:
MODEL_NAME = "gpt2-large"
model = HookedTransformer.from_pretrained(
        MODEL_NAME,
        center_unembed=True,  
        center_writing_weights=True,              # Whether to center weights writing to the residual stream (ie set mean to be zero). Due to LayerNorm this doesn't change the computation.      
        fold_ln=True,                             # Whether to  fold in the LayerNorm weights to the subsequent linear layer.
        refactor_factored_attn_matrices=True,
    )

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


## Dataset
We use the CounterFact Dataset

In [7]:
url = "https://rome.baulab.info/data/dsets/counterfact.json"

# Fetch the JSON data from the URL
response = requests.get(url)
dataset = response.json()


def sample_dataset(dataset, idx = None): 
    if idx is None: 
        prompt = "The {} is located in"
        subject = "Eiffel Tower"
        target = "Paris"
    else: 
        sample = dataset[idx]
        prompt = sample["requested_rewrite"]["prompt"]
        subject = sample["requested_rewrite"]["subject"]
        target = sample["requested_rewrite"]["target_true"]["str"]
    return prompt, subject, target


## Ablation Methods: 
- Noise ablation 
- Resample ablation

In [8]:
def noise_ablation(prompt, subject, target, n_noise_samples=5, vx=3):
    subject_tokens = model.to_tokens(subject)
    
    #shape: batch, n_tokens, embedding_dim
    subject_embedding = model.embed(subject_tokens)
    _, n_tokens, embedding_dim = subject_embedding.shape
    
    #noise: N(0,v), v = 3*std(embedding)
    embedding = model.W_E
    v = vx*torch.std(embedding, dim=0) #for each v in V
    noise = torch.randn(
        (n_noise_samples, n_tokens, embedding_dim)
    ).to(device) + v
    
    subject_embedding_w_noise = subject_embedding + noise
    
    #shape: batch, n_tokens, vocab_size (logits)
    unembedded_subject = model.unembed(subject_embedding_w_noise)

    noisy_subject_tokens = torch.argmax(unembedded_subject, dim=-1)
    noisy_subject_str = [
        model.to_string(nst) for nst in noisy_subject_tokens
    ]
    true_prompt = prompt.format(subject)
    corrupted_prompts = [
        prompt.format(nss) for nss in noisy_subject_str
    ]
    return true_prompt, corrupted_prompts, target

def resample_ablation(prompt, subject, target, n_noise_samples=20):
    subject_tokens = model.to_tokens(subject)
    embedding = model.W_E
    #we select n random rows from the embedding matrix
    permutations = torch.randperm(embedding.size(0))[:n_noise_samples]
    random_samples = embedding[permutations]
    #unsqueeze a token dimension between batch and embedding dims
    random_samples = random_samples.unsqueeze(dim=1)
    #we de-embed these rows
    random_embeddings = model.unembed(random_samples)
    random_tokens = torch.argmax(random_embeddings, dim=-1)
    random_subject_str = [
        model.to_string(t) for t in random_tokens
    ]
    corrupted_facts = [
        prompt.format(s) for s in random_subject_str
    ]
    true_fact = prompt.format(subject)
    return true_fact, corrupted_facts, target
    

    
    

In [13]:
def pad_from_left(tokens : torch.tensor, maxlen:int):
    pad_token = model.tokenizer.pad_token_id
    padded_tokenized_inputs = torch.zeros(tokens.shape[0], maxlen)
    
    n_pads = maxlen - tokens.shape[-1]
    padded_tokenized_inputs[:,n_pads] = pad_token
    padded_tokenized_inputs[:,n_pads:] = tokens
    return padded_tokenized_inputs.long()

def pad_to_same_length(clean_tokens, corrupted_tokens): 
    
    maxlen = max([clean_tokens.shape[-1], corrupted_tokens.shape[-1]])
    
    if clean_tokens.shape[-1] > corrupted_tokens.shape[-1]: 
        corrupted_tokens = pad_from_left(corrupted_tokens, maxlen)
    elif clean_tokens.shape[-1] < corrupted_tokens.shape[-1]: 
        clean_tokens = pad_from_left(clean_tokens, maxlen)
    return clean_tokens, corrupted_tokens

In [14]:
possible_patches = [
    'ln1.hook_scale',
    'ln1.hook_normalized',
    'attn.hook_q',
    'attn.hook_k',
    'attn.hook_v',
    # 'attn.hook_attn_scores',
    # 'attn.hook_pattern',
    # 'attn.hook_z',
    # 'hook_attn_out',
    # 'hook_resid_mid',
    'ln2.hook_scale',
    'ln2.hook_normalized',
    'mlp.hook_pre',
    'mlp.hook_post',
    # 'hook_mlp_out',
    'hook_resid_post'
]

In [112]:

def patch(
    corrupted_residual_component: Float[torch.Tensor, "batch pos d_model"],
    hook,
    pos,
    cache):
    corrupted_residual_component[:, pos, :] = cache[hook.name][:, pos, :]
    """
    Restore a patch in clean run. 
    """
    return corrupted_residual_component


def unembedding_function(residual_stack, cache, mlp=False) -> float:
    #we are only interested in applying the layer norm of the final layer on the final token
    #shape: [74, 5, 10, 1280] = n_layers, prompts, tokens, d_model
    z = cache.apply_ln_to_stack(residual_stack, layer = -1, mlp_input=mlp)
    z = z @ model.W_U
    return z


def initialize_graft(clean_cache, corrupted_cache,target_idx):
    #should we initialize the graft using a decomposition of the residual logits
    residual_clean_stack, clean_labels = clean_cache.decompose_resid(layer=-1, mode="attn_out", return_labels=True)
    residual_corrupted_stack, corrupted_labels = corrupted_cache.decompose_resid(layer=-1, mode="attn_out", return_labels=True)
    assert clean_labels==corrupted_labels
    #shape: [74, 4, 9, 50257] = n_layers, batch, tokens, vocab_size
    clean_logit_contributions_to_residual = unembedding_function(residual_clean_stack, clean_cache)[:,:,-1,:]
    corrupted_logit_contributions_to_residual = unembedding_function(residual_corrupted_stack, corrupted_cache)[:,:,-1,:]
        
    target_idx_expanded = target_idx.repeat(clean_logit_contributions_to_residual.shape[0],1,1)
    #center the logits
    clean_logit_contributions_to_residual = (clean_logit_contributions_to_residual.gather(index=target_idx_expanded, dim=-1) 
                                             - clean_logit_contributions_to_residual.mean(dim=-1, keepdim=True))
    
    corrupted_logit_contributions_to_residual = (corrupted_logit_contributions_to_residual.gather(index=target_idx_expanded, dim=-1)
                                                        - corrupted_logit_contributions_to_residual.mean(dim=-1, keepdim=True))
    

def similarity_loss(logits_original, logits_grafted, alpha=1.0, beta=1.0):
    #TODO: should we be using top n logits only
    probs_original = F.softmax(logits_original, dim=-1)
    probs_grafted = F.softmax(logits_grafted, dim=-1)
    kl_loss = F.kl_div(probs_grafted.log(), probs_original, reduction='batchmean')
    
    mse_loss = F.mse_loss(logits_grafted, logits_original)
    total_loss = alpha * kl_loss + beta * mse_loss
    
    return total_loss
        
def run_w_graft(graft, model, all_activation_keys, clean_cache, corrupted_tokens): 
    fwd_hooks = []
    for i in range(graft.shape[0]): 
        patch_name = all_activation_keys[i]
        for tok_pos in range(graft.shape[1]): 
            graft_val = graft[i, tok_pos]
            if graft_val==1: 
                hook_fn = partial(patch, pos=tok_pos, cache=clean_cache)
                hook = (patch_name, hook_fn)
                fwd_hooks.append(hook)

    patched_logits = model.run_with_hooks(
        corrupted_tokens,
        fwd_hooks=fwd_hooks,
        return_type="logits"
    )   
    patched_logits = patched_logits[:,-1,:]
    return patched_logits

    
def learn_graft(
             clean_prompt: str,
             corrupted_prompts: List[str],
             target: str,
             learning_rate=10,
             l1_penalty=0.001,
             num_epochs=10):
    
    clean_tokens = model.to_tokens(clean_prompt, prepend_bos=True) 
    corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
    #for now, use single corruption
    corrupted_tokens = corrupted_tokens[0].unsqueeze(0)
    
    target_token_idx = model.to_tokens(target)[:,1] #remember to slice out the start pad token
    #we pad the clean and corrupted tokens to the same length (pad from left)
    clean_tokens, corrupted_tokens = pad_to_same_length(clean_tokens, corrupted_tokens)
    
    #repeat clean_tokens to match corrupted_tokens
    clean_tokens = clean_tokens.expand(corrupted_tokens.shape[0], -1)
    target_token_idx = target_token_idx.expand(corrupted_tokens.shape[0], -1)

    _, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
    clean_logits, clean_cache = model.run_with_cache(clean_tokens, return_type="logits")
    clean_logits = clean_logits[:,-1,:]
    # clean_target_logit = clean_logits[:,target_token_idx]
    # clean_loss = F.cross_entropy(clean_logits.squeeze(), target_token_idx.squeeze())
    # print(clean_cache.keys())
    
    token_positions = torch.arange(corrupted_tokens.shape[-1], device=device)
    all_activation_keys = []
    for layer in range(model.cfg.n_layers): 
        for activation_patch in possible_patches:
            patch_name = f"blocks.{layer}.{activation_patch}"
            all_activation_keys.append(patch_name)

    
    # graft_base = initialize_graft(clean_cache, corrupted_cache, target_token_idx)
    graft_base_random = torch.rand(len(all_activation_keys), token_positions.shape[-1], device=device)
    threshold = 0.1  # Adjust this threshold value as needed
    # graft_base = (graft_base_random < threshold).long()
    transformed_graft = (torch.sigmoid(graft_base_random) < threshold).long()
    
    # graft_base = torch.ones(size=(len(all_activation_keys),token_positions.shape[-1]), device=device)
    # S = torch.normal(mean=0.1, std=0.001, size=(len(all_activation_keys), token_positions.shape[0]), requires_grad=True, device=device)
    
    optimizer = optim.SGD([transformed_graft], lr=learning_rate)

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        
        # graft = graft_base * (1 - torch.sigmoid(S)) + (1 - graft_base) * torch.sigmoid(S)
        graft = (torch.sigmoid(transformed_graft) < 0.5).long()
    
        with torch.no_grad():
            patched_logits = run_w_graft(
                graft,
                model,
                all_activation_keys,
                clean_cache,
                corrupted_tokens
            )
        
        similarity_loss_val = similarity_loss(clean_logits, patched_logits, alpha=1.0, beta=1.0)
        regularized_loss = l1_penalty * torch.sum(torch.abs(graft))
        total_loss = similarity_loss_val #+ regularized_loss

        total_loss.backward(retain_graph=True)
        optimizer.step()

        print(f"Epoch [{epoch + 1}/{num_epochs}] - Similarity Loss: {similarity_loss_val.item()} - L1 Loss: {regularized_loss.item()} - Sparisty: {graft.sum() / graft.flatten().shape[-1]}")

    return graft
        

In [113]:
clean_prompt, subject, target = sample_dataset(dataset)
original_fact, corrupted_facts, target = resample_ablation(clean_prompt, subject, target, n_noise_samples=10)

graft = learn_graft(original_fact, corrupted_facts, target,
                    learning_rate=0.001,
                     l1_penalty=0.001,
                     num_epochs=10)

Epoch [1/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [2/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [3/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [4/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [5/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [6/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [7/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [8/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [9/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0
Epoch [10/10] - Similarity Loss: 36.952144622802734 - L1 Loss: 0.0 - Sparisty: 0.0


tensor(0.4997, device='cuda:0')