
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Rare Token Hypothesis

In [None]:
import transformer_lens, torch, tqdm, copy, collections, operator
import numpy as np
from pathlib import Path
from toolz import compose
import pandas as pd
import seaborn as sns
from scipy.stats import rankdata

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter, MaxNLocator
import matplotlib.ticker as plticker

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False


import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, evaluation

## Load Model and Data

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu")
modelHandlers.set_no_grad(model, ["embed", "pos_embed", "unembed"])

In [None]:
mem_nonmem_sets  = dataLoaders.load_pile_splits(f"gpt-neo-125M/preds", file_names=["50_50_preds.pt", "0_10_preds.pt"], as_torch=True)
mem_set, nonmem_set = mem_nonmem_sets[0], mem_nonmem_sets[1]
full_corpus = torch.cat((mem_set, nonmem_set), dim=0) ## attention in both the mem_set and nonmem_set

In [None]:
nonmem_set.shape

In [None]:
def create_unigram_dict(model, token_set:torch.LongTensor):
    unigram_str_dict, unigram_tok_dict = collections.Counter(),collections.Counter() 
    for n in tqdm.tqdm(range(token_set.shape[0])):
        tok_ids = token_set[n].tolist()
        str_toks = model.to_str_tokens(token_set[n])
        for tok, tok_id in zip(str_toks, tok_ids):
            unigram_str_dict[tok] += 1
            unigram_tok_dict[tok_id] += 1
    return unigram_str_dict, unigram_tok_dict
            
full_str_counter, full_tok_counter = create_unigram_dict(model, full_corpus)
#mem_str_counter, mem_tok_counter = create_unigram_dict(model, mem_set)
#nonmem_str_counter, nonmem_tok_counter = create_unigram_dict(model, nonmem_set)

### Collect Attention per frequency of token

In [None]:
def collect_attn_frequency_scores(model, toks_NI:torch.LongTensor, tok_counter:dict, n_limit:int=100):
    
    layer, n_heads, tok_idx = 1, 12, 50
    topK_rarest = 50
    head_freq_attn = torch.zeros(n_heads, topK_rarest)
    
    for i, toks_I in tqdm.tqdm(enumerate(toks_NI[:n_limit,:])):
        toks_I_freq = [tok_counter[tok.item()] for tok in toks_I[:tok_idx]] ## get corpus frequencies for all tokens in the paragraph
        toks_I_freq_idcs = rankdata(torch.LongTensor(toks_I_freq), method='dense')-1 ## get frequency ranks for tokens including ties
        
        _, activs = model.run_with_cache(toks_I.unsqueeze(0).to(model.cfg.device))  
        activs = activs.to("cpu")
        
        attn_pattern = activs["pattern", layer, "attn"].squeeze()
        lookback_HI = attn_pattern[:,tok_idx,:tok_idx] ## collect "lookback" attention at token "tok_idx"
        lookback_HI_idcs = torch.argsort(lookback_HI, dim=1, descending=True) ## optional: get ranks of attention scores for all toks per paragraph
        
        ## looping through all prefix tokens and all heads
        for prefix_idx in range(lookback_HI.shape[-1]):
            for head_idx in range(n_heads):
                freq_rank = toks_I_freq_idcs[prefix_idx].item() ## get the tokens frequency rank in the sequence
                if freq_rank < topK_rarest: ## collect only tokens with ranks below "topK_rarest"
                    head_freq_attn[head_idx,freq_rank] += lookback_HI[head_idx,prefix_idx].item() ## sum attention of the token
    return head_freq_attn
            

full_corpus=full_corpus[torch.randperm(full_corpus.size()[0]),:]
head_freq_attn = collect_attn_frequency_scores(model, full_corpus, full_tok_counter, n_limit=1000)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 3.5))
fontsize=15

#norm_by = head_freq_attn.sum(1).unsqueeze(1).repeat(1, head_freq_attn.shape[1])
#vals = (head_freq_attn / norm_by).numpy()
vals = head_freq_attn.numpy()
vals = vals[:,:]

s = sns.heatmap(vals, #annot=rankdata(vals, axis=-1)-1
              cmap=mpl.colormaps["Greys"], center=None,
              xticklabels=np.arange(0, vals.shape[1]),
              yticklabels=np.arange(0, vals.shape[0]), square=False, cbar=True,
              cbar_kws={'location': 'right', 'pad': 0.01,'label':'attention on tokens\nover 1000 paragraphs'}, ax=ax) 

sns.set(font_scale=1.12)
#ax.set_xlabel("10 least frequent tokens per paragraph",fontsize=fontsize)
ax.set_xlabel("tokens per paragraph ordered by corpus frequency",fontsize=fontsize)

ax.set_ylabel("layer 1, head X",fontsize=fontsize)
ax.set_title(f"KQ attention on 'distinctive' tokens", fontsize=fontsize, loc="left")
s.set_yticklabels(s.get_yticklabels(), rotation=0, horizontalalignment='right')


# Create a second y-axis on the right side
#ax2 = ax.twinx()
#ax2.grid(False)
#ax2.set_yticks(np.linspace(0.5, (vals.shape[0]-0.5), num=12))

#kendall_corr = []
#for head_attn in vals:
#    coef, p = scipy.stats.kendalltau(np.arange(vals.shape[1]), head_attn)
#    kendall_corr.append(round(coef,2))

#ax2.set_yticklabels([corr for corr in kendall_corr], verticalalignment="center", fontsize = fontsize-1)
#ax2.set_ylabel("kendall corr. between attention\nat token and token frequency", fontsize = fontsize-2)
#ax2.invert_yaxis()
#for y_tick_pos in ax.get_yticks():
#    ax.text(1.01, y_tick_pos, f'1', color="black", fontsize=fontsize-1, horizontalalignment='left',verticalalignment='bottom', transform=ax.transAxes)

fig.savefig(f"{dataLoaders.ROOT}/results/least_frequent_50.pdf", dpi=200, bbox_inches="tight")

