
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Cosine Distance on FWD and BWD pass


In [None]:
#@title Import libraries
import transformer_lens
import torch, gc, itertools, functools, copy
import pandas as pd
import numpy as np

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, localizing

## Model

In [None]:
model_type="gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu", lr=0.0)

## Load Data

In [None]:
## mem and non-mem set
#(mem_prompts, mem_counts),(non_mem_prompts,non_mem_counts) = dataLoaders.load_pile_splits("acc/gpt-neo-125M", as_torch=True)
#train_dl, test_dl = dataLoaders.train_test_batching(mem_prompts, non_mem_prompts, mem_batch=5, non_mem_batch=5, test_frac=0.0, set_twice=None)
#c_toks_NI, k_toks_NI = next(iter(train_dl))

mem_nonmem_sets  = dataLoaders.load_pile_splits(f"{model_type}/preds", file_names=["50_50_preds.pt", "0_10_preds.pt"], as_torch=True)
mem_set, non_mem_set = mem_nonmem_sets[0], mem_nonmem_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, non_mem_set, mem_batch=50, non_mem_batch=50, test_frac=0.2, add_bos=None)
c_toks_NI, k_toks_NI = next(iter(train_dl))
c_toks_NI, k_toks_NI = c_toks_NI.squeeze(0), k_toks_NI.squeeze(0)

## load perturbed mem set and original mem set
mem_perturbed_sets  = dataLoaders.load_pile_splits(f"{model_type}/perturbed", file_names=["mem_toks.pt", "perturbed_mem_toks.pt"], as_torch=True)
mem_set, perturbed_mem_set = mem_perturbed_sets[0], mem_perturbed_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, perturbed_mem_set, mem_batch=20, non_mem_batch=20, matched=True, shuffle=False, test_frac=0.2, add_bos=None)
c_toks_NI, c_perturb_toks_NI = next(iter(train_dl))
c_toks_NI, c_perturb_toks_NI, = c_toks_NI.squeeze(0), c_perturb_toks_NI.squeeze(0)

#c_toks_NI, k_toks_NI = c_toks_NI[0,:], k_toks_NI[0,:]

## Cosine Similarity

In [None]:
def tensor_diffs(tok_set_vals:list, abs_vals:bool=False):
    vals = []
    for c_type_idx in range(0, len(tok_set_vals[0])):
        vals1, vals2 = tok_set_vals[0][c_type_idx], tok_set_vals[1][c_type_idx]
        if abs_vals:
            vals1, vals2 = torch.abs(vals1), torch.abs(vals2)
        dist_val = torch.nn.functional.cosine_similarity(vals1, vals2, dim=-1)
        #dist_val = (vals1 - vals2).mean(-1)
        vals.append(dist_val)
    vals = torch.cat(vals, dim=1)
    return vals

### Parameter Differences

In [None]:
def collect_param_grads(toks_NI:torch.Tensor, NI_idcs=[49,99], c_types:list=["W_K","W_Q","W_V","W_O","W_in","W_out"]): #"W_V","W_O","W_in","W_out"
    metric = functools.partial(gradient.single_seq_metric, NI_idcs=NI_idcs)
    fwd_cache, bwd_cache, _ = gradient.run_single_fwd_bwd(model, metric_fn=metric, c_toks_NI=toks_NI)
    c_type_vals, c_type_names = [], []
    for c_type in c_types:
        vals, names = localizing.collect_c_type(model=model, cache=None, c_type=c_type) 
        name = names[0].split('.')[-1]
        if name in ['W_Q', 'W_K', 'W_V', 'W_O']: ## attention
            vals = vals.view(vals.shape[0], vals.shape[1], -1)
            names = [f"{names[0].split('.')[-1]} H{i}" for i in range(vals.shape[1])]
        else: ## mlps
            vals = vals.view(vals.shape[0], 1, -1)
            names = [f"{names[0].split('.')[-1]}"]
        c_type_vals.append(vals)
        c_type_names += names
    return c_type_vals, c_type_names

### Activation Differences

In [None]:
def collect_activs(toks_NI:torch.Tensor, fwd_bwd:str="bwd", I_range=[49,99], c_types:list=["q", "k", "v", "z", "pre", "post"]):  #"q", "k", "v", "z", "pre", "post", "attn_out", "mlp_out"
    metric = functools.partial(gradient.single_seq_metric, NI_idcs=I_range)
    fwd_cache, bwd_cache, _ = gradient.run_single_fwd_bwd(model, metric_fn=metric, c_toks_NI=toks_NI)
    cache = {"fwd":fwd_cache, "bwd":bwd_cache}[fwd_bwd]
    c_type_vals, c_type_names = [], []
    for c_type in c_types:
        vals, names = localizing.collect_c_type(model=model, cache=cache, c_type=c_type) ## vals shape NILHD
        
        if I_range is not None: 
            vals = vals[:,I_range[0]:I_range[1]]
        vals = vals.mean(1) ## pool over tokens
        vals = vals.mean(0) ## pool over seqs
        
        name = names[0].split('.')[-1].split('_')[-1]
        if name not in ['q', 'k', 'v', 'z']: ## mlps
            vals = vals.view(vals.shape[0], 1, -1)
            names = [f"{names[0].split('.')[-1].split('_')[-1]}"]
        else:
            names = [f"{names[0].split('.')[-1].split('_')[-1]} H{i}" for i in range(vals.shape[1])]
              
        c_type_vals.append(vals)
        c_type_names += names
    return c_type_vals, c_type_names

## Run and Plot

### Parameters

In [None]:
c_types=["W_K","W_Q","W_V","W_O", "W_in","W_out"] #"W_K","W_Q","W_V","W_O"
c_vals_list, c_names_list = collect_param_grads(c_toks_NI, c_types=c_types)
k_vals_list, c_names_list = collect_param_grads(k_toks_NI, c_types=c_types)
dist_vals = tensor_diffs(tok_set_vals=[c_vals_list, k_vals_list])

In [None]:
fontsize = 12
fig, ax = plt.subplots(1, 1, figsize=(13, 3), gridspec_kw={'hspace': 0.4})

plot_vals = dist_vals.numpy()
#s = sns.heatmap(plot_vals[1:11,:],cmap=mpl.colormaps["coolwarm_r"],center=0,xticklabels=c_names_list,yticklabels=np.arange(1,plot_vals.shape[0]-1),cbar_kws={'location': 'right','pad': 0.01})
s = sns.heatmap(plot_vals[:,:],cmap=mpl.colormaps["coolwarm_r"],center=0,xticklabels=c_names_list,yticklabels=np.arange(0,plot_vals.shape[0]),cbar_kws={'location': 'right','pad': 0.01})

ax.set_title(f"Cosine similarity between parameter gradients\nof 50 memorized and 50 non-memorized paragraphs", fontsize=fontsize, loc="left")
s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')
ax.tick_params(axis='y', which='major', labelsize=fontsize)
ax.tick_params(axis='x', which='major', labelsize=fontsize-3)
ax.set_ylabel('layer', rotation=90, color="black", fontsize=fontsize, labelpad=5)
ax.invert_yaxis()

fig.savefig(f"{dataLoaders.ROOT}/results/{model_type}_param_cosine.png", dpi=200, bbox_inches="tight")

### Activations

In [None]:
c_types = ["v"]
I_range = [49,99]
c_vals_list, c_names_list = collect_activs(c_toks_NI, fwd_bwd="bwd", c_types=c_types, I_range=I_range)
k_vals_list, c_names_list = collect_activs(k_toks_NI, fwd_bwd="bwd", c_types=c_types, I_range=I_range)
dist_vals = tensor_diffs(tok_set_vals=[c_vals_list, k_vals_list])

In [None]:
fontsize = 13
fig, ax = plt.subplots(1, 1, figsize=(5, 3.5), gridspec_kw={'hspace': 0.4})

plot_vals = dist_vals.numpy()
#s = sns.heatmap(plot_vals[1:11,:],cmap=mpl.colormaps["coolwarm_r"],center=0,xticklabels=c_names_list,yticklabels=np.arange(1,plot_vals.shape[0]-1),cbar_kws={'location': 'right','pad': 0.01})
s = sns.heatmap(plot_vals[:,:],cmap=mpl.colormaps["coolwarm_r"],center=None,xticklabels=c_names_list,yticklabels=np.arange(0,plot_vals.shape[0]),cbar_kws={'location': 'right','pad': 0.03})

ax.set_title(f"Cosine similarity between activation gradients\nover 50 memorized and 50 non-memorized paragraphs", fontsize=fontsize-2, x=-0.1, loc="left")
s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')
ax.tick_params(axis='y', which='major', labelsize=fontsize)
ax.tick_params(axis='x', which='major', labelsize=fontsize-1)
ax.set_ylabel('layer', rotation=90, color="black", fontsize=fontsize, labelpad=5)
ax.invert_yaxis()

fig.savefig(f"{dataLoaders.ROOT}/results/{model_type}_bwd_activs_cosine.pdf", dpi=200, bbox_inches="tight")

In [None]:
c_types = ["k", "q", "v"]#["attn_out", "mlp_out"]#["k", "q", "v", "pre","post"]
I_range = [49,50]
c_vals_list, c_names_list = collect_activs(c_toks_NI, fwd_bwd="fwd", c_types=c_types, I_range=I_range)
k_vals_list, c_names_list = collect_activs(c_perturb_toks_NI, fwd_bwd="fwd", c_types=c_types, I_range=I_range)
dist_vals = tensor_diffs(tok_set_vals=[c_vals_list, k_vals_list])

In [None]:
fontsize = 12
fig, ax = plt.subplots(1, 1, figsize=(10, 3), gridspec_kw={'hspace': 0.4})

plot_vals = dist_vals.numpy()
#["attn_out", "mlp_out"]
#s = sns.heatmap(plot_vals[1:11,:],cmap=mpl.colormaps["coolwarm_r"],center=0,xticklabels=c_names_list,yticklabels=np.arange(1,plot_vals.shape[0]-1),cbar_kws={'location': 'right','pad': 0.01})
s = sns.heatmap(plot_vals[:,:],cmap=mpl.colormaps["coolwarm_r"],center=None,xticklabels=c_names_list,yticklabels=np.arange(0,plot_vals.shape[0]),cbar_kws={'location': 'right','pad': 0.01})

ax.set_title(f"Cosine similarity between activations (forward pass)\nover 50 memorized and 50 non-memorized paragraphs", fontsize=fontsize, x=-0.1, loc="left")
s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')
ax.tick_params(axis='y', which='major', labelsize=fontsize)
ax.tick_params(axis='x', which='major', labelsize=fontsize)
ax.set_ylabel('layer', rotation=90, color="black", fontsize=fontsize, labelpad=5)
ax.invert_yaxis()

fig.savefig(f"{dataLoaders.ROOT}/results/fwd_activs_cosine.pdf", dpi=200, bbox_inches="tight")