
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Contrastive Grads

In [None]:
#@title Import libraries
import transformer_lens, torch, gc, itertools, functools, math, glob, tqdm
import pandas as pd
import numpy as np
from pathlib import Path
from collections import defaultdict 

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, LogNorm, SymLogNorm

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False
mpl.rcParams['axes.spines.left'] = False
mpl.rcParams['axes.spines.bottom'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, localizing, patching

torch.set_printoptions(precision=2)

## Model

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu")
modelHandlers.set_no_grad(model, ["embed", "pos_embed", "unembed"])

## Load Data

In [None]:
mem_nonmem_sets  = dataLoaders.load_pile_splits(f"{model_type}/preds", file_names=["50_50_preds.pt", "0_10_preds.pt"], as_torch=True)
mem_set, non_mem_set = mem_nonmem_sets[0], mem_nonmem_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, non_mem_set, mem_batch=30, non_mem_batch=30, test_frac=0.2, shuffle=True, add_bos=None)
c_toks_NI, k_toks_NI = next(iter(train_dl))
c_toks_NI, k_toks_NI = c_toks_NI.squeeze(0), k_toks_NI.squeeze(0)

## load perturbed mem set and original mem set
mem_perturbed_sets  = dataLoaders.load_pile_splits(f"{model_type}/perturbed", file_names=["mem_toks.pt", "perturbed_mem_toks.pt"], as_torch=True)
mem_set, perturbed_mem_set = mem_perturbed_sets[0], mem_perturbed_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, perturbed_mem_set, mem_batch=30, non_mem_batch=30, test_frac=0.2, shuffle=True, add_bos=None)
c_toks_NI, c_perturb_toks_NI = next(iter(train_dl))
c_toks_NI, c_perturb_toks_NI, = c_toks_NI.squeeze(0), c_perturb_toks_NI.squeeze(0)

### Run backprop and gather gradients

In [None]:
def collect_pool_activs(model, cache:dict=None, c_type:str=None, I_range=None):
    vals, names = localizing.collect_c_type(model=model, cache=cache, c_type=c_type)
    pool_vals = localizing.pool_tensor(vals, pool="max", abs_vals=True, topP=0.01, norm_by_entries=False)
    
    ## consider either all tokens or only selected token, then mean over sequences    
    if I_range is not None: 
        pool_vals = pool_vals[:,I_range[0]:I_range[1]]
    pool_vals = pool_vals.mean(1) ## pool over tokens
    pool_vals = pool_vals.mean(0) ## pool over seqs
            
    ## reshape for plotting
    names = list()
    if len(pool_vals.shape) == 2: ## attention
        names = [f"{c_type} H{i}" for i in range(0,pool_vals.shape[1])]
    else: ## mlp
        pool_vals = pool_vals.unsqueeze(-1)
        names = [f"{c_type}"]
    return pool_vals, names


def gather_activation_grads(model, cache:dict, I_range:list=[50,100], c_types:list=["q", "k", "v", "z", "pre", "post"]):  #"q", "k", "v", "z", "pre", "post", "attn_out", "mlp_out"
    vals, names = [], []
    for c_type in c_types:
        c_type_vals, c_type_names = collect_pool_activs(model=model, cache=cache, c_type=c_type, I_range=I_range)
        vals.append(c_type_vals)
        names += c_type_names
    vals = torch.cat(vals, dim=-1)
    return vals, names


def gather_param_grads(model, I_range:list=[50,100], c_types:list=["W_K","W_Q","W_V","W_O","W_in","W_out"]):
    param_vals, param_names = [], []
    for c_type in c_types:
        vals, names = localizing.collect_c_type(model=model, cache=None, c_type=c_type) 
        if c_type in ['W_Q', 'W_K', 'W_V', 'W_O']: ## attention
            vals = vals.view(vals.shape[0], vals.shape[1], -1)
            names = [f"{names[0].split('.')[-1]} H{i}" for i in range(vals.shape[1])]
        else: ## mlps
            vals = vals.view(vals.shape[0], 1, -1)
            names = [f"{names[0].split('.')[-1]}"]
        pool_vals = localizing.pool_tensor(vals, pool="max", abs_vals=True, topP=0.1, norm_by_entries=False)
        param_vals.append(pool_vals)
        param_names += names
    param_vals = torch.cat(param_vals, dim=1)
    return param_vals, param_names


def gather_param_stats(model, param_stats, I_range:list=[49,99], c_types:list=["W_K","W_Q","W_V","W_O","W_in","W_out"]):
    
    param_stats["attn"] = {l: 0.0 for l in range(model.cfg.n_layers)}
    param_stats["mlp"] = {l: 0.0 for l in range(model.cfg.n_layers)}

    for c_type in c_types:
        vals, names = localizing.collect_c_type(model=model, cache=None, c_type=c_type) 
        layer_sum = torch.reshape(torch.abs(vals), (model.cfg.n_layers,-1)).sum(-1)
        param_stats["sum"][c_type] = torch.abs(vals).sum().item()
        #param_stats["mean"][c_type] = vals.mean().item()
        param_stats["var"][c_type] = vals.var().item() 
        param_stats["max"][c_type] = vals.max().item()
        param_stats["min"][c_type] = torch.abs(vals.min()).item()
        
        if c_type in ["W_K","W_V","W_Q","W_O"]:
            for l in range(model.cfg.n_layers):
                param_stats["attn"][l] += layer_sum[l].item()
        elif c_type in ["W_in","W_out"]:
            for l in range(model.cfg.n_layers):
                param_stats["mlp"][l] += layer_sum[l].item()        
        #stats = {"abs_sum": torch.abs(vals).sum().item(), "mean": vals.mean().item(), "var": vals.var().item(), "max": vals.max().item(), "min": torch.abs(vals.min()).item(), "layer_abs_sum":layer_sum}
    return param_stats

def gradient_stats(pool_vals:list):
    total_grad = pool_vals.sum()
    total_grad_per_layer = [f"{x:.2f}" for x in pool_vals.sum(-1)]
    total_grad_var = torch.var(pool_vals).item()
    return total_grad, total_grad_per_layer, total_grad_var


In [None]:
model = modelHandlers.load_model(model, lr=1e-05, weight_decay=1.0)
metric_fn = functools.partial(gradient.contrast_metric, I_range=[49,99], use_perturb=False, c_set_norm=0.1)
fwd_bwd_fn = functools.partial(gradient.run_contrastive_fwd_bwd, optim_steps=-1, topK=None, grad_norm=None, c_types=["W_V","W_Q","W_V","W_O","W_in","W_out"]) 
fwd_cache, bwd_cache, topk_idcs = fwd_bwd_fn(model, metric_fn, c_toks_NI, c_perturb_toks_NI, k_toks_NI)

In [None]:
fwd_bwd = "bwd"
grad_type = "params"

#I_range=[0,99]
I_range=[49,99]

param_stats = {"sum": defaultdict(dict), "var": defaultdict(dict), "max": defaultdict(dict), "min": defaultdict(dict), "attn":defaultdict(dict), "mlp":defaultdict(dict)}

## POOL
if grad_type=="activs":
    pool_vals, names = gather_activation_grads(model, {"fwd":fwd_cache,"bwd":bwd_cache}[fwd_bwd], I_range=I_range)
elif grad_type=="params":
    pool_vals, names = gather_param_grads(model, I_range=I_range)
    param_stats = gather_param_stats(model, param_stats, I_range=I_range)

total_grad, total_grad_per_layer, total_grad_var = gradient_stats(pool_vals)

## Plot Heatmap

In [None]:
fontsize = 12
fig, ax = plt.subplots(1, 1, figsize=(17, 3), gridspec_kw={'hspace': 0.5})

plot_vals = pool_vals.numpy()
s = sns.heatmap(plot_vals[:,:],cmap="binary",center=None,xticklabels=names,yticklabels=np.arange(0,plot_vals.shape[0]),square=False,ax=ax, cbar_kws={'location': 'right','pad': 0.05})#, norm=LogNorm())
s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')
ax.invert_yaxis()

ax.set_title(f"Parameter gradients for contrastive objective on tokens {I_range[0]+1} to {I_range[1]+1}", fontsize=fontsize, loc="left")
ax.tick_params(axis='y', which='major', labelsize=fontsize)
ax.tick_params(axis='x', which='major', labelsize=fontsize-3)
ax.set_ylabel("layer", fontsize=fontsize)

# Create a second y-axis on the right side
ax2 = ax.twinx()
ax2.set_yticks(np.arange(0,len(total_grad_per_layer)))
ax2.set_yticklabels(total_grad_per_layer, fontsize = fontsize-2)
ax2.set_ylabel("layer-wise gradient sum", fontsize = fontsize-2)

fig.savefig(f"{dataLoaders.ROOT}/results/contrastive_objective_max_{I_range[0]}_{I_range[1]}.pdf", dpi=200, bbox_inches="tight")