# Setup

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from mamba2mini import Mamba2LMHeadModel
from transformers import AutoTokenizer

torch.set_grad_enabled(False)

In [None]:
device = "cuda"
model_name = "state-spaces/mamba2-1.3b"
seed = 0
n_layers = 48

In [None]:
# Uncomment below to set correct caching directories

# hf_dir = XXX
# tri_dir = YYY
# xdg_dir = ZZZ
# os.environ['HF_HOME'] = hf_dir
# os.environ['TRITON_CACHE_DIR'] = tri_dir
# os.environ['XDG_CACHE_HOME'] = xdg_dir

# Prep Data

In [None]:
original_res = pd.read_parquet('entire_results_original.parquet')
attn_res = pd.read_parquet('entire_results_attention.parquet')
mask = (original_res['hit'] == attn_res['hit']) & (attn_res['hit'] == True)
data = attn_res[mask].reset_index(drop=True)

# Analysis Functionality

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", cache_dir=hf_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = Mamba2LMHeadModel.from_pretrained(model_name, device=device)

In [None]:
torch.random.manual_seed(seed)
model.eval()
temperature = 1
top_k = 0
top_p = 1

In [None]:
# Taken from https://github.com/google-research/google-research/blob/master/dissecting_factual_predictions/utils.py 
def decode_tokens(tokenizer, token_array):
    if hasattr(token_array, "shape") and len(token_array.shape) > 1:
        return [decode_tokens(tokenizer, row) for row in token_array]
    return [tokenizer.decode([t]) for t in token_array]

def find_token_range(tokenizer, token_array, substring):
    """Find the tokens corresponding to the given substring in token_array."""
    toks = decode_tokens(tokenizer, token_array)
    whole_string = "".join(toks)
    char_loc = whole_string.index(substring)
    loc = 0
    tok_start, tok_end = None, None
    for i, t in enumerate(toks):
        loc += len(t)
        if tok_start is None and loc > char_loc:
            tok_start = i
        if tok_end is None and loc >= char_loc + len(substring):
            tok_end = i + 1
            break
    return (tok_start, tok_end)

In [None]:
def forward_eval(temperature, top_k, top_p, prompt_idx, window):
    prompt = data.loc[prompt_idx, 'prompt']
    true_word = data.loc[prompt_idx, 'target_true']
    true_token = tokenizer(true_word, return_tensors="pt", padding=True)
    true_id = true_token.input_ids.to(device='cpu')
    tokens = tokenizer(prompt, return_tensors="pt", padding=True)
    input_ids = tokens.input_ids.to(device=device)
    max_new_length = input_ids.shape[1] + 1
    last_idx = input_ids.shape[1] - 1
    probs = np.zeros((input_ids.shape[1]))

    for idx in range(input_ids.shape[1]):
        num_to_masks = {layer : [(last_idx, idx)] for layer in window}
        
        fn = lambda: model.generate_single(
            input_ids=input_ids,
            max_new_length=max_new_length,
            temperature=temperature,
            top_k=top_k,
            top_p=top_p,
            eos_token_id=tokenizer.eos_token,
            attention=True,
            num_to_masks=num_to_masks,
        )
        
        out = fn()
        next_token_probs = out[-1].detach().cpu().numpy()
        probs[idx] = next_token_probs[0, true_id[:, 0]]
        torch.cuda.empty_cache()
    return probs

In [None]:
def evaluate(temperature, top_k, top_p, prompt_indices, windows):
    for prompt_idx in prompt_indices:
        prob_mat = []
        for window in windows:
            prob_mat.append(forward_eval(temperature, top_k, top_p, prompt_idx, window))
        prob_mat = np.array(prob_mat).T
        prompt = data.loc[prompt_idx, 'prompt']
        true_word = data.loc[prompt_idx, 'target_true']
        base_prob = data.loc[prompt_idx, 'true_prob']
        tokens = tokenizer(prompt, return_tensors="pt", padding=True)
        input_ids = tokens.input_ids.to(device=device)
        toks = decode_tokens(tokenizer, input_ids[0]) 
        last_tok = toks[-1]
        toks[-1] = toks[-1] + '*'

        fontsize = 8
        plt.figure(figsize=(4, 3))
        ax = sns.heatmap(prob_mat, cmap="Purples_r", cbar=True)
        plt.title(f'Intervening on flow to:' + last_tok + f'\nwindow: {len(windows[0])}, base probability: {round(base_prob, 4)}', 
                  fontsize=fontsize)
        plt.xlabel('')
        plt.ylabel('')
        x_pos = list(range(0, prob_mat.shape[1], 5))
        plt.xticks(ticks=np.array(range(0, prob_mat.shape[1], 5)) + 0.5, labels=[str(x) for x in x_pos], 
                   rotation=0, fontsize=fontsize)
        plt.yticks(ticks=np.arange(prob_mat.shape[0]) + 0.5, labels=toks, rotation=0, fontsize=fontsize)
        ax.tick_params(axis='both', which='both', length=0)
        cbar = ax.collections[0].colorbar
        cbar.ax.set_xlabel(f'p({true_word[1:]})', labelpad=10, fontsize=fontsize)
        cbar.locator = plt.MaxNLocator(nbins=5)
        cbar.update_ticks()
        cbar.ax.tick_params(labelsize=fontsize)
        plt.tight_layout()
        plt.savefig(f'heatmap_idx={prompt_idx}_ws={window_size}.pdf', format="pdf")
        plt.show()

# Experiments

In [None]:
print("The following prompts are used in [Geva et al. 23'] and are ones for which our core model is correct:")
print('Commerzbank, whose headquarters are in')
print('Edvard Grieg, playing the')
print('Statistical Package for the Social Sciences was created by')
print('The mother tongue of Pietro Mennea is')

In [None]:
prompt_indices = [2841, 661, 3124, 2274]
window_size = 5
windows = [list(range(i, i + window_size)) for i in range(0, n_layers - window_size + 1)]

In [None]:
evaluate(temperature, top_k, top_p, prompt_indices, windows)