
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Single Sequence Activation Gradients


In [None]:
#@title Import libraries
import transformer_lens
import torch, gc, itertools, functools
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, localizing

## Model

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu")
modelHandlers.set_no_grad(model, ["embed", "pos_embed", "unembed"])

## Load Data

In [None]:
## mem and non-mem set
#(mem_prompts, mem_counts),(non_mem_prompts,non_mem_counts) = dataLoaders.load_pile_splits("acc/gpt2-medium", as_torch=True)
#train_dl, test_dl = dataLoaders.train_test_batching(mem_prompts, non_mem_prompts, mem_batch=10, non_mem_batch=10, test_frac=0.0, shuffle=True, set_twice="k")
#c_toks_NI, k_toks_NI = next(iter(train_dl))
#c_toks_NI, k_toks_NI = c_toks_NI.squeeze(0), k_toks_NI.squeeze(0)

In [None]:
## load perturbed mem set and original mem set
mem_perturbed_sets  = dataLoaders.load_pile_splits(f"{model_type}/perturbed", file_names=["mem_toks.pt", "perturbed_mem_toks.pt"], as_torch=True)
mem_set, perturbed_mem_set = mem_perturbed_sets[0], mem_perturbed_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, perturbed_mem_set, mem_batch=30, non_mem_batch=30, matched=True, shuffle=False, test_frac=0.2, add_bos=None)
c_toks_NI, c_perturb_toks_NI = next(iter(train_dl))
c_toks_NI, c_perturb_toks_NI, = c_toks_NI.squeeze(0), c_perturb_toks_NI.squeeze(0)

In [None]:
#string_NI = [" headlines out of Washington never seem to slow. Subscribe to The D.C. Brief to make sense of what matters most. Please enter a valid email address. Sign Up Now Check the box if you do not wish to receive promotional offers via email from TIME. You can unsubscribe at any time. By signing up you are agreeing to our Terms of Use and Privacy Policy . This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply. Thank you! For your"]
string_NI = ["Sign up for Take Action Now and get three actions in your inbox every week. You will receive occasional promotional offers for programs that support The Nation’s journalism. You can read our Privacy Policy here. Sign up for Take Action Now and get three actions in your inbox every week.\n\nThank you for signing up. For more from The Nation, check out our latest issue\n\nSubscribe now for as little as $2 a month!\n\nSupport Progressive Journalism The Nation is reader supported:"]
#string_NI = ["The following are trademarks or service marks of Major League Baseball entities and may be used only with permission of Major League Baseball Properties, Inc. or the relevant Major League Baseball entity: Major League, Major League Baseball, MLB, the silhouetted batter logo, World Series, National League, American League, Division Series, League Championship Series, All-Star Game, and the names, nicknames, logos, uniform designs, color combinations, and slogans designating the Major League Baseball clubs and entities, and"]
c_toks_NI = model.to_tokens(string_NI, prepend_bos=False)

## Gradient Attribution

In [None]:
def single_seq_metric(nll_NI:torch.tensor, I_range:list=None, pool:dict={"c": []}):
    """
    minimizing / preserve keep_score while maximizing change_score
    """
    ## (1) preprocess________________________________________
    ## select tokens to apply metric to
    nll_NI = nll_NI[...,I_range[0]:I_range[1]]
        
    ## (2) pooling_______________________________________________
    ## pool over dims but then expand again to retain shapes
    nll_NI = gradient.pool_tensor(nll_NI, pool["c"])             
    print(f"pooling nll_NI {nll_NI.shape}, pool: {pool}")
    
    ## (3) apply metric_______________________________________________
    metric_res = nll_NI.mean()
    print(f"contrast loss: {metric_res}")
    return metric_res, None

metric = functools.partial(single_seq_metric, I_range=[49,50])
c_fwd_cache, c_bwd_cache, _ = gradient.run_single_fwd_bwd(model, metric_fn=metric, c_toks_NI=c_toks_NI)

In [None]:
POOL_FN = {"l1": lambda x, dim: torch.norm(x, p=1, dim=dim),
         "l2": lambda x, dim: torch.norm(x, p=2, dim=dim),
         "frob": lambda x, dim: torch.linalg.matrix_norm(x, ord='fro'), ## toDo: issue requires 2D input
         "mean_abs": lambda x, dim: torch.mean(torch.abs(x), dim=dim),
         "mean": lambda x, dim: torch.mean(x, dim=dim),
         "max_abs": lambda x, dim: torch.max(torch.abs(x), dim=dim)[0],
         "max": lambda x, dim: torch.max(x, dim=dim)[0],
         "pass": lambda x, dim: (x)}

DIST_FN = {"cos": lambda x1, x2: torch.nn.functional.cosine_similarity(x1, x2, dim=-1),
           "sub": lambda x1, x2: x1-x2,
           "sub_abs": lambda x1, x2: torch.abs(x1-x2)}

def pool_tensor(tensor:torch.tensor, pool:str="max", abs_vals:bool=True, topP:float=1.0, norm_by_entries:bool=False):
    """
    pool a tensor and normalize it by the number of entries
    """
    n_params = tensor.numel()
    if len(tensor.shape) == 5: ##ATTN
        n_params = n_params / 12 ## devide by number of heads    

    if abs_vals: ## take absolute values
        tensor = torch.abs(tensor)
        
    norm_by = 1.0
    if norm_by_entries:
        norm_by = math.log(n_params)#n_params**(1/2)
        tensor[tensor!=0] = tensor[tensor!=0]*(1/norm_by)
    
    if 0.0 < topP < 1.0:
        topP = max(int(topP*tensor.shape[-1]), 1) 
    topK_vals, topK_idcs = torch.topk(tensor, int(topP), dim=-1, largest=True)  
    tensorpool = POOL_FN[pool](topK_vals, dim=-1) ## do pooling
    
    #print(f"abs_vals {abs_vals}, topP {topP} selected, {pool} pooled and normalized by: {norm_by}")
    return tensorpool

In [None]:
def collect_pool(cache:dict, second_cache:dict=None, c_type:str=None):
    vals, names = localizing.collect_c_type(model=model, cache=cache, c_type=c_type)
    if second_cache is not None: ## take the difference before pooling
        vals2, names2 = localizing.collect_c_type(model=model, cache=second_cache, c_type=c_type)
        #vals, vals2 = (vals2.sum()/vals.sum())*vals, vals2 ## normalizing
        #vals = (vals - vals2)
        pool_vals = torch.nn.functional.cosine_similarity(vals, vals2, dim=-1)
    else:
        pool_vals = pool_tensor(vals, pool="max", abs_vals=True, topP=0.1, norm_by_entries=False)
    
    ## consider either all tokens or only selected token, then mean over sequences
    pool_vals = pool_vals[:,:49].mean(0)
        
    ## reshape for plotting
    names = list()
    if len(pool_vals.shape) == 2: ## attention
        names = [f"{c_type} H{i}" for i in range(0,pool_vals.shape[1])]
    else: ## mlp
        pool_vals = pool_vals.unsqueeze(-1)
        names = [f"{c_type}"]
            
    return pool_vals, names ## I, L, H, C

## Gather Activation Gradients

In [None]:
def gather_activation_grads(cache:dict, second_cache:dict=None,c_types:list=["q"]): #"attn_out", "mlp_out", "z", "pre", "post"  #"k", "q", "v", "z", "pre", "post", "attn_out", "mlp_out",  "z", "mlp_in", "post","mlp_out"
    vals, names = [], []
    for c_type in c_types:
        c_type_vals, c_type_names = collect_pool(cache, second_cache, c_type=c_type)
        vals.append(c_type_vals)
        names += c_type_names
    vals = torch.cat(vals, dim=-1)
    return vals, names

c_type = "v"
c_vals, names = gather_activation_grads(c_bwd_cache, c_types=[c_type])
#diff_vals, names = gather_activation_grads({"forward":c_fwd_cache,"backward":c_bwd_cache}[fwd_bwd], {"forward":k_fwd_cache,"backward":k_bwd_cache}[fwd_bwd], idcs_NI=idcs_NI)

In [None]:
head = 2
if len(c_vals.shape)==4:
    vals = c_vals[:,:,head,:].squeeze().T  ## I, L, H, C
else:
    vals = c_vals.T

fontsize = 12
fig, ax = plt.subplots(1, 1, figsize=(14, 3), gridspec_kw={'hspace': 0.4})
s = sns.heatmap(vals,
              cmap=mpl.colormaps["Reds"], center=None,
              xticklabels=model.to_str_tokens(c_toks_NI[:,:49]),
              yticklabels=np.arange(0,12), square=False,
              cbar_kws={'location': 'right','pad': 0.01})
ax.invert_yaxis()

ax.set_title(f"Activation gradients of {c_type} H{head}", fontsize=fontsize, loc="left")
ax.set_ylabel("layer", fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)
s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')
fig.savefig(f"{dataLoaders.ROOT}/results/activ_grads_layers.pdf", dpi=200, bbox_inches="tight")

In [None]:
layer = 1
if len(c_vals.shape)==4:
    vals = c_vals[:,layer,:,:].squeeze().T  ## I, L, H, C
else:
    vals = c_vals.T

fontsize = 12
fig, ax = plt.subplots(1, 1, figsize=(14, 3), gridspec_kw={'hspace': 0.4})
s = sns.heatmap(vals,
              cmap=mpl.colormaps["Reds"], center=None,
              xticklabels=model.to_str_tokens(c_toks_NI[:,:49]),
              yticklabels=np.arange(0,12), square=False,
              cbar_kws={'location': 'right','pad': 0.01})

ax.set_title(f"{c_type} activation gradients at layer {layer}", fontsize=fontsize, loc="left")
ax.set_ylabel("head", fontsize=fontsize)
ax.tick_params(axis='both', which='major', labelsize=fontsize)
s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')
fig.savefig(f"{dataLoaders.ROOT}/results/activ_grads_heads.pdf", dpi=200, bbox_inches="tight")