
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Change-Perturbed Change Set Activation Gradients

In [None]:
#@title Import libraries
import transformer_lens, torch, gc, itertools, functools, math
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, SymLogNorm

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, localizing, patching

## Model

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu")
modelHandlers.set_no_grad(model, ["embed", "pos_embed", "unembed"])

## Load Data

In [None]:
## mem and non-mem set
#(mem_prompts, mem_counts),(non_mem_prompts,non_mem_counts) = dataLoaders.load_pile_splits("acc/gpt2-medium", as_torch=True)
#train_dl, test_dl = dataLoaders.train_test_batching(mem_prompts, non_mem_prompts, mem_batch=10, non_mem_batch=10, test_frac=0.0, shuffle=True, set_twice="k")
#c_toks_NI, k_toks_NI = next(iter(train_dl))
#c_toks_NI, k_toks_NI = c_toks_NI.squeeze(0), k_toks_NI.squeeze(0)

In [None]:
## load perturbed mem set and original mem set
mem_perturbed_sets  = dataLoaders.load_pile_splits(f"{model_type}/perturbed", file_names=["mem_toks.pt", "perturbed_mem_toks.pt"], as_torch=True)
mem_set, perturbed_mem_set = mem_perturbed_sets[0], mem_perturbed_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, perturbed_mem_set, mem_batch=30, non_mem_batch=30, matched=True, shuffle=False, test_frac=0.2, add_bos=None)
c_toks_NI, c_perturb_toks_NI = next(iter(train_dl))
c_toks_NI, c_perturb_toks_NI, = c_toks_NI.squeeze(0), c_perturb_toks_NI.squeeze(0)

### Identify Intervention Token and Impact Token

In [None]:
def get_interv_impact_indeces(c_toks_NI:torch.tensor, k_toks_NI:torch.tensor):
    """
    function to get the positions of the intervention (src) and impact (trg) token
    """
    ck_diff_mask = torch.where(c_toks_NI != k_toks_NI, 1,0)
    ck_diff_cumsum = torch.cumsum(ck_diff_mask, dim=-1) ## intervention

    ## find intervention
    src_NI = (ck_diff_cumsum==1).nonzero() 
    src_idcs = torch.cat((torch.zeros(1),(src_NI[:-1,0] != src_NI[1:,0]).nonzero(as_tuple=True)[0] + 1)).long()
    #src_NI = torch.sub(src_NI, torch.cat((torch.zeros(src_NI.shape[0],1),torch.ones(src_NI.shape[0],1)), dim=-1), alpha=1) ## -1 because we care about what token is predicted
    src_NI = src_NI[src_idcs].long() 

    ## find impact
    trg_NI = (ck_diff_cumsum==2).nonzero() 
    trg_NI_idcs = torch.cat((torch.zeros(1),(trg_NI[:-1,0] != trg_NI[1:,0]).nonzero(as_tuple=True)[0] + 1)).long()
    trg_NI = torch.sub(trg_NI, torch.cat((torch.zeros(trg_NI.shape[0],1),torch.ones(trg_NI.shape[0],1)), dim=-1), alpha=1) ## -1 because we care about what token is predicted
    trg_NI = trg_NI[trg_NI_idcs].long() 
    
    return src_NI, trg_NI

src_NI, trg_NI = get_interv_impact_indeces(c_toks_NI, c_perturb_toks_NI)

## Run backprop and collect activations step

In [None]:
def single_seq_metric(nll_NI:torch.tensor, NI_idcs:torch.tensor=None, pool:dict={"c": []}):
    """
    minimizing / preserve keep_score while maximizing change_score
    """
    ## (1) preprocess________________________________________
    ## select tokens to apply metric to
    nll_NI = nll_NI[NI_idcs[:,0], NI_idcs[:,1]]
    #nll_NI = nll_NI[...,49:]
        
    ## (2) pooling_______________________________________________
    ## pool over dims but then expand again to retain shapes
    nll_NI = gradient.pool_tensor(nll_NI, pool["c"])             
    print(f"pooling nll_NI {nll_NI.shape}, pool: {pool}")
    
    ## (3) apply metric_______________________________________________
    metric_res = nll_NI.mean()
    print(f"contrast loss: {metric_res}")
    return metric_res, None

metric = functools.partial(single_seq_metric, NI_idcs=trg_NI)
c_fwd_cache, c_bwd_cache, _ = gradient.run_single_fwd_bwd(model, metric_fn=metric, c_toks_NI=c_toks_NI)
k_fwd_cache, k_bwd_cache, _ = gradient.run_single_fwd_bwd(model, metric_fn=metric, c_toks_NI=c_perturb_toks_NI)

In [None]:
POOL_FN = {"l1": lambda x, dim: torch.norm(x, p=1, dim=dim),
         "l2": lambda x, dim: torch.norm(x, p=2, dim=dim),
         "frob": lambda x, dim: torch.linalg.matrix_norm(x, ord='fro'), ## toDo: issue requires 2D input
         "mean_abs": lambda x, dim: torch.mean(torch.abs(x), dim=dim),
         "mean": lambda x, dim: torch.mean(x, dim=dim),
         "max_abs": lambda x, dim: torch.max(torch.abs(x), dim=dim)[0],
         "max": lambda x, dim: torch.max(x, dim=dim)[0],
         "pass": lambda x, dim: (x)}

DIST_FN = {"cos": lambda x1, x2: torch.nn.functional.cosine_similarity(x1, x2, dim=-1),
           "sub": lambda x1, x2: x1-x2,
           "sub_abs": lambda x1, x2: torch.abs(x1-x2)}

def pool_tensor(tensor:torch.tensor, pool:str="max", abs_vals:bool=True, topP:float=1.0, norm_by_entries:bool=False):
    """
    pool a tensor and normalize it by the number of entries
    """
    n_params = tensor.numel()
    if len(tensor.shape) == 5: ##ATTN
        n_params = n_params / 12 ## devide by number of heads    

    if abs_vals: ## take absolute values
        tensor = torch.abs(tensor)
        
    norm_by = 1.0
    if norm_by_entries:
        norm_by = math.log(n_params)#n_params**(1/2)
        tensor[tensor!=0] = tensor[tensor!=0]*(1/norm_by)
    
    if 0.0 < topP < 1.0:
        topP = max(int(topP*tensor.shape[-1]), 1) 
    topK_vals, topK_idcs = torch.topk(tensor, int(topP), dim=-1, largest=True)  
    tensorpool = POOL_FN[pool](topK_vals, dim=-1) ## do pooling
    
    #print(f"abs_vals {abs_vals}, topP {topP} selected, {pool} pooled and normalized by: {norm_by}")
    return tensorpool

In [None]:
def collect_pool(cache:dict, second_cache:dict=None, c_type:str=None, idcs_NI=None):
    vals, names = localizing.collect_c_type(model=model, cache=cache, c_type=c_type)
    if second_cache is not None: ## take the difference before pooling
        vals2, names2 = localizing.collect_c_type(model=model, cache=second_cache, c_type=c_type)
        #vals, vals2 = (vals2.sum()/vals.sum())*vals, vals2 ## normalizing
        #vals = (vals - vals2)
        pool_vals = torch.nn.functional.cosine_similarity(vals, vals2, dim=-1)
    else:
        pool_vals = pool_tensor(vals, pool="max", abs_vals=True, topP=0.1, norm_by_entries=False)
    
    ## consider either all tokens or only selected token, then mean over sequences
    if idcs_NI is not None: 
        pool_vals = pool_vals[idcs_NI[:,0],idcs_NI[:,1]].mean(0)
    else:
        pool_vals = pool_vals.mean(1).mean(0)
        
    ## reshape for plotting
    names = list()
    if len(pool_vals.shape) == 2: ## attention
        names = [f"{c_type} H{i}" for i in range(0,pool_vals.shape[1])]
    else: ## mlp
        pool_vals = pool_vals.unsqueeze(-1)
        names = [f"{c_type}"]
    return pool_vals, names

## Activation Gradient Pooling and Plotting

In [None]:
fwd_bwd = "forward"
tok_pos = "target"

idcs_NI = {"source":src_NI,"target":trg_NI}[tok_pos]

In [None]:
def gather_activation_grads(cache:dict, second_cache:dict=None,idcs_NI:torch.tensor=None,c_types:list=["k","q"]): #"attn_out", "mlp_out", "z", "pre", "post"  #"k", "q", "v", "z", "pre", "post", "attn_out", "mlp_out",  "z", "mlp_in", "post","mlp_out"
    vals, names = [], []
    for c_type in c_types:
        c_type_vals, c_type_names = collect_pool(cache, second_cache, c_type=c_type, idcs_NI=idcs_NI)
        vals.append(c_type_vals)
        names += c_type_names
    vals = torch.cat(vals, dim=-1)
    return vals, names

c_vals, names = gather_activation_grads({"forward":c_fwd_cache,"backward":c_bwd_cache}[fwd_bwd], idcs_NI=idcs_NI)
k_vals, names = gather_activation_grads({"forward":k_fwd_cache,"backward":k_bwd_cache}[fwd_bwd], idcs_NI=idcs_NI)
#diff_vals, names = gather_activation_grads({"forward":c_fwd_cache,"backward":c_bwd_cache}[fwd_bwd], {"forward":k_fwd_cache,"backward":k_bwd_cache}[fwd_bwd], idcs_NI=idcs_NI)

In [None]:
fontsize = 12
fig, axs = plt.subplots(2, 1, figsize=(12, 7), gridspec_kw={'hspace': 0.4})

plot_types = ["memorized", "perturbed memorized"] #"grad differences", "diff after pooling",
#plot_types = ["keep set 1", "keep set 2", "diff after pooling", "diff before pooling"]

cmaps = ["PuOr", "PuOr"] #"coolwarm", "binary"
centering = [None, 0.0, 0.0, None]
vals = [c_vals, k_vals]#[k_vals, c_vals]  #((k_vals.sum()/c_vals.sum())*c_vals)-k_vals 

for i, ax in enumerate(axs):
    plot_vals = vals[i].numpy() 
    #sns.heatmap(plot_vals[1:11,:],cmap=mpl.colormaps[cmaps[i]],center=centering[i],xticklabels=names,yticklabels=np.arange(1,plot_vals.shape[0]-1),square=False,ax=ax, cbar_kws={'location': 'right','pad': 0.01})
    # norm=SymLogNorm(linthresh=1.0))
    sns.heatmap(plot_vals[:,:],cmap=mpl.colormaps[cmaps[i]],center=centering[i],xticklabels=names,yticklabels=np.arange(0,plot_vals.shape[0]),square=False,ax=ax, cbar_kws={'location': 'right','pad': 0.01})# norm=SymLogNorm(linthresh=1.0))
    ax.invert_yaxis()
    ax.set_title(f"{plot_types[i]}: {fwd_bwd} activations at {tok_pos} token", fontsize=fontsize, loc="left")
    ax.set_ylabel("layer")
    ax.tick_params(axis='both', which='major', labelsize=fontsize-2)

    
#fig.savefig(f"{dataLoaders.ROOT}/results/{tok_pos}_{fwd_bwd}_{model_type}.pdf", dpi=200, bbox_inches="tight")

## Final Plot Creation_____________________

In [None]:
tok_pos = "source"
fwd_bwd = "backward"
source_c_vals, names = gather_activation_grads({"forward":c_fwd_cache,"backward":c_bwd_cache}[fwd_bwd], idcs_NI= {"source":src_NI,"target":trg_NI}[tok_pos])
source_k_vals, names = gather_activation_grads({"forward":k_fwd_cache,"backward":k_bwd_cache}[fwd_bwd], idcs_NI= {"source":src_NI,"target":trg_NI}[tok_pos])

tok_pos = "target"
fwd_bwd = "backward"
target_c_vals, names = gather_activation_grads({"forward":c_fwd_cache,"backward":c_bwd_cache}[fwd_bwd], idcs_NI= {"source":src_NI,"target":trg_NI}[tok_pos])
target_k_vals, names = gather_activation_grads({"forward":k_fwd_cache,"backward":k_bwd_cache}[fwd_bwd], idcs_NI= {"source":src_NI,"target":trg_NI}[tok_pos])

In [None]:
fontsize = 11
fig, axs = plt.subplots(2, 2, figsize=(17, 7), gridspec_kw={'hspace': 0.3,'wspace': 0.025})

paragraph_types = ["memorized", "perturbed memorized"]
token_types = ["perturbed", "first impacted"]
vals = [[source_c_vals, target_c_vals], [source_k_vals, target_k_vals]]
cmaps = ["RdBu_r", "PuOr"] 
centering = [0, 0]

for i, ax_row in enumerate(axs):
    for j, ax in enumerate(ax_row):
        plot_vals = vals[i][j].numpy() 
        #s = sns.heatmap(plot_vals[1:,:],cmap=mpl.colormaps[cmaps[i]],center=centering[i],xticklabels=names,yticklabels=np.arange(1,plot_vals.shape[0]),square=False,ax=ax, cbar_kws={'location': 'right','pad': 0.01})
        # norm=SymLogNorm(linthresh=1.0))
        s =sns.heatmap(plot_vals[:,:],cmap=mpl.colormaps[cmaps[i]],center=centering[i],xticklabels=names,yticklabels=np.arange(0,plot_vals.shape[0]),square=False,ax=ax, cbar_kws={'location': 'right','pad': 0.01})# norm=SymLogNorm(linthresh=1.0))
        ax.invert_yaxis()
        if i==0:
            ax.set_title(f"activation gradients at {token_types[j]} token (mean over 50 paragraphs)", fontsize=fontsize, loc="left")
        if j==0:
            ax.set_ylabel("layer")
        ax.tick_params(axis='both', which='major', labelsize=fontsize)
        s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')

axs[0,0].set_ylabel('memorized', rotation=90, color="red", fontsize=fontsize, labelpad=5)
axs[1,0].set_ylabel('perturbed memorized', rotation=90, color="purple", fontsize=fontsize, labelpad=5)
        
fig.savefig(f"{dataLoaders.ROOT}/results/activ_grads_perturbed.pdf", dpi=200, bbox_inches="tight")