
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Attention Head Analysis

In [None]:
#@title Import libraries
import transformer_lens, torch, gc, itertools, functools, math
import pandas as pd
import numpy as np
import circuitsvis as cv

import seaborn as sns
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, SymLogNorm

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False

import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, gradient, localizing, patching

## Model

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu")
modelHandlers.set_no_grad(model, ["embed", "pos_embed", "unembed"])

## Load Data

In [None]:
## mem and non-mem set
mem_nonmem_sets  = dataLoaders.load_pile_splits(f"{model_type}/preds", file_names=["50_50_preds.pt", "0_10_preds.pt"], as_torch=True)
mem_set, non_mem_set = mem_nonmem_sets[0], mem_nonmem_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, non_mem_set, mem_batch=1, non_mem_batch=50, test_frac=0.2, add_bos=None)
_, k_toks_NI = next(iter(train_dl))
k_toks_NI = k_toks_NI.squeeze(0)

In [None]:
## load perturbed mem set and original mem set
mem_perturbed_sets  = dataLoaders.load_pile_splits(f"{model_type}/perturbed", file_names=["mem_toks.pt", "perturbed_mem_toks.pt"], as_torch=True)
mem_set, perturbed_mem_set = mem_perturbed_sets[0], mem_perturbed_sets[1]
train_dl, test_dl = dataLoaders.train_test_batching(mem_set, perturbed_mem_set, mem_batch=50, non_mem_batch=50, matched=True, shuffle=False, test_frac=0.2, add_bos=None)
c_toks_NI, c_perturb_toks_NI = next(iter(train_dl))
c_toks_NI, c_perturb_toks_NI, = c_toks_NI.squeeze(0), c_perturb_toks_NI.squeeze(0)

## Forward Pass

In [None]:
string_NI = [" headlines out of Washington never seem to slow. Subscribe to The D.C. Brief to make sense of what matters most. Please enter a valid email address. Sign Up Now Check the box if you do not wish to receive promotional offers via email from TIME. You can unsubscribe at any time. By signing up you are agreeing to our Terms of Use and Privacy Policy . This site is protected by reCAPTCHA and the Google Privacy Policy and Terms of Service apply. Thank you! For your"]
string_NI = ["Sign up for Take Action Now and get three actions in your inbox every week. You will receive occasional promotional offers for programs that support The Nation’s journalism. You can read our Privacy Policy here. Sign up for Take Action Now and get three actions in your inbox every week.\n\nThank you for signing up. For more from The Nation, check out our latest issue\n\nSubscribe now for as little as $2 a month!\n\nSupport Progressive Journalism The Nation is reader supported:"]
#string_NI = ["The following are trademarks or service marks of Major League Baseball entities and may be used only with permission of Major League Baseball Properties, Inc. or the relevant Major League Baseball entity: Major League, Major League Baseball, MLB, the silhouetted batter logo, World Series, National League, American League, Division Series, League Championship Series, All-Star Game, and the names, nicknames, logos, uniform designs, color combinations, and slogans designating the Major League Baseball clubs and entities, and"]

#string_NI = model.to_string(c_toks_NI[21])
#string_NI = model.to_string(c_toks_NI[:])

print(string_NI)

toks_NI = model.to_tokens(string_NI, prepend_bos=False)
_, activs = model.run_with_cache(toks_NI)    


### Attention Pattern Analysis

In [None]:
attention_pattern = activs["pattern", 1, "attn"].squeeze()
c_str_toks = model.to_str_tokens(toks_NI[0])

In [None]:
#cv.attention.attention_patterns(tokens=c_str_toks, attention=attention_pattern)

In [None]:
idx = 50
if len(attention_pattern.shape) == 4:
    attention_pattern = attention_pattern.mean(0)
pattern = attention_pattern[:,idx,:idx]

In [None]:

fig, ax = plt.subplots(1, 1, figsize=(14, 3), gridspec_kw={'hspace': 0.4})
fontsize=12

s = sns.heatmap(pattern.numpy(),
              cmap=mpl.colormaps["Reds"], center=None,
              xticklabels=model.to_str_tokens(toks_NI)[:idx],
              yticklabels=np.arange(0, pattern.shape[0]), square=True,
              cbar_kws={'location': 'right','pad': 0.01})

sns.set(font_scale=1.0)

ax.set_ylabel("layer 2, head X",fontsize=fontsize)
ax.set_title(f"KQ attention pattern analysis", fontsize=fontsize, loc="left")
s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')
ax.text(1.01, -0.2, f'predicted\ntoken:"{model.to_str_tokens(toks_NI)[idx]}"', color="black", fontsize=fontsize-1, horizontalalignment='left',verticalalignment='bottom', transform=ax.transAxes)

fig.savefig(f"{dataLoaders.ROOT}/results/attention_pattern_l2_example2.pdf", dpi=200, bbox_inches="tight")

