
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Token Patching

In [None]:
#@title Import libraries
import transformer_lens
import torch, gc, itertools, functools, tqdm, copy
import pandas as pd
import numpy as np
from typing import Callable

import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
torch.set_grad_enabled(False)

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, evaluation, localizing, intervening

## Model

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu", lr=0.0, weight_decay=0.01)

## Data

In [None]:
mem_nonmem_sets  = dataLoaders.load_pile_splits(f"{model_type}/preds", file_names=["50_50_preds.pt", "0_10_preds.pt"], as_torch=True)
mem_set, non_mem_set = mem_nonmem_sets[0], mem_nonmem_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, non_mem_set, mem_batch=100, non_mem_batch=1, test_frac=0.2, add_bos=None, shuffle=False)
c_toks_NI, k_toks_NI = next(iter(train_dl))
c_toks_NI, k_toks_NI = c_toks_NI.squeeze(0), k_toks_NI.squeeze(0)


#c_string_NI = ["Sign up for Take Action Now and get three actions in your inbox every week. You will receive occasional promotional offers for programs that support The Nation’s journalism. You can read our Privacy Policy here. Sign up for Take Action Now and get three actions in your inbox every week.\n\nThank you for signing up. For more from The Nation, check out our latest issue\n\nSubscribe now for as little as $2 a month!\n\nSupport Progressive Journalism The Nation is reader supported:"]
c_string_NI = ["The following are trademarks or service marks of Major League Baseball entities and may be used only with permission of Major League Baseball Properties, Inc. or the relevant Major League Baseball entity: Major League, Major League Baseball, MLB, the silhouetted batter logo, World Series, National League, American League, Division Series, League Championship Series, All-Star Game, and the names, nicknames, logos, uniform designs, color combinations, and slogans designating the Major League Baseball clubs and entities, and"]
#c_string_NI = [" headlines out of Washington never seem to slow. Subscribe to The D.C. Brief to make sense of what matters most. Please enter a valid email address. Sign Up Now Check the box if you do not wish to receive promotional offers via email from TIME. You can unsubscribe at any time. By signing up you are agreeing to our Terms of Use and Privacy Policy . This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply. Thank you! For your"]
c_toks_NI = model.to_tokens(c_string_NI, prepend_bos=False)

In [None]:
#for i in range(0,100):
#    print(i, model.to_string(c_toks_NI[i]))
#model.to_string(c_toks_NI[6])
#c_toks_NI = c_toks_NI[6].unsqueeze(0)

In [None]:
#prompt = "Our Father, who art in heaven, hallowed be thy name; thy kingdom come; thy will be done; on earth as it is in heaven. Give us this day our daily bread. And forgive us our trespasses, as we forgive those who trespass against us. And lead us not into temptation; but deliver us from evil"
#prompt = "An apple a day keeps the doctor away"
#tokens_NI = model.to_tokens(prompt, prepend_bos=True)
toks_NI = c_toks_NI
print(model.to_string(toks_NI))
pref_cont_split = 50#int(tokens_NI.shape[-1] / 2)
pref_NI = toks_NI[:,:pref_cont_split]

In [None]:
def get_first2sec_tok(logits_NIT:torch.Tensor, prefix_NI:torch.Tensor, keepNonTop:bool=True):
    """
    pertubate sequence via first and second most likely tokens
    """
    scores_NIT = (torch.nn.functional.softmax(logits_NIT.to("cpu"), dim=-1))
    prefix_scores_NI = modelHandlers.gather_token_scores(scores_NIT, prefix_NI)
    top_scores_Ik, top_idcs_Ik = modelHandlers.get_topK(scores_NIT, topK=2, minK=False)
    
    pertubed_prefix = torch.clone(prefix_NI[:,1:]).long()
    prefixIsTop = torch.where(top_idcs_Ik[...,:-1,0] == prefix_NI[:,1:], 1, 0)
    pertubed_prefix[prefixIsTop.bool()] = top_idcs_Ik[...,:-1,1][prefixIsTop.bool()] ## pick top 2
    if keepNonTop:
        pertubed_prefix[~prefixIsTop.bool()] = top_idcs_Ik[...,:-1,0][~prefixIsTop.bool()] ## pick top 1
    
    ## add BOS token
    bos_N = prefix_NI[:,0].unsqueeze(-1)
    pertubed_prefix = torch.cat((bos_N, pertubed_prefix), dim=-1)
    return pertubed_prefix
    
def get_random_tok(prefix_NI:torch.Tensor, vocab_size:int=50257, seed:int=0): 
    """
    pertubate sequence via random tokens (vocab_size = model.cfg.d_vocab)
    """
    if seed >= 0:
        print(f"fixed torch seed {seed}")
        torch.manual_seed(seed)
    pertubed_prefix = torch.randint(0, vocab_size, prefix_NI.shape)[...,:-1]
    
    ## add BOS token
    bos_N = prefix_NI[:,0].unsqueeze(-1)
    pertubed_prefix = torch.cat((bos_N, pertubed_prefix), dim=-1)
    return pertubed_prefix

perturb_type = "random" #first2sec

if perturb_type=="first2sec":
    pertubed_prefix_NI = get_first2sec_tok(model(pref_NI[:,:pref_cont_split]).to("cpu"), pref_NI, keepNonTop=True)
    print(model.to_string(pertubed_prefix_NI), pertubed_prefix_NI.shape)

elif perturb_type=="random":
    pertubed_pref_NI = get_random_tok(pref_NI, vocab_size= model.cfg.d_vocab, seed=-1)
    print(model.to_string(pertubed_pref_NI), pertubed_pref_NI.shape)

In [None]:
def token_patching_loop(model, toks_NI=torch.tensor, pertubed_pref_NI=torch.tensor, decode:bool=False, single_tok_perturb:bool=True, disable_tqdm:bool=False):
    """
    loop over all tokens in the prefix, pertubate them and measure the change in the continuation
    """
    with torch.no_grad():
        pref_NI, cont_NI, n_toks = toks_NI[:,:pertubed_pref_NI.shape[-1]], toks_NI[:,-pertubed_pref_NI.shape[-1]:], pertubed_pref_NI.shape[-1]

        nll_metric, em_metric = torch.zeros(pref_NI.shape[0], pref_NI.shape[-1]), torch.zeros(pref_NI.shape[0], pref_NI.shape[-1])
        toks_NI = torch.cat((pref_NI, cont_NI), dim=-1)
        orig_toks_nll = modelHandlers.gather_token_scores(modelHandlers.NegLogLik(model(toks_NI.to(model.cfg.device)).to("cpu")), toks_NI)

        interv_tok_pos, min_em, most_changed_preds = torch.zeros(cont_NI.shape[0]).long(), torch.ones(cont_NI.shape[0])*9999, torch.zeros(cont_NI.shape).long()
        for tok_pos in tqdm.tqdm(range(n_toks), total=n_toks, disable=disable_tqdm):

            ## (1) intervene on token at token position
            pref_NI_interv = torch.clone(pref_NI)
            if single_tok_perturb:
                pref_NI_interv[:,tok_pos] = pertubed_pref_NI[:,tok_pos]
            else:
                pref_NI_interv[:,:tok_pos] = pertubed_pref_NI[:,:tok_pos]

            ## (2) generate continuation on intervened token sequence
            if decode: #[:,:prefix_NI.shape[-1]]
                pred_toks_NI = model.generate(input=pref_NI_interv, use_past_kv_cache=True, stop_at_eos=False, max_new_tokens=cont_NI.shape[-1], do_sample=False)
                pred_nll_NIT = modelHandlers.NegLogLik(model(pred_toks_NI).detach().to("cpu"))
                pred_nll_NI = modelHandlers.gather_token_scores(pred_nll_NIT, pred_toks_NI.to("cpu"))

                pred_toks_NI = pred_toks_NI[:,-cont_NI.shape[-1]:].to("cpu")
                cont_NI_test = torch.clone(cont_NI).to("cpu")


            else: ## argmax decoding
                toks_NI_interv = torch.cat((pref_NI_interv, cont_NI), dim=-1)
                pred_nll_NIT = modelHandlers.NegLogLik(model(toks_NI_interv.to(model.cfg.device)).to("cpu"))

                pred_nll_NI = modelHandlers.gather_token_scores(pred_nll_NIT, toks_NI) ## get pred NLL 
                _, pred_toks_NIk = modelHandlers.get_topK(pred_nll_NIT, topK=1, minK=True) ## get argmax toks 
                pred_toks_NI = pred_toks_NIk[...,-(cont_NI.shape[-1]+1):-1,0].to("cpu")
                cont_NI_test = torch.clone(cont_NI[...,:]).to("cpu")

            ## (3) evaluate the generated continuation against the original continuation
            nll_metric[:,tok_pos] = pred_nll_NI[:,-cont_NI.shape[-1]:].mean(-1)   

            em = evaluation.compute_exact_match(pred_toks_NI, cont_NI_test, until_wrong=True)
            em_metric[:,tok_pos] = em

            ## (4) update minimum em and most_changed_preds
            select_mask = torch.where(em < min_em, 1, 0)
            select_idcs = torch.nonzero(select_mask.bool()).squeeze()
            interv_tok_pos[select_idcs] = tok_pos
            min_em[select_idcs] = em[select_idcs]
            most_changed_preds[select_idcs,:] = pred_toks_NI[select_idcs,:].detach()        
    return nll_metric, em_metric, most_changed_preds, interv_tok_pos

nll_metric, em_metric, most_changed_preds, min_tok_pos = token_patching_loop(model, toks_NI, pertubed_pref_NI, single_tok_perturb=True, decode=True)

In [None]:
model.to_string(toks_NI[:,:50])

In [None]:
model.to_string(toks_NI[:,50:])

In [None]:
model.to_string(most_changed_preds)

In [None]:
nll_baseline = nll_metric[...,0,0].mean().item()
em_baseline = em_metric[...,0,0].mean().item()

nll = nll_metric[...,0,:].numpy()
em = em_metric[...,0,:].numpy()

x = np.arange(0,em.shape[-1])
true_prefix=model.to_str_tokens(pref_NI.squeeze())
pertubed_prefix=model.to_str_tokens(pertubed_pref_NI.squeeze())

xlabels = true_prefix
#xlabels = [a + r" $\rightarrow$ " + b for (a,b) in zip(true_prefix, pertubed_prefix)]

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(15, 2), gridspec_kw={'hspace': 0.4})
fontsize = 12

ax[0].axhline(y=nll_baseline, linewidth=1, linestyle='--', c="grey", alpha=0.5)
ax[0].plot(x, nll, c="black")
ax[0].set_ylabel(r'NLL', fontsize=fontsize) #\searrow
ax[0].set_xticks([])
#ax[0].yscale("log")
ax[0].axvline(x=min_tok_pos, c="orange", linestyle='-')
ax[0].tick_params(axis='both', which='major', labelsize=fontsize)


ax[1].axhline(y=em_baseline, linewidth=1, linestyle='--', c="grey", alpha=0.5)
ax[1].plot(x, em, c="black")
ax[1].set_ylabel(r'EM', fontsize=fontsize) #\searrow
ax[1].axvline(x=min_tok_pos, c="orange", linestyle='-')
ax[1].tick_params(axis='both', which='major', labelsize=fontsize)

ax[1].set_xticks(x)
labels = ax[1].set_xticklabels(xlabels, fontsize=fontsize-2, rotation=90)

#fig.savefig(f"{dataLoaders.ROOT}/results/{model_type}_{perturb_type}_perturb.pdf", dpi=200, bbox_inches="tight")

In [None]:
#nll_NI = modelHandlers.gather_tokenr_scores(model(tokens_NI), tokens_NI) ## get pred NLL 
logs = model(c_toks_NI)
_, pred_k = modelHandlers.get_topK(logs, topK=1, minK=False) ## get argmax toks 
model.to_string(pred_k[...,0])

In [None]:
pred_toks_NI = model.generate(input=pref_NI, use_past_kv_cache=True, stop_at_eos=False, max_new_tokens=50, do_sample=False)
model.to_string(pred_toks_NI[...,:])

In [None]:
c_toks_NI
evaluation.compute_exact_match(pred_toks_NI[:,50:], c_toks_NI[:,50:], until_wrong=False)


In [None]:
model.to_string(c_toks_NI[...,:])