# Optimizer Step: Integrate Localization and Intervention

In [None]:
#@title Import libraries
import transformer_lens
import torch, gc, itertools, functools, tqdm
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, evaluation, localizing, intervening

## Model and Data

In [120]:
model = modelHandlers.load_model(model_type="gpt2-medium", DEVICE="cpu", lr=0.0, weight_decay=0.01)

Loaded pretrained model gpt2-medium into HookedTransformer


## Data

In [128]:
train_dl, test_dl = dataLoaders.batched_pile(mem_batch=1, non_mem_batch=5, test_frac=0.0, load_data="acc/gpt2-medium", set_twice=None)
#train_dl = dataLoaders.batched_pop_seqs(model, mem_batch=1, non_mem_batch=50)
c_toks_NI, k_toks_NI = next(iter(train_dl))
c_toks_NI, k_toks_NI = c_toks_NI.squeeze(0), k_toks_NI.squeeze(0)

In [129]:
c_orig_pred_NI = model.generate(input=c_toks_NI[:,:50], max_new_tokens=50, do_sample=False)
k_orig_pred_NI = model.generate(input=k_toks_NI[:,:50], max_new_tokens=50, do_sample=False)

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

In [132]:
def model_eval(model,c_NI:torch.LongTensor=None,c_orig_pred_NI:torch.LongTensor=None,k_NI:torch.LongTensor=None,k_orig_pred_NI:torch.LongTensor=None,I_range:list=[50,100], print_pred:bool=True):
    """
    evaluate the language model on individual batches of c_toks_NI and k_toks_NI
    """
    ## change set
    (c_mean_nll, c_minK_nll), (_, _) = evaluation.evaluate_nll_greedy(model, c_NI, batch_size=50)
    c_NI_pred = model.generate(input=c_NI[:,:50], max_new_tokens=50, do_sample=False)
    c_em_N = evaluation.compute_exact_match(c_NI_pred[...,-(I_range[1]-I_range[0]):], c_orig_pred_NI[:,-(I_range[1]-I_range[0]):], until_wrong=False)
    
    ## keep set
    (k_mean_nll, k_minK_nll), (_, _) = evaluation.evaluate_nll_greedy(model, k_NI, batch_size=50)
    k_NI_pred = model.generate(input=k_NI[:,:50], max_new_tokens=I_range[1]-I_range[0], do_sample=False)
    k_em_N = evaluation.compute_exact_match(k_NI_pred[...,-(I_range[1]-I_range[0]):], k_orig_pred_NI[:,-(I_range[1]-I_range[0]):], until_wrong=False)

    ## process change and keep set
    c_mean_nll, k_mean_nll = round(c_mean_nll[:,I_range[0]:I_range[1]].mean().detach().item(),4), round(k_mean_nll[:,I_range[0]:I_range[1]].mean().detach().item(),4)
    c_em_N, k_em_N = c_em_N.mean().item(), k_em_N.mean().item()
        
    print(f"---Greedy EM--- change set: {c_em_N}, keep set: {k_em_N}  [mean over {k_NI.shape[0]} seqs]")
    print(f"---Mean NLL--- change set: {c_mean_nll}, keep set: {k_mean_nll}")
    
    if print_pred:
        print(f"c_NI_pred: {model.to_string(c_NI_pred)}")
        print(f"k_NI_pred: {model.to_string(k_NI_pred)}")
    #return (c_em_N, k_em_N), (c_mean_nll, k_mean_nll)

In [133]:
model_eval(model, c_toks_NI, c_orig_pred_NI, k_toks_NI, k_orig_pred_NI, I_range=[50,100])


  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

---Greedy EM--- change set: 50.0, keep set: 50.0  [mean over 5 seqs]
---Mean NLL--- change set: 0.987, keep set: 5.2917
c_NI_pred: [' is currently unavailable. Please, try again later. View profileView wishlistStart conversationInvite to friendsInvite to friendsAccept invitationAccept invitationPending invitation...User since {{ user.formattedDateUserJoined }} Friends since {{ user.formattedDateUserFriended }} {{ user.formattedDateUserExists }} {{ user.formattedDateUserIsOffline }} {{ user.formattedDateUserIsJoined }} {{ user.formattedDateUserIsFriends }} {{ user.form']
k_NI_pred: [' Islam. They either rape them or sell them on for £10 or so to new masters. The girls are the victims of slavery, child abuse and forced marriage. Their captors are by extension slavers and rapists.\n\nAs you can see, the Islamic State is not a religion. It is a criminal organisation. It is a criminal organisation that is not only a threat to the world, but to the lives of its members.\n\nThe Islamic State 

In [134]:
def find_topK_grads(model, topK:float=0.001, abs_grad:bool=True, c_types:list=["W_in", "W_out", "W_K", "W_V", "W_Q", "W_O"]):
    """
    find the topK weights in all model parameters
    """
    ## (1) collect all params
    all_params = list()
    all_params = torch.cat([param.grad.view(-1) for name, param in model.named_parameters() if name.split(".")[-1] in c_types])
    if abs_grad:
        all_params = torch.abs(all_params)

    ## (2) identify top params (sparsity)
    if 0.0 < topK < 1.0: ## percentage
        topk_vals, topk_idcs = torch.topk(all_params, k=int(topK*len(all_params)), largest=True)
    elif 1.0 <= topK < len(all_params): ## pick top weights
        topk_vals, topk_idcs = torch.topk(all_params, k=int(topK), largest=True)
    min_grad = torch.min(topk_vals)
    print(f"{len(topk_idcs)} weights in {c_types} with grads > {min_grad.item()} abs_grad: {abs_grad}")
    return min_grad, topk_idcs

def clip_grads(model, min_grad:float, keep_neg:bool=True): ## in-place
    """
    clip gradients than are not above min_grad, and not below min_grad if keep_neg is enabled
    """
    #param_vec = torch.nn.utils.parameters_to_vector(model.parameters())
    #torch.nn.utils.vector_to_parameters(param_vec, model.parameters())
    print(f"clipped at {min_grad} with keep_neg: {keep_neg}")
    for param in model.parameters():
        if param.requires_grad:
            if keep_neg: ## keep very small and very big gradients
                param.grad[(0 < param.grad) & (param.grad < min_grad)] = 0.0 ## annul small positive
                param.grad[(0 > param.grad) & (param.grad > -min_grad)] = 0.0 ## annul small negative
            else: ## keep only very big gradients
                param.grad[(param.grad < min_grad)] = 0.0 ## annul all grads smaller than min_grad

In [135]:
def run_fwd_bwd(model, metric_fn, c_toks_NI:torch.LongTensor=None, k_toks_NI:torch.LongTensor=None, optim_step:bool=False, topK:float=None, grad_norm:float=None, c_types:list=None):
    """
    adding hooks to model, running model on data on metric and returning cached activs, params are cached in model
    """
    fwd_cache, bwd_cache = gradient.add_fwd_bwd_hooks(model, hook_filter={"not in":"_input"})
    c_nll = modelHandlers.gather_token_scores(modelHandlers.NegLogLik(model(c_toks_NI.to(model.cfg.device))).to("cpu"), c_toks_NI)
    k_nll = modelHandlers.gather_token_scores(modelHandlers.NegLogLik(model(k_toks_NI.to(model.cfg.device))).to("cpu"), k_toks_NI)    
    metric_res, metric_norm = metric_fn(c_nll, k_nll)
    
    model.zero_grad()
    metric_res.backward(retain_graph=False)
    
    if grad_norm is not None:
        print(f"applied grad norm clipping with max norm {grad_norm}")
        torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_norm), norm_type=2.0)
               
    topk_idcs = None
    if topK is not None:
        min_grad, topk_idcs = find_topK_grads(model, topK=topK, c_types=c_types, abs_grad=True)
        clip_grads(model, min_grad, keep_neg=True)
    
    if optim_step and hasattr(model, 'optim'):
        print(f"optimizer step to change model params")
        model.optim.step()
        model.optim.zero_grad()

    fwd_cache = transformer_lens.ActivationCache(fwd_cache, model)
    bwd_cache = transformer_lens.ActivationCache(bwd_cache, model)
    return fwd_cache, bwd_cache, topk_idcs

model = modelHandlers.load_model(model, lr=0.0001, weight_decay=0.01)
metric_fn = functools.partial(gradient.contrast_metric, I_range=[50,100], with_mse=True, pool={"c":[0],"k":[0]}, norm_sets=1.0)
fwd_bwd_fn = functools.partial(run_fwd_bwd, optim_step=True, topK=1000, grad_norm=None, c_types=["W_in", "W_out"])#, "W_K","W_V","W_Q"])
fwd_cache, bwd_cache, topk_idcs = fwd_bwd_fn(model, metric_fn, c_toks_NI, k_toks_NI)

reset model gpt2-medium
Loaded pretrained model gpt2-medium into HookedTransformer
added optimizer with lr: 0.0001 and weight_decay: 0.01
contrast_res: -5.291677474975586, c_nll: 5.291677474975586, k_nll: 0.0
1000 weights in ['W_in', 'W_out'] with grads > 2.8466906547546387 abs_grad: True
clipped at 2.8466906547546387 with keep_neg: True
optimizer step to change model params


In [136]:
model_eval(model, c_toks_NI, c_orig_pred_NI, k_toks_NI, k_orig_pred_NI, I_range=[50,100])

  0%|          | 0/50 [00:00<?, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

---Greedy EM--- change set: 1.0, keep set: 6.599999904632568  [mean over 5 seqs]
---Mean NLL--- change set: 7.4672, keep set: 8.0368
c_NI_pred: [' is currently unavailable. Please, try again later. View profileView wishlistStart conversationInvite to friendsInvite to friendsAccept invitationAccept invitationPending invitation...User since {{ user.formattedDateUserJoined }} Friends since {{ user.formatted\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n']
k_NI_pred: [' Islam. They either rape them or sell them on for £10 or so to new masters. The girls are the victims of slavery, child abuse and forced marriage. Their captors are by extension slavers and rapists.\n\nAs you can see\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n', 's prisons on false charges on terrorism. Senator James Lankford is calling for the US to impose sanctions in light of Turkey’s actions.\n\nPlease continue to p

In [119]:
#model = modelHandlers.load_model(model, lr=0.0, weight_decay=0.01)
k_preds_NI, _ = modelHandlers.batch_decode(model, k_toks_NI, new_toks=50)
em = evaluation.compute_exact_match(k_preds_NI, k_preds_orig, until_wrong=False)
print(f"keep exact match: {em}")
model.to_string(k_preds_NI)

0it [00:00, ?it/s]


ValueError: not enough values to unpack (expected 2, got 1)

In [27]:
model.to_string(k_preds_orig)

['ombia); CERN, CERN-CAS, and CERN-CAS-CAS (Colombia); CERN-CAS-CAS-CAS (Colombia); CERN-CAS-CAS',
 ' ());\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n\n}\n',
 'W5hbWVzdGFuZS5jb20vb2VudHMudGhlbWVzdGFuZS5jb20vb2VudHMudGhlbWV',
 'Linux is a free operating system that runs on a variety of hardware. It is a free operating system that runs on a variety of hardware. It is a free operating system that runs on a variety of hardware.\n\nLinux is a free operating system',
 ' and the cultivars are selected for their ability to withstand the stresses of the growing season.\n\nThe selection of cultivars is based on the following criteria:\n\nThe cultivars are selected for their ability to withstand the stresses of the growing season',
 '                                                  ',
 ', and are often the subject of a great deal of discussion.\n\nThe book is divided into three parts, each of which is divided into chapters. The first part, The General Prologue, i

In [34]:
c_preds_NI, c_preds_NI_true = modelHandlers.batch_decode(model, c_toks_NI)
em = evaluation.compute_exact_match(c_preds_NI, c_preds_NI_true, until_wrong=False)
print(f"change exact match: {em}")
model.to_string(c_preds_NI)

0it [00:00, ?it/s]

  0%|          | 0/50 [00:00<?, ?it/s]

1it [00:04,  4.73s/it]

change exact match: tensor([5.])





['istration.py:1062\n\nmsgstr ""\n\n#: ../server/handlers/xmlrpc/registration.py:1063\n\nmsgstr ""\n\n#: ../server/handlers/xml']

In [32]:
model.to_string(c_toks_NI[:,:,50:].squeeze())

'istration.py:1071\nmsgid "Red Hat Satellite Privacy Statement"\nmsgstr ""\n\n#: ../server/handlers/xmlrpc/registration.py:1092\nmsgid "Expected a dictionary'