
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Fine-Tuning: Integrate Localization and Intervention

In [None]:
#@title Import libraries
import transformer_lens
import torch, gc, itertools, functools, tqdm
import pandas as pd
import numpy as np
import torch.nn.functional as F

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, evaluation, localizing, intervening

## Model and Data

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cuda", lr=0.0, weight_decay=0.01)

## Data

In [None]:
## mem and non-mem set
mem_nonmem_sets  = dataLoaders.load_pile_splits(f"{model_type}/preds", file_names=["50_50_preds.pt", "0_10_preds.pt"], as_torch=True)
mem_set, non_mem_set = mem_nonmem_sets[0], mem_nonmem_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, non_mem_set, mem_batch=1, non_mem_batch=10, test_frac=0.2, add_bos=None)
_, k_toks_NI = next(iter(train_dl))
k_toks_NI = k_toks_NI.squeeze(0)

## load perturbed mem set and original mem set
mem_perturbed_sets  = dataLoaders.load_pile_splits(f"{model_type}/perturbed", file_names=["mem_toks.pt", "perturbed_mem_toks.pt"], as_torch=True)
mem_set, perturbed_mem_set = mem_perturbed_sets[0], mem_perturbed_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, perturbed_mem_set, mem_batch=10, non_mem_batch=10, matched=True, shuffle=False, test_frac=0.2, add_bos=None)
c_toks_NI, c_perturb_toks_NI = next(iter(train_dl))
c_toks_NI, c_perturb_toks_NI, = c_toks_NI.squeeze(0), c_perturb_toks_NI.squeeze(0)

### Generate Original Continuations

In [None]:
def model_eval(model,c_NI:torch.LongTensor=None,k_NI:torch.LongTensor=None,I_range:list=[50,100], print_pred:bool=True):
    """
    evaluate the language model on individual batches of c_toks_NI and k_toks_NI
    """
    (c_mean_nll, c_minK_nll), (c_NI_pred, c_NI_true) = evaluation.evaluate_nll_greedy(model, c_NI, batch_size=50)
    (k_mean_nll, k_minK_nll), (k_NI_pred, k_NI_true) = evaluation.evaluate_nll_greedy(model, k_NI, batch_size=50)

    c_em_N = evaluation.compute_exact_match(c_NI_pred, c_NI_true, until_wrong=True)
    k_em_N = evaluation.compute_exact_match(k_NI_pred, k_NI_true, until_wrong=True)

    ## process change and keep set
    c_mean_nll, k_mean_nll = round(c_mean_nll[...,I_range[0]:I_range[1]].mean().detach().item(),4), round(k_mean_nll[...,I_range[0]:I_range[1]].mean().detach().item(),4)
    
    c_changed_frac = torch.where(c_em_N == int(I_range[1]-I_range[0]), 0, 1).sum()
    k_kept_frac = torch.where(k_em_N == int(I_range[1]-I_range[0]), 1, 0).sum() 

    print(f"---Greedy EM--- change set: {c_em_N.mean().item()} [changed {c_changed_frac}/{c_em_N.shape[0]}], keep set: {k_em_N.mean().item()} [kept {k_kept_frac}/{k_em_N.shape[0]}]")
    print(f"---Mean NLL--- change set: {c_mean_nll}, keep set: {k_mean_nll}\n\n")
    
    if print_pred:
        print(f"c_NI_pred: {model.to_string(c_NI_pred)}\n")
        print(f"k_NI_pred: {model.to_string(k_NI_pred)}")
        
    return c_em_N, k_em_N
            
#model_eval(model, c_toks_NI, c_orig_pred_NI, k_toks_NI, k_orig_pred_NI, I_range=[50,100])
c_em_N, k_em_N = model_eval(model, c_toks_NI, k_toks_NI, I_range=[50,100])

In [None]:
def find_topK_grads(model, topK:float=0.001, c_types:list=["W_K", "W_Q", "W_V", "W_O","W_in", "W_out"]):
    """
    find the topK weights in all model parameters
    """
    ## (1) collect all params
    all_param_grads = list()
    all_param_grads = torch.cat([torch.abs(param.grad.view(-1)) for name, param in model.named_parameters() if name.split(".")[-1] in c_types])

    ## (2) identify top params (sparsity)
    topK = abs(topK)
    if 0.0 < topK< 1.0: ## percentage
        topK = int(topK*len(all_param_grads))
        
    if 1.0 <= topK < len(all_param_grads): ## pick top weights
        topk_vals_flat, topk_idcs_flat = torch.topk(all_param_grads, k=int(topK), largest=True)
    
    min_grad = torch.min(topk_vals_flat)
    print(f"{len(topk_idcs_flat)} weights in {c_types} with grads > {min_grad.item()}\n")
    return min_grad, topk_idcs_flat


def clip_grads(model, min_grad:float=None, full_remove_idcs:list=[], topK=0.0): ## in-place
    """
    clip gradients than are not above min_grad, and not below min_grad if keep_neg is enabled
    """
    #param_vec = torch.nn.utils.parameters_to_vector(model.parameters())
    #torch.nn.utils.vector_to_parameters(param_vec, model.parameters())
    list_idx = 0
    removed_n_weights = 0
    for param in model.parameters():
        if param.requires_grad:
            if min_grad is not None:
                if -1.0 < topK < 0.0:
                    remove_idcs = torch.bernoulli(torch.ones(param.grad.shape)*(1.0-abs(topK)))
                else:  
                    remove_idcs = torch.where((param.grad >= min_grad) | (param.grad <= -min_grad), 0, 1)
                full_remove_idcs.append(remove_idcs)
            else:
                remove_idcs = full_remove_idcs[list_idx]
                list_idx += 1
            removed_n_weights += ((~(remove_idcs.bool())).int()).sum()
            param.grad[remove_idcs.bool()] = 0.0 ## annul small positive and negative grads
            
    print(f"clipped at {min_grad} / kept {removed_n_weights.sum()} weights")
    return full_remove_idcs
                
                
#min_grad, topk_idcs_flat = find_topK_grads(model, topK=0.01)
#full_remove_idcs = clip_grads(model, min_grad)
#full_remove_idcs = clip_grads(model, min_grad=None, full_remove_idcs=full_remove_idcs)

In [None]:
KLDiv = torch.nn.KLDivLoss(reduction="batchmean")


def contrast_metric(c_nll_NIT, c_toks_NI, c_perturb_toks_NI, k_logits_NIT, k_logits_fixed_NIT, I_range:list=[49,99], use_perturb:bool=True, c_set_norm:float=None):
    """
    minimizing / preserve keep_score while maximizing change_score
    """
    
    if use_perturb:
        c_nll_NI = modelHandlers.gather_token_scores(c_nll_NIT, c_perturb_toks_NI)
    else:
        c_nll_NI = modelHandlers.gather_token_scores(c_nll_NIT, c_toks_NI)

    
    c_nll_Nc = c_nll_NI[...,I_range[0]:I_range[1]]
    k_logits_NcT, k_logits_fixed_NcT = k_logits_NIT[...,I_range[0]:I_range[1],:], k_logits_fixed_NIT[...,I_range[0]:I_range[1],:]
        
    keep = KLDiv(F.log_softmax(k_logits_NcT,  dim=-1), F.softmax(k_logits_fixed_NcT.detach(), dim=-1)).mean()
    change = c_set_norm * c_nll_Nc.mean()
    
    if use_perturb:
        contrast_res = (keep+change)
    else:
        contrast_res = (keep-change)
        
    print(f"loss: {contrast_res}, mem: {change.detach()}, non mem: {keep.detach()}, use_perturb: {use_perturb}, c_set_norm: {c_set_norm}")
    return contrast_res, None


In [None]:
def run_contrastive_fwd_bwd(model, metric_fn, c_toks_NI, c_perturb_toks_NI, k_toks_NI, optim_steps:int=-1, topK:float=None, grad_norm:float=None, c_types:list=None):
    """
    adding hooks to model, running model on data on metric and returning cached activs, params are cached in model
    """
    fwd_cache, bwd_cache = gradient.add_fwd_bwd_hooks(model, hook_filter={"not in":"_input"})     
    c_toks_NI = c_toks_NI.to(model.cfg.device)
    c_perturb_toks_NI = c_perturb_toks_NI.to(model.cfg.device)
    k_toks_NI = k_toks_NI.to(model.cfg.device)
    k_logits_fixed_NIT = model(k_toks_NI)

    for step_i in range(abs(optim_steps)):
        
        c_nll_NIT = modelHandlers.NegLogLik(model(c_toks_NI))
        k_logits_NIT = model(k_toks_NI)

        metric_res, metric_norm = metric_fn(c_nll_NIT, c_toks_NI, c_perturb_toks_NI, k_logits_NIT, k_logits_fixed_NIT) 

        model.zero_grad()
        metric_res.backward(retain_graph=False)

        if grad_norm is not None:
            print(f"applied grad norm clipping with max norm {grad_norm}")
            torch.nn.utils.clip_grad_norm_(model.parameters(), float(grad_norm), norm_type=2.0)

        if topK is not None:
            if step_i == 0:
                min_grad, topk_idcs = find_topK_grads(model, topK=topK, c_types=c_types)
                full_remove_idcs = clip_grads(model, min_grad, full_remove_idcs=[], topK=topK)
            full_remove_idcs = clip_grads(model, min_grad=None, full_remove_idcs=full_remove_idcs)
        else:
            full_remove_idcs = None

        if optim_steps >= 1 and hasattr(model, 'optim'):
            print(f"{step_i+1}/{abs(optim_steps)}, optimizer step")
            model.optim.step()
            model.optim.zero_grad()
    
    del c_toks_NI
    del k_toks_NI
    del c_perturb_toks_NI
    torch.cuda.empty_cache()

    fwd_cache = transformer_lens.ActivationCache(fwd_cache, model)
    bwd_cache = transformer_lens.ActivationCache(bwd_cache, model)
    return fwd_cache, bwd_cache, full_remove_idcs

model = modelHandlers.load_model(model, lr=1e-05, weight_decay=1.0)
metric_fn = functools.partial(contrast_metric, I_range=[49,99], use_perturb=True, c_set_norm=0.1)
fwd_bwd_fn = functools.partial(run_contrastive_fwd_bwd, optim_steps=5, topK=0.01, grad_norm=None, c_types=["W_K","W_Q","W_V","W_O","W_in","W_out"])  #"W_K","W_Q","W_V","W_O","W_in","W_out"
fwd_cache, bwd_cache, topk_idcs = fwd_bwd_fn(model, metric_fn, c_toks_NI, c_perturb_toks_NI, k_toks_NI)

In [None]:
c_em, k_em = model_eval(model, c_toks_NI, k_toks_NI, I_range=[50,100])
#model_eval(model, c_toks_NI, None, k_toks_NI, k_orig_pred_NI, I_range=[50,100])

## Test Set

In [None]:
c_test_NI, k_test_NI = next(iter(test_dl))
c_test_NI, k_test_NI = c_test_NI.squeeze(0), k_test_NI.squeeze(0)
model_eval(model, c_test_NI, k_test_NI, I_range=[50,100])