# Optimizer Step: Integrate Localization and Intervention

In [1]:
#@title Import libraries
import transformer_lens
import torch, gc, itertools, functools, tqdm
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, evaluation, localizing, intervening

## Model and Data

In [2]:
model = modelHandlers.load_model(model_type="gpt-neo-125M", DEVICE="cuda", lr=0.0, weight_decay=0.01)

Loaded pretrained model gpt-neo-125M into HookedTransformer
setting no_grad on ['embed', 'pos_embed', 'unembed', 'b_in', 'b_out', 'b_K', 'b_Q', 'b_V', 'b_O']


## Data

In [3]:
## mem and non-mem set
#(mem_prompts, mem_counts),(non_mem_prompts,non_mem_counts) = dataLoaders.load_pile_splits("acc/gpt-neo-125M", as_torch=True)
#train_dl, test_dl = dataLoaders.train_test_batching(mem_prompts, non_mem_prompts, mem_batch=5, non_mem_batch=5, test_frac=0.0, set_twice=None)
#c_toks_NI, k_toks_NI = next(iter(train_dl))

## load perturbed mem set and original mem set
mem_prompts, non_mem_prompts = dataLoaders.load_perturbed_mem(file_path="acc/gpt-neo-125M")
train_dl = torch.utils.data.DataLoader(list(zip(mem_prompts, non_mem_prompts)), batch_size=3, shuffle=False)
c_toks_NI, k_toks_NI = next(iter(train_dl))

In [4]:
c_toks_NI, k_toks_NI = next(iter(train_dl))
c_toks_NI, k_toks_NI = c_toks_NI.squeeze(0), k_toks_NI.squeeze(0)
c_orig_pred_NI = model.generate(input=c_toks_NI[:,:50], max_new_tokens=50, do_sample=False)
k_orig_pred_NI = model.generate(input=k_toks_NI[:,:50], max_new_tokens=50, do_sample=False)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [5]:
def model_eval(model,c_NI:torch.LongTensor=None,c_orig_pred_NI:torch.LongTensor=None,k_NI:torch.LongTensor=None,k_orig_pred_NI:torch.LongTensor=None,I_range:list=[50,100], print_pred:bool=True):
    """
    evaluate the language model on individual batches of c_toks_NI and k_toks_NI
    """
    ## change set
    (c_mean_nll, c_minK_nll), (c_NI_pred, c_NI_true) = evaluation.evaluate_nll_greedy(model, c_NI, batch_size=50)
    if c_orig_pred_NI is not None:
        c_NI_pred = model.generate(input=c_NI[:,:50], stop_at_eos=False, max_new_tokens=50, do_sample=False)
        c_NI_pred, c_orig_pred_NI = c_NI_pred[...,I_range[0]:I_range[1]].to("cpu"), c_orig_pred_NI[...,I_range[0]:I_range[1]].to("cpu")
        c_em_N = evaluation.compute_exact_match(c_NI_pred, c_orig_pred_NI, until_wrong=False)
    else: ## argmax greedy decoding
        print("argmax greedy decoding on c_NI")
        c_em_N = evaluation.compute_exact_match(c_NI_pred, c_NI_true, until_wrong=True)

    ## keep set
    (k_mean_nll, k_minK_nll), (_, _) = evaluation.evaluate_nll_greedy(model, k_NI, batch_size=50)
    k_NI_pred = model.generate(input=k_NI[:,:50], max_new_tokens=I_range[1]-I_range[0], do_sample=False)
    k_NI_pred, k_orig_pred_NI = k_NI_pred[...,I_range[0]:I_range[1]].to("cpu"), k_orig_pred_NI[...,I_range[0]:I_range[1]].to("cpu")
    if c_orig_pred_NI is not None:
        k_em_N = evaluation.compute_exact_match(k_NI_pred, k_orig_pred_NI, until_wrong=False)
    else:
        k_em_N = evaluation.compute_exact_match(k_NI_pred, k_orig_pred_NI, until_wrong=True)

    ## process change and keep set
    c_mean_nll, k_mean_nll = round(c_mean_nll[...,I_range[0]:I_range[1]].mean().detach().item(),4), round(k_mean_nll[...,I_range[0]:I_range[1]].mean().detach().item(),4)
    
    c_changed_frac = torch.where(c_em_N == int(I_range[1]-I_range[0]), 0, 1).sum()
    k_kept_frac = torch.where(k_em_N == int(I_range[1]-I_range[0]), 1, 0).sum() 

    print(f"---Greedy EM--- change set: {c_em_N.mean().item()} [changed {c_changed_frac}/{c_em_N.shape[0]}], keep set: {k_em_N.mean().item()} [kept {k_kept_frac}/{k_em_N.shape[0]}]")
    print(f"---Mean NLL--- change set: {c_mean_nll}, keep set: {k_mean_nll}\n\n")
    
    if print_pred:
        print(f"c_NI_pred: {model.to_string(c_NI_pred)}\n")
        print(f"k_NI_pred: {model.to_string(k_NI_pred)}")
            
model_eval(model, c_toks_NI, c_orig_pred_NI, k_toks_NI, k_orig_pred_NI, I_range=[50,100])
#model_eval(model, c_toks_NI, None, k_toks_NI, k_orig_pred_NI, I_range=[50,100])

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

---Greedy EM--- change set: 50.0 [changed 0/3], keep set: 50.0 [kept 3/3]
---Mean NLL--- change set: 0.1491, keep set: 0.5782


c_NI_pred: [". I consent to the collection, use, maintenance, and disclosure of my information in accordance with the Postmedia's Privacy Policy.\n\nPostmedia wants to improve your reading experience as well as share the best deals and promotions from our advertisers with you", '       \\setlength{\\oddsidemargin}{-69pt}\n                \\begin{document}$$\\begin{aligned} \\mathbf', '_ISA_USAGE = YES_ERROR;\n\t\t\t\tCLANG_WARN_EMPTY_BODY = YES;\n\t\t\t\tCLANG_WARN_ENUM_CONVERSION = YES;\n\t\t\t']

k_NI_pred: [' of Disclaimer and Privacy Policy and Terms of Use.\n\nPostmedia wants to improve the user experience on our site by providing it with a broad range of information and marketing information that is tailored to meet the specific needs of its users. Postmedia', '  \\setlength{\\oddsidemargin}{-69pt}\n           \\begin{document}$$\\begin{aligned} \\mathbf

In [6]:
def find_topK_grads(model, topK:float=0.001, c_types:list=["W_in", "W_out", "W_K", "W_V", "W_Q", "W_O"]):
    """
    find the topK weights in all model parameters
    """
    ## (1) collect all params
    all_param_grads = list()
    all_param_grads = torch.cat([torch.abs(param.grad.view(-1)) for name, param in model.named_parameters() if name.split(".")[-1] in c_types])

    ## (2) identify top params (sparsity)
    if 0.0 < topK < 1.0: ## percentage
        topk_vals_flat, topk_idcs_flat = torch.topk(all_param_grads, k=int(topK*len(all_param_grads)), largest=True)
    elif 1.0 <= topK < len(all_param_grads): ## pick top weights
        topk_vals_flat, topk_idcs_flat = torch.topk(all_param_grads, k=int(topK), largest=True)
    min_grad = torch.min(topk_vals_flat)
    print(f"{len(topk_idcs_flat)} weights in {c_types} with grads > {min_grad.item()}")
    return min_grad, topk_idcs_flat

def clip_grads(model, min_grad:float=None, full_remove_idcs:list=[]): ## in-place
    """
    clip gradients than are not above min_grad, and not below min_grad if keep_neg is enabled
    """
    #param_vec = torch.nn.utils.parameters_to_vector(model.parameters())
    #torch.nn.utils.vector_to_parameters(param_vec, model.parameters())
    print(f"clipped at {min_grad} or using full_remove_idcs on {len(full_remove_idcs)} modules")
    list_idx = 0
    for param in model.parameters():
        if param.requires_grad:
            if min_grad is not None:
                remove_idcs = torch.where((param.grad > min_grad) | (param.grad < -min_grad), 0, 1)
            else:
                remove_idcs = full_remove_idcs[list_idx]
                list_idx += 1
            param.grad[remove_idcs.bool()] = 0.0 ## annul small positive and negative grads
            if min_grad is not None:
                full_remove_idcs.append(remove_idcs)
    return full_remove_idcs
                
                
#min_grad, topk_idcs_flat = find_topK_grads(model, topK=0.01)
#full_remove_idcs = clip_grads(model, min_grad)
#full_remove_idcs = clip_grads(model, min_grad=None, full_remove_idcs=full_remove_idcs)

In [29]:
def pool_tensor(orig_tensor:torch.tensor, dims:list, match_size:tuple=None):
    if dims is not None:
        for dim in dims:
            orig_tensor_pooled = orig_tensor.mean(dim)
            if match_size is None:
                orig_tensor = torch.repeat_interleave(orig_tensor_pooled.unsqueeze(dim), orig_tensor.shape[dim], dim=dim)
            else: ## expand the size of pooled tensor dimension to match that of match_size
                orig_tensor = torch.repeat_interleave(orig_tensor_pooled.unsqueeze(dim), match_size[dim], dim=dim)
    return orig_tensor


def contrast_metric(c_nll:torch.tensor, k_nll:torch.tensor=None, k_nll_fixed:torch.tensor=None, I_range:list=[0,100], norm_sets:float=None, pool:dict={"c": [-1], "k": [0,-1]}, only_set:str=None):
    """
    minimizing / preserve keep_score while maximizing change_score
    """
    if only_set is not None: ## option for making one set zero for baseline tests
        if only_set=="c":
            k_nll.detach()
            del k_nll
            k_nll = torch.zeros(c_nll.shape, requires_grad=False)
        elif only_set=="k":
            c_nll.detach()
            del c_nll            
            c_nll = torch.zeros(k_nll.shape, requires_grad=False)

    ## (1) preprocess________________________________________
    ## select latter tokens to apply metric to
    c_nll, k_nll = c_nll[...,I_range[0]:I_range[1]], k_nll[...,I_range[0]:I_range[1]]
        
    ## (2) pooling_______________________________________________
    ## pool over dims but then expand again to retain shapes
    c_nll = pool_tensor(c_nll, pool["c"], match_size=k_nll.shape) 
    k_nll = pool_tensor(k_nll, pool["k"], match_size=None)
                
    ## adjust shapes
    ## clip batch sizes and paragraph lengths always to shorter version
    #c_nll = c_nll[:k_nll.shape[0], :k_nll.shape[1]]
    #k_nll = k_nll[:c_nll.shape[0], :c_nll.shape[1]]
    
    #if norm_sets: ## balance out loss terms  
    #    c_nll = torch.nn.functional.normalize(c_nll, p=1.0, dim=-1)
    #    k_nll = torch.nn.functional.normalize(k_nll, p=1.0, dim=-1)
    print(f"pooling c_nll {c_nll.shape}, k_nll {k_nll.shape} pool: {pool}")
    
    ## (3) apply metric_______________________________________________
    
    if only_set=="c":
        contrast_res = c_nll.mean()
    elif only_set=="k":
        contrast_res = k_nll.mean()
    else: ## normal case
        if isinstance(norm_sets, float):
            c_nll = c_nll * (norm_sets*(k_nll.detach().sum() / c_nll.detach().sum())) 
            #c_nll = c_nll * norm_sets
        if k_nll_fixed is not None: ## mean squared error version to enforce non-changing keep set NLL
            k_nll_fixed = k_nll_fixed[...,I_range[0]:I_range[1]]
            #k_nll_fixed = pool_tensor(k_nll_fixed, pool["k"], match_size=None)
            #k_nll_fixed = k_nll_fixed[:k_nll.shape[0], :k_nll.shape[1]]
            k_mse = 1000* (k_nll-k_nll_fixed.detach())**2 
            #kl_loss = torch.nn.KLDivLoss(reduction="batchmean")
            #contrast_res = kl_loss(c_nll, k_nll_fixed.detach())
            contrast_res = (k_mse - c_nll).mean()
            print(f"contrast loss: {contrast_res}, c_nll: {c_nll.detach().mean()}, k_nll_mse: {k_mse.detach().mean()}")
        else:
            contrast_res = (k_nll - c_nll).mean()
            print(f"contrast loss: {contrast_res}, c_nll: {c_nll.detach().mean()}, k_nll: {k_nll.detach().mean()}")
    return contrast_res, None


In [39]:
def run_fwd_bwd(model, metric_fn, c_toks_NI:torch.LongTensor=None, k_toks_NI:torch.LongTensor=None, optim_steps:int=-1, topK:float=None, grad_norm:float=None, c_types:list=None):
    """
    adding hooks to model, running model on data on metric and returning cached activs, params are cached in model
    """
    #fwd_cache, bwd_cache = gradient.add_fwd_bwd_hooks(model, hook_filter={"not in":"_input"})     
    c_toks_NI = c_toks_NI.to(model.cfg.device)
    k_toks_NI = k_toks_NI.to(model.cfg.device)
    k_nll_fixed = modelHandlers.gather_token_scores(modelHandlers.NegLogLik(model(k_toks_NI)), k_toks_NI)
    #k_logits_fixed = torch.nn.functional.softmax(model(k_toks_NI), dim=-1)

    for step_i in range(abs(optim_steps)):
        c_nll = modelHandlers.gather_token_scores(modelHandlers.NegLogLik(model(c_toks_NI)), c_toks_NI)
        k_nll = modelHandlers.gather_token_scores(modelHandlers.NegLogLik(model(k_toks_NI)), k_toks_NI) 
        #c_logits = torch.nn.functional.log_softmax(model(c_toks_NI), dim=-1)
        metric_res, metric_norm = metric_fn(c_nll, k_nll)  #.to("cpu")

        model.zero_grad()
        metric_res.backward(retain_graph=False)

        if grad_norm is not None:
            print(f"applied grad norm clipping with max norm {grad_norm}")
            torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_norm), norm_type=2.0)

        if topK is not None:
            if step_i == 0:
                min_grad, topk_idcs = find_topK_grads(model, topK=topK, c_types=c_types)
                full_remove_idcs = clip_grads(model, min_grad)
            full_remove_idcs = clip_grads(model, min_grad=None, full_remove_idcs=full_remove_idcs)

        if optim_steps >= 1 and hasattr(model, 'optim'):
            print(f"{step_i+1}/{abs(optim_steps)}, optimizer step")
            model.optim.step()
            model.optim.zero_grad()
        
    del c_toks_NI
    del k_toks_NI
    torch.cuda.empty_cache()

        #fwd_cache = transformer_lens.ActivationCache(fwd_cache, model)
        #bwd_cache = transformer_lens.ActivationCache(bwd_cache, model)
    #return fwd_cache, bwd_cache, topk_idcs

model = modelHandlers.load_model(model, lr=1e-05, weight_decay=1.0)
metric_fn = functools.partial(contrast_metric, I_range=[49,99], pool={"c":[],"k":[]}, norm_sets=None)
fwd_bwd_fn = functools.partial(run_fwd_bwd, optim_steps=5, topK=None, grad_norm=None, c_types=["W_in","W_out","W_K","W_Q","W_V"])
fwd_bwd_fn(model, metric_fn, c_toks_NI, k_toks_NI)

reset model gpt-neo-125M
Loaded pretrained model gpt-neo-125M into HookedTransformer
added optimizer with lr: 1e-05 and weight_decay: 1.0
setting no_grad on ['embed', 'pos_embed', 'unembed', 'b_in', 'b_out', 'b_K', 'b_Q', 'b_V', 'b_O']
pooling c_nll torch.Size([3, 50]), k_nll torch.Size([3, 50]) pool: {'c': [], 'k': []}
contrast loss: 0.4332226812839508, c_nll: 0.14667823910713196, k_nll: 0.5799009203910828
1/5, optimizer step
pooling c_nll torch.Size([3, 50]), k_nll torch.Size([3, 50]) pool: {'c': [], 'k': []}
contrast loss: 0.3361014425754547, c_nll: 0.1701779067516327, k_nll: 0.5062793493270874
2/5, optimizer step
pooling c_nll torch.Size([3, 50]), k_nll torch.Size([3, 50]) pool: {'c': [], 'k': []}
contrast loss: 0.2607945501804352, c_nll: 0.19398583471775055, k_nll: 0.4547803997993469
3/5, optimizer step
pooling c_nll torch.Size([3, 50]), k_nll torch.Size([3, 50]) pool: {'c': [], 'k': []}
contrast loss: 0.18579351902008057, c_nll: 0.23012681305408478, k_nll: 0.41592031717300415
4/5

In [40]:
model_eval(model, c_toks_NI, c_orig_pred_NI, k_toks_NI, k_orig_pred_NI, I_range=[50,100])
#model_eval(model, c_toks_NI, None, k_toks_NI, k_orig_pred_NI, I_range=[50,100])

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

---Greedy EM--- change set: 19.66666603088379 [changed 2/3], keep set: 32.66666793823242 [kept 1/3]
---Mean NLL--- change set: 0.3563, keep set: 0.3621


c_NI_pred: ['.\n\nPostmedia wants to improve the user experience on our site by providing it with the latest user-friendly content, tools and functionality. This includes offering new features and premium content that add new functionality and are offered by third parties. Postmedia', '       \\setlength{\\oddsidemargin}{-69pt}\n                \\begin{document}$$\\begin{aligned} \\mathbf', '_EXCEPTIONS = YES;\n\t\t\t\tCLANG_WARN_DOCUMENTATION_COMMENTS = YES;\n\t\t\t\tCLANG_WARN_EMPTY_BODY = YES;\n\t\t\t\t']

k_NI_pred: [' of Disclaimer and Privacy Policy and Terms of Use.\n\nPostmedia wants to improve the user experience on our site by providing it with the latest and greatest features by improving user experience by removing links to content that violates the Terms of Service and the', '  \\setlength{\\oddsidemargin}{-69pt}\n       

## Test Set

In [None]:
c_test_NI, k_test_NI = next(iter(test_dl))
c_test_NI, k_test_NI = c_test_NI.squeeze(0), k_test_NI.squeeze(0)
c_orig_test_NI = model.generate(input=c_test_NI[:,:50], max_new_tokens=50, do_sample=False)
k_orig_test_NI = model.generate(input=k_test_NI[:,:50], max_new_tokens=50, do_sample=False)
model_eval(model, c_test_NI, c_orig_test_NI, k_test_NI, k_orig_test_NI, I_range=[50,100])