
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Intervened Decoding

In [None]:
#@title Import libraries
import torch, gc, itertools, tqdm, scipy, copy, functools, collections, transformer_lens
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, intervening, localizing, evaluation
modelHandlers.gpu_check()

### Load Model

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu")

### Load Data

In [None]:
mem_nonmem_sets  = dataLoaders.load_pile_splits(f"{model_type}/preds", file_names=["50_50_preds.pt", "0_10_preds.pt"], as_torch=True)
mem_set, non_mem_set = mem_nonmem_sets[0], mem_nonmem_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, non_mem_set, mem_batch=1, non_mem_batch=30, test_frac=0.2, add_bos=None)
c_toks_NI, k_toks_NI = next(iter(train_dl))
c_toks_NI, k_toks_NI = c_toks_NI.squeeze(0), k_toks_NI.squeeze(0)

## load perturbed mem set and original mem set
#mem_perturbed_sets  = dataLoaders.load_pile_splits(f"{model_type}/perturbed", file_names=["mem_toks.pt", "perturbed_mem_toks.pt"], as_torch=True)
#mem_set, perturbed_mem_set = mem_perturbed_sets[0], mem_perturbed_sets[1]
#train_dl, test_dl = dataLoaders.train_test_batching(mem_set, perturbed_mem_set, mem_batch=20, non_mem_batch=20, test_frac=0.2, add_bos=None)
#c_toks_NI, c_perturb_toks_NI = next(iter(train_dl))
#c_toks_NI, c_perturb_toks_NI, = c_toks_NI.squeeze(0), c_perturb_toks_NI.squeeze(0)

In [None]:
def model_eval(model,c_NI:torch.LongTensor=None,k_NI:torch.LongTensor=None,I_range:list=[50,100], print_pred:bool=True):
    """
    evaluate the language model on individual batches of c_toks_NI and k_toks_NI
    """
    (c_mean_nll, c_minK_nll), (c_NI_pred, c_NI_true) = evaluation.evaluate_nll_greedy(model, c_NI, batch_size=50)
    (k_mean_nll, k_minK_nll), (k_NI_pred, k_NI_true) = evaluation.evaluate_nll_greedy(model, k_NI, batch_size=50)

    c_em_N = evaluation.compute_exact_match(c_NI_pred, c_NI_true, until_wrong=True)
    k_em_N = evaluation.compute_exact_match(k_NI_pred, k_NI_true, until_wrong=True)

    ## process change and keep set
    c_mean_nll, k_mean_nll = round(c_mean_nll[...,I_range[0]:I_range[1]].mean().detach().item(),4), round(k_mean_nll[...,I_range[0]:I_range[1]].mean().detach().item(),4)
    
    c_changed_frac = torch.where(c_em_N == int(I_range[1]-I_range[0]), 0, 1).sum()
    k_kept_frac = torch.where(k_em_N == int(I_range[1]-I_range[0]), 1, 0).sum() 

    print(f"---Greedy EM--- change set: {c_em_N.mean().item()} [changed {c_changed_frac}/{c_em_N.shape[0]}], keep set: {k_em_N.mean().item()} [kept {k_kept_frac}/{k_em_N.shape[0]}]")
    print(f"---Mean NLL--- change set: {c_mean_nll}, keep set: {k_mean_nll}\n\n")
    
    if print_pred:
        print(f"c_NI_pred: {model.to_string(c_NI_pred)}\n")
        print(f"k_NI_pred: {model.to_string(k_NI_pred)}")
        
    return (c_em_N, k_em_N), (c_mean_nll, k_mean_nll)
            
#model_eval(model, c_toks_NI, c_orig_pred_NI, k_toks_NI, k_orig_pred_NI, I_range=[50,100])
(c_em_N, k_em_N), (c_mean_nll, k_mean_nll) = model_eval(model, c_toks_NI, k_toks_NI, I_range=[50,100])

## With Intervention

In [None]:
model = modelHandlers.load_model(model, lr=0.0, weight_decay=1.0)
metric_fn = functools.partial(gradient.contrast_metric, I_range=[49,99], use_perturb=False, c_set_norm=0.1)
fwd_bwd_fn = functools.partial(gradient.run_contrastive_fwd_bwd, optim_steps=-1, topK=None, grad_norm=None, c_types=["W_V","W_Q","W_V","W_O","W_in","W_out"]) 
fwd_cache, bwd_cache, topk_idcs = fwd_bwd_fn(model, metric_fn, c_toks_NI, c_toks_NI, k_toks_NI)

In [None]:
def c_type_collection(model, bwd_cache=None, c_types:list=["W_Q","W_K","W_V","W_O","W_in","W_out"]):
    """
    summing all gradient weights in component c_type over multiple batches
    """
    weight_gradients = collections.defaultdict(torch.tensor)
    for c_type in c_types:
        if c_type in ["W_Q","W_K","W_V","W_O","W_in","W_out"]: ## model params
            c_vals, c_names = localizing.collect_c_type(model=model, cache=None, c_type=c_type)
        elif c_type in ["q","k","v","o","mlp_in","mlp_out"]: ## activation
            c_vals, c_names = localizing.collect_c_type(model=model, cache=bwd_cache, c_type=c_type)
        else:
            raise Exception(f"No eligible parameter oder activation name passed: {c_types}")

        ## Summing up values___________________________
        c_vals = c_vals.detach() 
        weight_gradients[c_type] = c_vals 
    return weight_gradients

c_weights = c_type_collection(model, c_types=["W_Q","W_K","W_V","W_O","W_in","W_out"])

In [None]:
def get_topK_grads(c_grads:dict, topK:int=100, select_c:list=None, select_l:list=None, select_heads:list=[], return_lk:bool=False, largest:bool=True, select_random:bool=False):
    """
    weight_gradients is a list of tensors, collect topK weight gradients and return as layer-wise list and in original shape
    """
    c_grads = copy.deepcopy(c_grads)
    ## (1) prepare components and layers
    n_layers = list(c_grads.values())[0].shape[0]
    if select_l is None or len(select_l) == 0:
        select_l = list(range(n_layers))
    remove_layers = list(range(n_layers))
    remove_layers = list(set(remove_layers).difference(set(select_l)))
    if select_c is None or len(select_c) == 0:
        select_c = list(c_grads.keys())
        
    ## (2) gather the top gradients
    c_top_grads = {}
    for c_type,c_vals in c_grads.items():
        if c_type in select_c:
            if len(select_heads) > 0 and len(c_vals.shape) >= 4: ##select specific head from attention component 
                print(c_vals.shape)
                c_vals = c_vals[:,torch.LongTensor(select_heads),:,:]
            gradients_ld = c_vals.view(c_vals.shape[0],-1) ## flatten tensor to l_dim and model_dim
            gradients_ld[torch.LongTensor(remove_layers),:] = gradients_ld[torch.LongTensor(remove_layers),:]*0 ## filter layers based on select_layers criterion  
            if select_random==False: ## normal topK selection mode
                gradients_ld = torch.abs(gradients_ld)
                if 0.0 < topK< 1.0: ## percentage
                    topK = int(topK*len(gradients_ld.flatten()))
                weight_scores, weight_idcs = torch.topk(gradients_ld.flatten(), topK, largest=largest)
            else: ## selecting any random weights as a baseline
                random_idcs = torch.randperm(gradients_ld.flatten().shape[0])
                weight_scores, weight_idcs = gradients_ld.flatten()[random_idcs[:topK]], random_idcs[:topK].squeeze()
            weight_idcs = torch.tensor(np.array(np.unravel_index(weight_idcs.numpy(), gradients_ld.shape))).T
            c_top_grads[c_type]={"idcs": weight_idcs, "scores": weight_scores}

            if return_lk: ## reformat the output to return layer-wise list of lists
                weight_ids_lk = [[] for l in range(n_layers)]
                weight_scores_lk = [[] for l in range(n_layers)]
                for k, weight_idx in enumerate(weight_idcs):
                    weight_ids_lk[weight_idx[0]].append(weight_idx[1].item())
                    weight_scores_lk[weight_idx[0]].append(weight_scores[k].item())
                c_top_grads[c_type]={"idcs": weight_ids_lk, "scores": weight_scores_lk}
    return c_top_grads


c_weights_lk = get_topK_grads(c_weights, topK=100, select_c=[], select_l=[], select_heads=[], return_lk=True)

In [None]:
def intervene_params(model, c_weights_lk:dict, std_fac:float=1.0):
    """
    perform intervention on model params according to weight_ids_LK
    """
    n_weights, layers = 0, []
    model = modelHandlers.load_model(model) ## reloading the model
    for name, param in model.named_parameters():
        name_list = name.split(".")
        c_type = name_list[-1]
        if c_type in c_weights_lk.keys():
            l, param_shape = int(name_list[1]), param.shape
            weight_ids = torch.LongTensor(c_weights_lk[c_type]["idcs"][l])
            if len(weight_ids) > 0:
                multidim_ids = np.array(np.unravel_index(weight_ids, param_shape)).T
                multidim_ids = torch.LongTensor(multidim_ids)
        
                with torch.no_grad():   
                    #std = torch.abs(param[multidim_ids[:,0], multidim_ids[:,1]] * std_fac)
                    std = torch.std(param) * std_fac
                    set_vals = torch.normal(mean=0.0, std=torch.ones(multidim_ids.shape[0]) * std)                    
                    #std = torch.std(param[multidim_ids[:,0]])
                    #set_vals = torch.normal(mean=0.0,std=torch.ones(multidim_ids.shape[0])*std*std_fac)
                    if multidim_ids.shape[1]==2:
                        param[multidim_ids[:,0], multidim_ids[:,1]] += set_vals
                    elif multidim_ids.shape[1]==3:
                        param[multidim_ids[:,0], multidim_ids[:,1], multidim_ids[:,2]] += set_vals
                    n_weights += multidim_ids.shape[0]
                    layers.append(l)
    print(f"intervened with {std_fac} on a total of {n_weights} weights in {list(c_weights_lk.keys())} in layers {set(layers)}")
    model.cfg.intervention = {"std":std_fac,"n_weights":n_weights,"c_types":list(c_weights_lk.keys())}
    return model

model = intervene_params(model, c_weights_lk, std_fac=0.1)
ck_em_after, ck_nll_after = model_eval(model, c_toks_NI, k_toks_NI, I_range=[49,99])

## Intervention (1): Loop over Different Intervention Values

In [None]:
def loop_vals_intervention(model, dl, c_weights:list, val_space:tuple=(-5,5,5), topK:int=1000, select_c:list=[], select_l:list=[], select_heads:list=[]):
    """
    searching for best intervention setting by iterating over ideal intervention values
    """
    c_weights_lk = get_topK_grads(c_weights, topK=topK, select_c=select_c, select_l=select_l, select_heads=[], return_lk=True, select_random=False)
    intervene_vals = torch.linspace(val_space[0], val_space[1], steps=val_space[2])
    ## Ensure that 0 and 1 are always included
    #values_to_add = torch.tensor([0, 1])[~torch.isin(torch.tensor([0, 1]), intervene_vals)]
    #intervene_vals = torch.cat([intervene_vals, values_to_add]) 
    #intervene_vals = torch.sort(intervene_vals).values
    
    ck_em_after, ck_nll_after = [],[]
    for intervene_val in tqdm.tqdm(intervene_vals):
        model = intervene_params(model, c_weights_lk, std_fac=intervene_val)
        ck_em, ck_nll = model_eval(model, c_toks_NI, k_toks_NI, I_range=[49,99])
        ck_em = (ck_em[0].mean(),ck_em[1].mean())
        ck_em_after.append(ck_em)
        ck_nll_after.append(ck_nll)
    ck_em_after, ck_nll_after = torch.tensor(ck_em_after), torch.tensor(ck_nll_after)
    return ck_em_after, ck_nll_after, intervene_vals, model


#"W_K","W_Q", "W_O","W_in","W_out"2
ck_em, ck_nll, intervene_vals, model = loop_vals_intervention(model, train_dl, c_weights, val_space=(0.0,0.0075,6), topK=0.001, select_c=["W_K","W_Q","W_V","W_O","W_in","W_out"], select_l=[], select_heads=[])

In [None]:
fontsize = 11
fig, ax = plt.subplots(1, 1, figsize=(4.6, 1.5), gridspec_kw={'hspace': 0.4})

c_em = ax.scatter(intervene_vals, ck_em[:,0].numpy(), c="red", marker="v", label="change set em")
k_em = ax.scatter(intervene_vals, ck_em[:,1].numpy(), c="blue", marker="^", label="keep set em")
#ax.axvline(x=0, color='r', linestyle='-', c="black", alpha=0.2)
#ax.axhline(y=ck_em[0,1], color='r', linestyle='--', c="black", alpha=0.5)

ax2 = ax.twinx()
c_nll, = ax2.plot(intervene_vals.numpy(), ck_nll[:,0].numpy(), c="red", alpha=0.3, label="memorized paragraphs")
k_nll, = ax2.plot(intervene_vals.numpy(), ck_nll[:,1].numpy(), c="blue", alpha=0.3, label="non-memorized paragraphs")
#ax.axhline(y=ck_nll[int(ck_nll.shape[0]/2),1], color='r', linestyle='--', c="blue", alpha=0.5)


locator = mpl.ticker.MaxNLocator(5)
plt.gca().xaxis.set_major_locator(locator)

plot_summary = f"Intervening on 0.1% of max gradient weights" #{model.cfg.intervention['n_weights']}, {', '.join(model.cfg.intervention['c_types'])}
#plot_summary = f"Intervening on top {model.cfg.intervention['n_weights']} max gradient weights {model.cfg.intervention['c_types']} in layers: {', '.join((str(l) for l in model.cfg.intervention['layers']))}"
ax.set_title(plot_summary, fontsize=fontsize, loc="left", x=-0.1)
ax.set_xlabel('std of intervention noise', x=0.8, fontsize=fontsize)

ax.set_ylabel('EM (triangles)', fontsize=fontsize) #\bullet
ax2.set_ylabel('NLL (lines)', fontsize=fontsize) #\searrow
#legend = ax.legend(handles=[c_nll, k_nll], frameon=False, bbox_to_anchor=(0.2, -0.15), ncol=2, prop={'size': fontsize}, handlelength=0)
#for text, color in zip(legend.get_texts(), ["red", "blue"]):
#    text.set_color(color) 

ax.text(-0.15, -0.31, f'memorized', color="red", fontsize=fontsize-1, horizontalalignment='left',verticalalignment='center', transform=ax.transAxes)
ax.text(0.1, -0.32, f'non-memorized', color="blue", fontsize=fontsize-1, horizontalalignment='left',verticalalignment='center', transform=ax.transAxes)
    
fig.savefig(f"{dataLoaders.ROOT}/results/noise_intervention_std.pdf", dpi=200, bbox_inches="tight")

## Intervention (2): Loop over Different Number of Weights set to Zero

In [None]:
def loop_topK_intervention(model, dl, c_weights:dict, topK_space:tuple=(-5,5,5), intervene_val:int=0, select_c:list=[], select_l:list=[], select_heads:list=[], largest:bool=True):
    """
    searching for best intervention setting by iterating over ideal intervention values
    """
    topK_vals = torch.linspace(topK_space[0], topK_space[1], steps=topK_space[2])    
    ck_em_after, ck_nll_after = [],[]
    for topK_val in tqdm.tqdm(topK_vals):
        c_weights_lk = get_topK_grads(c_weights,topK=int(topK_val),select_c=select_c,select_l=select_l,select_heads=select_heads,return_lk=True,largest=largest,select_random=False)
        model = intervene_params(model, c_weights_lk, std_fac=intervene_val)
        ck_em, ck_nll = model_eval(model, c_toks_NI, k_toks_NI, I_range=[49,99])
        ck_em = (ck_em[0].mean(),ck_em[1].mean())
        ck_em_after.append(ck_em)
        ck_nll_after.append(ck_nll)
    ck_em_after, ck_nll_after = torch.tensor(ck_em_after), torch.tensor(ck_nll_after)
    return ck_em_after, ck_nll_after, topK_vals, model

intervene_val = 0.2
ck_em, ck_nll, topK_vals, model = loop_topK_intervention(model, train_dl, c_weights, topK_space=(0,100,4), intervene_val=intervene_val, select_c=["W_V"], select_l=[1], select_heads=[2])

In [None]:
fontsize = 9
fig, ax = plt.subplots(1, 1, figsize=(5, 1.5), gridspec_kw={'hspace': 0.4})

c_em = ax.scatter(topK_vals, ck_em[:,0].numpy(), c="red", marker="v", label="change set em")
k_em = ax.scatter(topK_vals, ck_em[:,1].numpy(), c="blue", marker="^", label="keep set em")
#ax.axhline(y=ck_nll[0,1], color='r', linestyle='--', c="blue", alpha=0.5)

ax2 = ax.twinx()
c_nll, = ax2.plot(topK_vals, ck_nll[:,0].numpy(), c="red", alpha=0.3, label="memorized")
k_nll, = ax2.plot(topK_vals, ck_nll[:,1].numpy(), c="blue", alpha=0.3, label="non-memorized")

plot_summary = f"Intervening on top {int(topK_vals[0])} to {int(topK_vals[-1])} max gradient weights of {', '.join(model.cfg.intervention['c_types'])} with std of {intervene_val}"
#plot_summary = f"Intervening on top {model.cfg.intervention['n_weights']} max gradient weights {model.cfg.intervention['c_types']} in layers: {', '.join((str(l) for l in model.cfg.intervention['layers']))}"
ax.set_xlim(topK_vals[0]-2,topK_vals[-1]+2)
ax.set_title(plot_summary, fontsize=fontsize, loc="left", x=-0.1)
ax.set_xlabel('number of intervention weights', x=0.7, fontsize=fontsize)
ax.set_ylabel('EM (triangles)', fontsize=fontsize) #\bullet
ax2.set_ylabel('NLL (lines)', fontsize=fontsize) #\searrow
#legend = ax.legend(handles=[c_nll, k_nll], frameon=False, bbox_to_anchor=(0.2, -0.15), ncol=2, prop={'size': fontsize}, handlelength=0)
#for text, color in zip(legend.get_texts(), ["red", "blue"]):
#    text.set_color(color) 

ax.text(0.0, -0.31, f'memorized', color="red", fontsize=fontsize-1, horizontalalignment='left',verticalalignment='center', transform=ax.transAxes)
ax.text(0.2, -0.31, f'non-memorized', color="blue", fontsize=fontsize-1, horizontalalignment='left',verticalalignment='center', transform=ax.transAxes)
    
    
fig.savefig(f"{dataLoaders.ROOT}/results/noise_intervention_n_weights.pdf", dpi=200, bbox_inches="tight")