# Setup

In [None]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
from mamba2mini import Mamba2LMHeadModel
from transformers import AutoTokenizer

torch.set_grad_enabled(False)

In [None]:
device = "cuda"
model_name = "state-spaces/mamba2-1.3b"
seed = 0

In [None]:
# Uncomment below to set correct caching directories

# hf_dir = XXX
# tri_dir = YYY
# xdg_dir = ZZZ
# os.environ['HF_HOME'] = hf_dir
# os.environ['TRITON_CACHE_DIR'] = tri_dir
# os.environ['XDG_CACHE_HOME'] = xdg_dir

# Predict

In [None]:
original_data = pd.read_parquet('original_data.parquet')
original_data['true_prob'] = 0.0
original_data['max_prob'] = 0.0
original_data['hit'] = False
original_data['pred'] = ""

In [None]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", cache_dir=hf_dir, use_fast=True)
tokenizer.pad_token = tokenizer.eos_token

In [None]:
model = Mamba2LMHeadModel.from_pretrained(model_name, device=device)

In [None]:
torch.random.manual_seed(seed)
model.eval()
temperature = 1
top_k = 0
top_p = 1
attention=True

In [None]:
def forward_eval(temperature, top_k, top_p, batch_start, batch_end, attention, print_period=1000):
    prompts = list(original_data.loc[batch_start:batch_end-1, 'prompt'].values)
    true_word = list(original_data.loc[batch_start:batch_end-1, 'target_true'].values)
    true_token = tokenizer(true_word, return_tensors="pt", padding=True)
    true_id = true_token.input_ids.to(device='cpu')
    tokens = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = tokens.input_ids.to(device=device)
    max_new_length = input_ids.shape[1] + 1
    fn = lambda: model.generate_single(
        input_ids=input_ids,
        max_new_length=max_new_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token,
        attention=attention,
    )
    out = fn()
    next_token_probs = out[-1].detach().cpu().numpy()
    max_idx = np.argmax(next_token_probs, axis=1)
    row_idx = np.arange(next_token_probs.shape[0])
    preds = [tokenizer.decode([t]) for t in max_idx]
    original_data.loc[batch_start:batch_end-1, 'true_prob'] = next_token_probs[row_idx, true_id[:, 0]]
    original_data.loc[batch_start:batch_end-1, 'max_prob'] = next_token_probs[row_idx, max_idx]
    original_data.loc[batch_start:batch_end-1, 'hit'] = original_data.loc[batch_start:batch_end-1, 'true_prob'] == original_data.loc[batch_start:batch_end-1, 'max_prob']
    original_data.loc[batch_start:batch_end-1, 'pred'] = preds
    if (batch_start+1) % print_period == 0:
        print(f'Finished batch [{batch_start}:{batch_end-1}]')
    torch.cuda.empty_cache()

In [None]:
batch_size = 1
N = len(original_data)
batches = list(np.arange(0, N, batch_size)) + [N]

In [None]:
forward_eval(temperature, top_k, top_p, batches[len(batches)-2], batches[len(batches)-1], attention)

In [None]:
for i in range(len(batches)-2):
    forward_eval(temperature, top_k, top_p, batches[i], batches[i+1], attention)

In [None]:
original_data.head()

In [None]:
original_data['hit'].mean()

In [None]:
original_data.to_parquet('entire_results_attention.parquet')