# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
import src.models.minimal_mamba2 as minimal_mamba2
import src.models.minimal_mamba2_new as minimal_mamba2_new
from transformers import AutoTokenizer

torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x790ed41ddc70>

In [24]:
from src.consts import FILTERATIONS
from src.datasets.download_dataset import load_splitted_counter_fact
from tqdm import tqdm
from src.types import MODEL_ARCH


In [None]:

device = "cuda:1"
# model_name = "state-spaces/mamba2-1.3b"
model_name = "state-spaces/mamba2-130M"
seed = 0

In [5]:
# Uncomment below to set correct caching directories

# hf_dir = XXX
# tri_dir = YYY
# xdg_dir = ZZZ
# os.environ['HF_HOME'] = hf_dir
# os.environ['TRITON_CACHE_DIR'] = tri_dir
# os.environ['XDG_CACHE_HOME'] = xdg_dir

# Predict

In [6]:


original_data = pd.DataFrame(load_splitted_counter_fact(
    "all", align_to_known=False, filteration=FILTERATIONS.all_correct
))
original_data['true_prob'] = 0.0
original_data['max_prob'] = 0.0
original_data['hit'] = False
original_data['pred'] = ""

In [7]:
# tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b", cache_dir=hf_dir, use_fast=True)


In [8]:
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neox-20b")
tokenizer.pad_token = tokenizer.eos_token

In [27]:
from scripts.evaluate_model import generate_next_tokens
from src.logit_utils import logits_to_probs

model_arch = MODEL_ARCH.MINIMAL_MAMBA2
model = minimal_mamba2.Mamba2LMHeadModel.from_pretrained(model_name, device=device)
get_logits = lambda **kwargs: logits_to_probs(generate_next_tokens(
    model=model, input_ids=kwargs['input_ids'], num_tokens_to_generate=1, model_arch=model_arch
    )[1])

In [19]:
model = minimal_mamba2_new.Mamba2LMHeadModel.from_pretrained(model_name, device=device)

get_logits = lambda **kwargs: model.generate_single(
        input_ids=kwargs['input_ids'],
        max_new_length=kwargs['max_new_length'],
        temperature=kwargs['temperature'],
        top_k=kwargs['top_k'],
        top_p=kwargs['top_p'],
        eos_token_id=tokenizer.eos_token,
        attention=kwargs['attention'],
    )

In [10]:
torch.random.manual_seed(seed)
model.eval()
temperature = 1
top_k = 0
top_p = 1
attention = False

In [11]:
def forward_eval(temperature, top_k, top_p, batch_start, batch_end, attention, print_period=10000):
    prompts = list(original_data.loc[batch_start:batch_end-1, 'prompt'].values)
    true_word = list(original_data.loc[batch_start:batch_end-1, 'target_true'].values)
    true_token = tokenizer(true_word, return_tensors="pt", padding=True)
    true_id = true_token.input_ids.to(device='cpu')
    tokens = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = tokens.input_ids.to(device=device)
    max_new_length = input_ids.shape[1] + 1
    out = get_logits(
        input_ids=input_ids,
        max_new_length=max_new_length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        eos_token_id=tokenizer.eos_token,
        attention=attention,
    )
    next_token_probs = out[-1].detach().cpu().numpy()
    max_idx = np.argmax(next_token_probs, axis=1)
    row_idx = np.arange(next_token_probs.shape[0])
    preds = [tokenizer.decode([t]) for t in max_idx]
    original_data.loc[batch_start:batch_end-1, 'true_prob'] = next_token_probs[row_idx, true_id[:, 0]]
    original_data.loc[batch_start:batch_end-1, 'max_prob'] = next_token_probs[row_idx, max_idx]
    original_data.loc[batch_start:batch_end-1, 'hit'] = original_data.loc[batch_start:batch_end-1, 'true_prob'] == original_data.loc[batch_start:batch_end-1, 'max_prob']
    original_data.loc[batch_start:batch_end-1, 'pred'] = preds
    if (batch_start+1) % print_period == 0:
        print(f'Finished batch [{batch_start}:{batch_end-1}]')
    torch.cuda.empty_cache()

In [12]:
batch_size = 1
N = len(original_data)
batches = list(np.arange(0, N, batch_size)) + [N]

In [20]:
idx = original_data.pipe(lambda x: x[x['original_idx'] == 17689]).index.values[0]
forward_eval(temperature, top_k, top_p, batches[idx], batches[idx+1], attention)
original_data.iloc[idx:idx+1]

Unnamed: 0,relation,relation_prefix,relation_suffix,prompt,relation_id,target_false_id,target_true_id,target_true,target_false,subject,original_idx,split,true_prob,max_prob,hit,pred
8,{} worked in the city of,,{} worked in the city of,Pierre-Jean Mariette worked in the city of,P937,Q29364,Q90,Paris,Montgomery,Pierre-Jean Mariette,17689,train1,0.145603,0.145603,True,Paris


In [28]:
idx = original_data.pipe(lambda x: x[x['original_idx'] == 17689]).index.values[0]
forward_eval(temperature, top_k, top_p, batches[idx], batches[idx+1], attention)
original_data.iloc[idx:idx+1]

AxisError: axis 1 is out of bounds for array of dimension 1

In [17]:
original_data.iloc[idx:idx+1]

Unnamed: 0,relation,relation_prefix,relation_suffix,prompt,relation_id,target_false_id,target_true_id,target_true,target_false,subject,original_idx,split,true_prob,max_prob,hit,pred
8,{} worked in the city of,,{} worked in the city of,Pierre-Jean Mariette worked in the city of,P937,Q29364,Q90,Paris,Montgomery,Pierre-Jean Mariette,17689,train1,82.855637,82.855637,True,Paris


In [26]:
for i in tqdm(range(len(batches)-2)):
    forward_eval(temperature, top_k, top_p, batches[i], batches[i+1], attention)

  0%|          | 0/826 [00:00<?, ?it/s]

100%|██████████| 826/826 [06:55<00:00,  1.99it/s]


In [27]:
original_data.pipe(lambda df: df[~df['hit']])

Unnamed: 0,relation,relation_prefix,relation_suffix,prompt,relation_id,target_false_id,target_true_id,target_true,target_false,subject,original_idx,split,true_prob,max_prob,hit,pred
8,{} worked in the city of,,{} worked in the city of,Pierre-Jean Mariette worked in the city of,P937,Q29364,Q90,Paris,Montgomery,Pierre-Jean Mariette,17689,train1,24.26906,24.624636,False,Montreal
186,"{}, created by",,"{}, created by","Microsoft Office Mobile, created by",P178,Q11463,Q2283,Microsoft,Adobe,Microsoft Office Mobile,12300,train3,20.858376,20.880962,False,the
197,{} is a product of,,{} is a product of,Nintendo Video is a product of,P178,Q248,Q8093,Nintendo,Intel,Nintendo Video,19227,train3,16.41267,16.71034,False,the
251,{} belongs to the continent of,,{} belongs to the continent of,Willan Nunatak belongs to the continent of,P30,Q46,Q51,Antarctica,Europe,Willan Nunatak,18892,train4,21.710266,21.801659,False,Australia
259,{} is owned by,,{} is owned by,Yahoo! Tech is owned by,P127,Q1400,Q37093,Yahoo,Pennsylvania,Yahoo! Tech,9222,train4,16.264675,17.205734,False,Verizon
403,{} is a native speaker of,,{} is a native speaker of,Joseph Paul-Boncour is a native speaker of,P103,Q7737,Q150,French,Russian,Joseph Paul-Boncour,11538,train5,21.06838,21.762497,False,the
479,{} is a native speaker of,,{} is a native speaker of,Jean Rouch is a native speaker of,P103,Q652,Q150,French,Italian,Jean Rouch,19175,test,23.496683,23.595036,False,the
545,{} is a product of,,{} is a product of,Honda Airwave is a product of,P176,Q27597,Q9584,Honda,Fiat,Honda Airwave,18118,test,8.129721,8.258763,False,the
597,{} is written in,,{} is written in,Lenta.ru is written in,P407,Q652,Q7737,Russian,Italian,Lenta.ru,9361,test,12.827466,13.290056,False,the
633,{} holds a citizenship from,,{} holds a citizenship from,Andreas Carlsson holds a citizenship from,P27,Q35,Q34,Sweden,Denmark,Andreas Carlsson,16143,test,24.555607,24.633812,False,two


In [28]:
original_data['hit'].mean()

0.9842805320435308

In [17]:
original_data.to_parquet('entire_results_original.parquet')