# Setup

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x72e8ab229c70>

In [3]:

from src.consts import FILTERATIONS
from src.datasets.download_dataset import load_splitted_counter_fact
from tqdm import tqdm
from src.types import MODEL_ARCH
from src.models.model_interface import get_model_interface


  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd
  @custom_fwd
  @custom_bwd


In [4]:

# device = "cuda:1"
# # model_name = "state-spaces/mamba2-1.3b"
# model_name = "state-spaces/mamba2-130M"
seed = 0

# Predict

In [6]:


original_data = pd.DataFrame(load_splitted_counter_fact(
    "all", align_to_known=False, filteration=FILTERATIONS.all_correct
))
original_data['true_prob'] = 0.0
original_data['max_prob'] = 0.0
original_data['hit'] = False
original_data['pred'] = ""

In [7]:
model_arch = MODEL_ARCH.MINIMAL_MAMBA2_new
model_interface = get_model_interface(model_arch, model_size="130M")

In [12]:
torch.random.manual_seed(seed)
# model.eval()
temperature = 1
top_k = 0
top_p = 1
attention = True

In [13]:
tokenizer = model_interface.tokenizer
device = model_interface.device

In [None]:
def forward_eval(temperature, top_k, top_p, batch_start, batch_end, attention, print_period=10000):
    prompts = list(original_data.loc[batch_start:batch_end-1, 'prompt'].values)
    true_word = list(original_data.loc[batch_start:batch_end-1, 'target_true'].values)
    true_token = tokenizer(true_word, return_tensors="pt", padding=True)
    true_id = true_token.input_ids.to(device='cpu')
    tokens = tokenizer(prompts, return_tensors="pt", padding=True)
    input_ids = tokens.input_ids.to(device=device)
    max_new_length = input_ids.shape[1] + 1
    next_token_probs = model_interface.generate_logits(
        input_ids=input_ids,
        attention=True,
    )
    max_idx = np.argmax(next_token_probs, axis=1)
    row_idx = np.arange(next_token_probs.shape[0])
    preds = [tokenizer.decode([t]) for t in max_idx]
    original_data.loc[batch_start:batch_end-1, 'true_prob'] = next_token_probs[row_idx, true_id[:, 0]]
    original_data.loc[batch_start:batch_end-1, 'max_prob'] = next_token_probs[row_idx, max_idx]
    original_data.loc[batch_start:batch_end-1, 'hit'] = original_data.loc[batch_start:batch_end-1, 'true_prob'] == original_data.loc[batch_start:batch_end-1, 'max_prob']
    original_data.loc[batch_start:batch_end-1, 'pred'] = preds
    if (batch_start+1) % print_period == 0:
        print(f'Finished batch [{batch_start}:{batch_end-1}]')
    torch.cuda.empty_cache()

In [15]:
batch_size = 1
N = len(original_data)
batches = list(np.arange(0, N, batch_size)) + [N]

In [16]:
for i in tqdm(range(len(batches)-1)):
    forward_eval(temperature, top_k, top_p, batches[i], batches[i+1], attention)

  0%|          | 0/827 [00:00<?, ?it/s]

100%|██████████| 827/827 [00:27<00:00, 30.25it/s]


In [17]:
original_data.pipe(lambda df: df[~df['hit']])

Unnamed: 0,relation,relation_prefix,relation_suffix,prompt,relation_id,target_false_id,target_true_id,target_true,target_false,subject,original_idx,split,true_prob,max_prob,hit,pred
137,{} belongs to the continent of,,{} belongs to the continent of,Finland belongs to the continent of,P30,Q51,Q46,Europe,Antarctica,Finland,11576,train2,0.164485,0.31281,False,the
149,The official language of {} is,The official language of,is,The official language of Netherlands Antilles is,P37,Q7737,Q7411,Dutch,Russian,Netherlands Antilles,17961,train2,0.081177,0.083499,False,Spanish
193,{} was born in,,{} was born in,Mizuki Fukumura was born in,P19,Q1085,Q1490,Tokyo,Prague,Mizuki Fukumura,6753,train3,0.066823,0.076437,False,the
229,{} is created by,,{} is created by,.NET Framework is created by,P178,Q95,Q2283,Microsoft,Google,.NET Framework,8180,train3,0.147513,0.1883,False,the
286,{} belongs to the continent of,,{} belongs to the continent of,Slovenia belongs to the continent of,P30,Q51,Q46,Europe,Antarctica,Slovenia,7958,train4,0.18633,0.261546,False,the
354,{} is a product of,,{} is a product of,Sony Mavica is a product of,P176,Q27564,Q41187,Sony,Dodge,Sony Mavica,13965,train5,0.181343,0.244415,False,the
359,{} is developed by,,{} is developed by,Amazon Echo is developed by,P178,Q37156,Q3884,Amazon,IBM,Amazon Echo,7601,train5,0.105153,0.143823,False,Google
454,"{}, produced by",,"{}, produced by","iPad, produced by",P176,Q9584,Q312,Apple,Honda,iPad,13762,test,0.062546,0.067953,False,the
469,{} is produced by,,{} is produced by,Toyota Alphard is produced by,P176,Q181114,Q53268,Toyota,Chrysler,Toyota Alphard,9795,test,0.163757,0.186177,False,the
484,{} is a professional,,{} is a professional,Jozy Altidore is a professional,P641,Q41466,Q2736,soccer,hockey,Jozy Altidore,2621,test,0.1122,0.257095,False,footballer


In [18]:
original_data['hit'].mean()

0.9818621523579202

In [None]:
original_data['hit'].mean()

0.9842805320435308

In [19]:
original_data.to_parquet('entire_results_attention.parquet')