# Setup

In [None]:
import os
import numpy as np
import pandas as pd
import torch
from mamba2mini import Mamba2LMHeadModel
from transformers import AutoTokenizer

torch.set_grad_enabled(False)

In [None]:
device = "cuda"
model_name = "state-spaces/mamba2-1.3b"
seed = 0
n_prompts = 1000
n_layers = 48

In [None]:
# Uncomment below to set correct caching directories

# hf_dir = XXX
# tri_dir = YYY
# xdg_dir = ZZZ
# os.environ['HF_HOME'] = hf_dir
# os.environ['TRITON_CACHE_DIR'] = tri_dir
# os.environ['XDG_CACHE_HOME'] = xdg_dir

# Prep Data

In [None]:
original_res = pd.read_parquet('entire_results_original.parquet')
attn_res = pd.read_parquet('entire_results_attention.parquet')
mask = (original_res['hit'] == attn_res['hit']) & (attn_res['hit'] == True)
data = attn_res[mask].sample(n_prompts, random_state=seed).reset_index(drop=True)

# Analysis Functionality

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", cache_dir=hf_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = Mamba2LMHeadModel.from_pretrained(model_name, device=device)

In [None]:
torch.random.manual_seed(seed)
model.eval()
temperature = 1
top_k = 0
top_p = 1

In [None]:
# Taken from https://github.com/google-research/google-research/blob/master/dissecting_factual_predictions/utils.py 
def decode_tokens(tokenizer, token_array):
    if hasattr(token_array, "shape") and len(token_array.shape) > 1:
        return [decode_tokens(tokenizer, row) for row in token_array]
    return [tokenizer.decode([t]) for t in token_array]

def find_token_range(tokenizer, token_array, substring):
    """Find the tokens corresponding to the given substring in token_array."""
    toks = decode_tokens(tokenizer, token_array)
    whole_string = "".join(toks)
    char_loc = whole_string.index(substring)
    loc = 0
    tok_start, tok_end = None, None
    for i, t in enumerate(toks):
        loc += len(t)
        if tok_start is None and loc > char_loc:
            tok_start = i
        if tok_end is None and loc >= char_loc + len(substring):
            tok_end = i + 1
            break
    return (tok_start, tok_end)

In [None]:
def forward_eval(temperature, top_k, top_p, prompt_idx, window, block=None):
    prompt = data.loc[prompt_idx, 'prompt']
    true_word = data.loc[prompt_idx, 'target_true']
    base_prob = data.loc[prompt_idx, 'true_prob']
    true_token = tokenizer(true_word, return_tensors="pt", padding=True)
    true_id = true_token.input_ids.to(device='cpu')
    tokens = tokenizer(prompt, return_tensors="pt", padding=True)
    input_ids = tokens.input_ids.to(device=device)
    max_new_length = input_ids.shape[1] + 1
    last_idx = input_ids.shape[1] - 1
    num_to_masks = {}
    first_token = False

    tok_start, tok_end = find_token_range(tokenizer, input_ids[0], data.loc[prompt_idx, 'subject'])
    subject_tokens = list(range(tok_start, tok_end))
    if 0 in subject_tokens:
        first_token = True
    if block not in ('subject', 'relation'):
        blocked_idx = [last_idx]
    else:
        if block == 'subject':
            blocked_idx = subject_tokens
        else:
            blocked_idx = [i for i in range(last_idx + 1) if i not in subject_tokens]
        
    for layer in window:
        for idx in blocked_idx:
            if num_to_masks.get(layer) == None:
                num_to_masks[layer] = [(last_idx, idx)]
            else:
                num_to_masks[layer].append((last_idx, idx))
    
    fn = lambda: model.generate_single(
        input_ids=input_ids,
        max_new_length=max_new_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token,
        attention=True,
        num_to_masks=num_to_masks,
    )
    out = fn()
    next_token_probs = out[-1].detach().cpu().numpy()
    max_prob = np.max(next_token_probs, axis=1)[0]
    true_prob = next_token_probs[0, true_id[:, 0]]
    torch.cuda.empty_cache()
    return (true_prob == max_prob, (true_prob - base_prob) * 100.0 / base_prob, first_token)

In [None]:
def evaluate(temperature, top_k, top_p, prompt_indices, windows, block=None, print_period=500):
    counts_w_first = np.zeros((len(windows)))
    counts_wo_first = np.zeros((len(windows)))
    diffs_w_first = np.zeros((len(windows)))
    diffs_wo_first = np.zeros((len(windows)))
    w_first = 0
    for i, window in enumerate(windows):
        print('---------------------------------------------------------------')
        print(f'Starting window {i}: {window}')
        for j, prompt_idx in enumerate(prompt_indices):
            hit, diff, first = forward_eval(temperature, top_k, top_p, prompt_idx, window, block)
            if first:
                if i == 0:
                    w_first += 1
                counts_w_first[i] += hit
                diffs_w_first[i] += diff
            else:
                counts_wo_first[i] += hit
                diffs_wo_first[i] += diff
            if (j+1) % print_period == 0:
                print(f'Finished prompt {j}')
    counts = counts_w_first + counts_wo_first
    diffs = diffs_w_first + diffs_wo_first
    return (counts / n_prompts, diffs / n_prompts,
            counts_w_first / w_first, diffs_w_first / w_first,
            counts_wo_first / (n_prompts - w_first), diffs_wo_first / (n_prompts - w_first))

# Experiments - no blocking

In [None]:
prompt_indices = list(range(n_prompts))
windows = [[]]
no_block_acc, no_block_diff, _, _, _, _ = evaluate(temperature, top_k, top_p, prompt_indices, windows)

In [None]:
print(no_block_acc)
print(no_block_diff)

# Experiments - window size = 9

In [None]:
window_size = 9
prompt_indices = list(range(n_prompts))
windows = [list(range(i, i + window_size)) for i in range(0, n_layers - window_size + 1)]

## Block last

In [None]:
last_acc, last_diff, last_wf_acc, last_wf_diff, last_wof_acc, last_wof_diff = evaluate(temperature, top_k, top_p, 
                                                                                       prompt_indices, windows)

In [None]:
df = pd.DataFrame(last_acc)
df.to_parquet(f'block_last_acc_ws={window_size}.parquet')
df = pd.DataFrame(last_diff)
df.to_parquet(f'block_last_diff_ws={window_size}.parquet')
df = pd.DataFrame(last_wf_acc)
df.to_parquet(f'block_last_wf_acc_ws={window_size}.parquet')
df = pd.DataFrame(last_wf_diff)
df.to_parquet(f'block_last_wf_diff_ws={window_size}.parquet')
df = pd.DataFrame(last_wof_acc)
df.to_parquet(f'block_last_wof_acc_ws={window_size}.parquet')
df = pd.DataFrame(last_wof_diff)
df.to_parquet(f'block_last_wof_diff_ws={window_size}.parquet')

## Block subject

In [None]:
sub_acc, sub_diff, sub_wf_acc, sub_wf_diff, sub_wof_acc, sub_wof_diff = evaluate(temperature, top_k, top_p, 
                                                                                 prompt_indices, windows, block='subject')

In [None]:
df = pd.DataFrame(sub_acc)
df.to_parquet(f'block_subject_acc_ws={window_size}.parquet')
df = pd.DataFrame(sub_diff)
df.to_parquet(f'block_subject_diff_ws={window_size}.parquet')
df = pd.DataFrame(sub_wf_acc)
df.to_parquet(f'block_subject_wf_acc_ws={window_size}.parquet')
df = pd.DataFrame(sub_wf_diff)
df.to_parquet(f'block_subject_wf_diff_ws={window_size}.parquet')
df = pd.DataFrame(sub_wof_acc)
df.to_parquet(f'block_subject_wof_acc_ws={window_size}.parquet')
df = pd.DataFrame(sub_wof_diff)
df.to_parquet(f'block_subject_wof_diff_ws={window_size}.parquet')

## Block relation

In [None]:
rel_acc, rel_diff, rel_wf_acc, rel_wf_diff, rel_wof_acc, rel_wof_diff = evaluate(temperature, top_k, top_p, 
                                                                                 prompt_indices, windows, block='relation')

In [None]:
df = pd.DataFrame(rel_acc)
df.to_parquet(f'block_relation_acc_ws={window_size}.parquet')
df = pd.DataFrame(rel_diff)
df.to_parquet(f'block_relation_diff_ws={window_size}.parquet')
df = pd.DataFrame(rel_wf_acc)
df.to_parquet(f'block_relation_wf_acc_ws={window_size}.parquet')
df = pd.DataFrame(rel_wf_diff)
df.to_parquet(f'block_relation_wf_diff_ws={window_size}.parquet')
df = pd.DataFrame(rel_wof_acc)
df.to_parquet(f'block_relation_wof_acc_ws={window_size}.parquet')
df = pd.DataFrame(rel_wof_diff)
df.to_parquet(f'block_relation_wof_diff_ws={window_size}.parquet')

# Experiments - window size = 15

In [None]:
window_size = 15
prompt_indices = list(range(n_prompts))
windows = [list(range(i, i + window_size)) for i in range(0, n_layers - window_size + 1)]

## Block last

In [None]:
last_acc, last_diff, last_wf_acc, last_wf_diff, last_wof_acc, last_wof_diff = evaluate(temperature, top_k, top_p, 
                                                                                       prompt_indices, windows)

In [None]:
df = pd.DataFrame(last_acc)
df.to_parquet(f'block_last_acc_ws={window_size}.parquet')
df = pd.DataFrame(last_diff)
df.to_parquet(f'block_last_diff_ws={window_size}.parquet')
df = pd.DataFrame(last_wf_acc)
df.to_parquet(f'block_last_wf_acc_ws={window_size}.parquet')
df = pd.DataFrame(last_wf_diff)
df.to_parquet(f'block_last_wf_diff_ws={window_size}.parquet')
df = pd.DataFrame(last_wof_acc)
df.to_parquet(f'block_last_wof_acc_ws={window_size}.parquet')
df = pd.DataFrame(last_wof_diff)
df.to_parquet(f'block_last_wof_diff_ws={window_size}.parquet')

## Block subject

In [None]:
sub_acc, sub_diff, sub_wf_acc, sub_wf_diff, sub_wof_acc, sub_wof_diff = evaluate(temperature, top_k, top_p, 
                                                                                 prompt_indices, windows, block='subject')

In [None]:
df = pd.DataFrame(sub_acc)
df.to_parquet(f'block_subject_acc_ws={window_size}.parquet')
df = pd.DataFrame(sub_diff)
df.to_parquet(f'block_subject_diff_ws={window_size}.parquet')
df = pd.DataFrame(sub_wf_acc)
df.to_parquet(f'block_subject_wf_acc_ws={window_size}.parquet')
df = pd.DataFrame(sub_wf_diff)
df.to_parquet(f'block_subject_wf_diff_ws={window_size}.parquet')
df = pd.DataFrame(sub_wof_acc)
df.to_parquet(f'block_subject_wof_acc_ws={window_size}.parquet')
df = pd.DataFrame(sub_wof_diff)
df.to_parquet(f'block_subject_wof_diff_ws={window_size}.parquet')

## Block relation

In [None]:
rel_acc, rel_diff, rel_wf_acc, rel_wf_diff, rel_wof_acc, rel_wof_diff = evaluate(temperature, top_k, top_p, 
                                                                                 prompt_indices, windows, block='relation')

In [None]:
df = pd.DataFrame(rel_acc)
df.to_parquet(f'block_relation_acc_ws={window_size}.parquet')
df = pd.DataFrame(rel_diff)
df.to_parquet(f'block_relation_diff_ws={window_size}.parquet')
df = pd.DataFrame(rel_wf_acc)
df.to_parquet(f'block_relation_wf_acc_ws={window_size}.parquet')
df = pd.DataFrame(rel_wf_diff)
df.to_parquet(f'block_relation_wf_diff_ws={window_size}.parquet')
df = pd.DataFrame(rel_wof_acc)
df.to_parquet(f'block_relation_wof_acc_ws={window_size}.parquet')
df = pd.DataFrame(rel_wof_diff)
df.to_parquet(f'block_relation_wof_diff_ws={window_size}.parquet')