
Copyright 2024 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.


# Unigram Differences

In [None]:
import transformer_lens, torch, tqdm, copy, collections, operator, scipy
import numpy as np
from pathlib import Path
from toolz import compose
import pandas as pd
from scipy.stats import rankdata

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter, MaxNLocator
import matplotlib.ticker as plticker

mpl.rcParams['axes.spines.right'] = False
mpl.rcParams['axes.spines.top'] = False


import sys
sys.path.append('/home/jupyter/')
from paraMem.utils import modelHandlers, dataLoaders, evaluation

## Load Model and Data

In [None]:
model_type = "gpt-neo-125M"
model = modelHandlers.load_model(model_type=model_type, DEVICE="cpu")
modelHandlers.set_no_grad(model, ["embed", "pos_embed", "unembed"])

In [None]:
mem_nonmem_sets  = dataLoaders.load_pile_splits(f"gpt-neo-125M/preds", file_names=["50_50_preds.pt", "0_10_preds.pt"], as_torch=True)
mem_set, nonmem_set = mem_nonmem_sets[0], mem_nonmem_sets[1]
full_corpus = torch.cat((mem_set, nonmem_set), dim=0) ## attention in both the mem_set and nonmem_set

In [None]:
def create_unigram_dict(model, token_set:torch.LongTensor):
    unigram_str_dict, unigram_tok_dict = collections.Counter(),collections.Counter() 
    for n in tqdm.tqdm(range(token_set.shape[0])):
        tok_ids = token_set[n].tolist()
        str_toks = model.to_str_tokens(token_set[n])
        for tok, tok_id in zip(str_toks, tok_ids):
            unigram_str_dict[tok] += 1
            unigram_tok_dict[tok_id] += 1
            
    return unigram_str_dict, unigram_tok_dict
            
str_counter, tok_counter = create_unigram_dict(model, full_corpus)
mem_str_counter, mem_tok_counter = create_unigram_dict(model, mem_set)
nonmem_str_counter, nonmem_tok_counter = create_unigram_dict(model, nonmem_set)

mem_toks = set(mem_set.flatten().tolist())
nonmem_toks = set(nonmem_set.flatten().tolist())

#tok_counter_ranks = collections.Counter(dict(zip(tok_counter.keys(), rankdata(list(tok_counter.values()), method='dense')-1)))

In [None]:
def get_filter_tokens(unigram_dict:dict, topP:float=0.1, bottomP:float=0.0):
    n_toks = len(unigram_dict.keys())
    top_n, bottom_n= int(n_toks*topP), n_toks-int(n_toks*bottomP)
    remove_keys = [k for i, k in enumerate(unigram_dict.keys()) if (i <= top_n) or (i >= bottom_n)]
    print(f"remove_keys {len(remove_keys)} from {n_toks}")
    return remove_keys
    
topP = 0.5
remove_keys = get_filter_tokens(tok_counter, topP=topP, bottomP=0.0)
tok_counter_filtered = {k: v for k, v in tok_counter.items() if k not in remove_keys}

remove_keys = get_filter_tokens(mem_tok_counter, topP=topP, bottomP=0.0)
mem_tok_counter_filtered = {k: v for k, v in mem_tok_counter.items() if k not in remove_keys}

remove_keys = get_filter_tokens(nonmem_tok_counter, topP=topP, bottomP=0.0)
nonmem_tok_counter_filtered = {k: v for k, v in nonmem_tok_counter.items() if k not in remove_keys}

In [None]:
#mem_str_counter.subtract(nonmem_str_counter)
#mem_tok_counter.subtract(nonmem_tok_counter)

#nonmem_str_counter.subtract(mem_str_counter)
#nonmem_tok_counter.subtract(mem_tok_counter)

### Collect Attention in all heads in Layer 1

In [None]:
def collect_attention_scores(model, toks_NI:torch.LongTensor, layer:int=1, tok_idx:int=50, n_limit:int=100):
    
    n_heads = 12
    attn_head_dicts = [collections.defaultdict(float) for head_idx in range(0,n_heads)]
    attn_head_entropy_N = []
    
    for i, toks_I in tqdm.tqdm(enumerate(toks_NI[:n_limit,:])):
        _, activs = model.run_with_cache(toks_I.unsqueeze(0).to(model.cfg.device))
        
        activs = activs.to("cpu")
        attn_pattern = activs["pattern", layer, "attn"].squeeze()
        lookback_HI = attn_pattern[:,tok_idx,:tok_idx]
        
        ## collect entropy of attention-token distribution
        entropy_H = scipy.stats.entropy(lookback_HI, axis=-1)
        attn_head_entropy_N.append(entropy_H)
        
        for prefix_idx in range(lookback_HI.shape[-1]):
            for head_idx in range(len(attn_head_dicts)):
                attn_head_dicts[head_idx][toks_I[prefix_idx].item()] += lookback_HI[head_idx,prefix_idx].item() 
    
    attn_head_entropy_NH = np.stack(attn_head_entropy_N)
    attn_head_entropy_H = attn_head_entropy_NH.mean(0)
    return attn_head_dicts, attn_head_entropy_H
            

full_corpus = full_corpus[torch.randperm(full_corpus.size()[0])]
attn_head_dicts, attn_head_entropy_H = collect_attention_scores(model, full_corpus, n_limit=5000)

### Post-processing attention scores

In [None]:
#attn_head_dicts_meaned = []
#for attn_head_dict in attn_head_dicts:
#    attention_sum = np.array(list(attn_head_dict.values())).sum()
#    attn_head_dict_meaned = dict(zip(list(attn_head_dict.keys()), list(attn_head_dict.values()) / attention_sum))
#    attn_head_dicts_meaned.append(attn_head_dict_meaned)

## Compute Correlations

In [None]:
def get_correlations(freq_dict:dict, attn_dict:dict, filter_tok_list:list=None, title:str=""):
    
    if filter_tok_list is None:
        filter_tok_list = list(freq_dict.keys())
    
    tok_freqs, attn_scores = [], []
    for attn_tok_id, attn_score in attn_dict.items():
        if attn_tok_id in freq_dict.keys() and attn_tok_id in filter_tok_list:
            tok_freq = freq_dict[attn_tok_id]
            tok_freqs.append(tok_freq)
            attn_scores.append(attn_score)
            
    tok_freqs, attn_scores = np.array(tok_freqs), np.array(attn_scores)
    #coef, p = scipy.stats.spearmanr(np.array(tok_freqs), np.array(attn_scores))       
    coef, p = scipy.stats.kendalltau(tok_freqs, attn_scores)
    entropy = scipy.stats.entropy(attn_scores)
    #coef, p = scipy.stats.pearsonr(np.array(tok_freqs), np.array(attn_scores))
    print(f"{title}---n tokens: {len(tok_freqs)}, correlation: {round(coef,3)}, p: {round(p,3)}, entropy: {round(entropy,2)}")
    return coef, entropy, tok_freqs, attn_scores

## Collect Correlations per Head

In [None]:
full_corrs, mem_corrs, nonmem_corrs = [], [], []
full_entropys, mem_entropys, nonmem_entropys = [], [], []
all_head_tok_freqs, all_head_attn_scores = [], []

for head_idx in range(len(attn_head_dicts)):  
    print(f"\n{head_idx}____________________")
    full_corr,full_entropy,tok_freqs,attn_scores = get_correlations(tok_counter_ranks, attn_head_dicts[head_idx], filter_tok_list=None, title="full corpus toks")
    mem_corr,mem_entropy,_,_ = get_correlations(mem_tok_counter_filtered, attn_head_dicts[head_idx], filter_tok_list=None, title="memorized toks")
    nonmem_corr,nonmem_entropy,_,_ = get_correlations(nonmem_tok_counter_filtered, attn_head_dicts[head_idx], filter_tok_list=None, title="non-memorized toks")
    
    full_corrs.append(full_corr)
    mem_corrs.append(mem_corr)
    nonmem_corrs.append(nonmem_corr)
    
    full_entropys.append(full_entropy)
    mem_entropys.append(mem_entropy)
    nonmem_entropys.append(nonmem_entropy)
    
    all_head_tok_freqs.append(tok_freqs)
    all_head_attn_scores.append(attn_scores)

In [None]:
#fig, axs = plt.subplots(int(len(all_head_tok_freqs)/4), 4, figsize=(10, 5), constrained_layout=True) #gridspec_kw={'hspace': 0.2}
#fontsize = 12

#for i, ax in enumerate(axs.flatten()):  
#    ax.set_title(f"L1 H{i}, Attn Entropy {round(full_entropys[i],2)}", fontsize=fontsize, loc='left')
#    x,y = np.array(all_head_tok_freqs[i]), np.array(all_head_attn_scores[i]) 
#    em = ax.scatter(x, y, color="grey", alpha=0.2, label="exact match")
#    ax.set_xscale('log')
#    ax.set_yscale('log')
    #ax.set_xlim(0, 10000)
    #ax.set_ylim(None, 1)
    
    #b, a = np.polyfit(x, y, deg=1)
    #x_seq = np.arange(0,x.max())
    #ax.plot(x_seq, a + b * x_seq, color="black", lw=1.5, linestyle=":");
#    ax.tick_params(axis='both', which='major', labelsize=fontsize)
    
#fig.supxlabel('Token Frequency', fontsize=fontsize)
#fig.supylabel('Average Attention at Token', fontsize=fontsize)#, x=0.085)
#fig.savefig(f"{dataLoaders.ROOT}/results/attn_entropy.pdf", dpi=200, bbox_inches="tight")

## Plotting

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 2.5), gridspec_kw={'hspace': 0.4})
fontsize=12

x = np.arange(len(mem_corrs))  # the label locations
width, offset = 0.25, 0.25  # the width of the bars

label_full = ax.bar(x + -offset, full_corrs, width, color="grey", label="full corpus")
label_mem = ax.bar(x, mem_corrs, width, color="red", label="memorized paragraphs")
label_nonmem = ax.bar(x + +offset, nonmem_corrs, width , color="blue", label="non-memorized paragraphs")

ax.set_ylabel("correlation coefficient",fontsize=fontsize)
ax.set_xlabel("Layer 1, Head X",x=0.85,fontsize=fontsize)

ax.axhline(y=0.0, c="black",linewidth=1.5,zorder=10)
ax.spines['bottom'].set_visible(False)
ax.tick_params(axis='both', which='major', labelsize=fontsize)

ax.text(0.01, -0.17, f'all tokens in corpus', color="grey", fontsize=fontsize, horizontalalignment='left',verticalalignment='top', transform=ax.transAxes)
ax.text(0.18, -0.17, f'tokens in memorized paragraphs', color="red", fontsize=fontsize, horizontalalignment='left',verticalalignment='top', transform=ax.transAxes)
ax.text(0.45, -0.17, f'tokens in non-memorized paragraphs', color="blue", fontsize=fontsize, horizontalalignment='left',verticalalignment='top', transform=ax.transAxes)

ax.set_title(f"Kendall correlation between attention at tokens and the tokens' frequencies", fontsize=fontsize, loc="left")
loc = plticker.MultipleLocator(base=1.0) # this locator puts ticks at regular intervals
ax.xaxis.set_major_locator(loc)

fig.savefig(f"{dataLoaders.ROOT}/results/attn_freq_kendall.pdf", dpi=200, bbox_inches="tight")


In [None]:
fig, ax = plt.subplots(1, 1, figsize=(4, 2.5), gridspec_kw={'hspace': 0.4})
fontsize=10

x = np.arange(len(mem_corrs))  # the label locations
width, offset = 0.8, 0.25  # the width of the bars

label_mem = ax.bar(x, full_corrs, width, color="grey", label="memorized paragraphs")

ax.set_ylabel("Kendall corr. coeff.",fontsize=fontsize)
ax.set_xlabel("Layer 1, Head X",fontsize=fontsize)
ax.set_ylim(np.array(full_corrs).min()-0.02, None)

ax.axhline(y=0.0, c="black",linewidth=1.5,zorder=10)
ax.spines['bottom'].set_visible(False)
ax.tick_params(axis='both', which='major', labelsize=fontsize)

ax.set_title(f"Correlation between attention\nat tokens and the tokens' frequencies", fontsize=fontsize, loc="left")
loc = plticker.MultipleLocator(base=1.0) # this locator puts ticks at regular intervals
ax.xaxis.set_major_locator(loc)

fig.savefig(f"{dataLoaders.ROOT}/results/attn_freq_kendall.pdf", dpi=200, bbox_inches="tight")
