In [43]:
import transformers
import torch
from transformers import pipeline

import pandas as pd
from tqdm import tqdm
import numpy as np

tqdm.pandas()
import datasets

import argparse

from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset

In [44]:
model_path = "meta-llama/Meta-Llama-3.1-8B-Instruct"
hf_token = "hf_token"

tokenizer = AutoTokenizer.from_pretrained(model_path,
    padding_side='left',
    cache_dir="/home/jovyan/shares/SR003.nfs2/.cache/models/transformers/",
    token=hf_token,
    )

model = AutoModelForCausalLM.from_pretrained(model_path,
    cache_dir="/home/jovyan/shares/SR003.nfs2/.cache/models/transformers/",
    token=hf_token,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    #attn_implementation="flash_attention_2",
    attn_implementation='eager',
    ).eval()

Loading checkpoint shards: 100%|██████████| 4/4 [00:03<00:00,  1.26it/s]


In [45]:
def assess_knowledgability(question):
    
    messages = [
        {"role": "system", "content": "Answer the following question based\
on your internal knowledge with one or few words. If you are sure the answer is\
accurate and correct, please say '100'. If you are not confident\
with the answer, please range your knowledgability from 0 to 100, say just number. For example, '40'. Question: {question}. Answer:\
    "},
        {"role": "user", "content": question
    },
    ]
    
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    
    terminators = [
        tokenizer.eos_token_id,
        tokenizer.convert_tokens_to_ids("<|eot_id|>")
    ]
    
    outputs = model.generate(
        input_ids,
        max_new_tokens=512,
        eos_token_id=terminators,
        do_sample=True,
        temperature=0.3,
        top_p=0.9,
    )
    response = outputs[0][input_ids.shape[-1]:]
    answer = tokenizer.decode(response, skip_special_tokens=True)
    
    return answer

In [96]:
train_path = "data/adaptive_rag_hotpotqa/train.csv"
test_path = "data/adaptive_rag_hotpotqa/test.csv"
hf_path = "AdaRAG/data_hf/external_rag_natural_questions_extra_v4.hf"
df_train = pd.read_csv(train_path)
df_test = pd.read_csv(test_path)
df_hf = datasets.load_from_disk(hf_path)

In [None]:
for i in tqdm(range(len(df_train))):
    if pd.notna(df_train.loc[i, 'bela_base_ents_lbls']):
        ents = df_train.loc[i, 'bela_base_ents_lbls'].split(', ')
        cur_evals = []
        for j in range(len(ents)):
            cur_ques = "question: Do you know a lot about "+ ents[j] + "?"
            cur_evals.append(int(generate_paraphrases(cur_ques).split('\n')[0].replace(".", "")))
        
        df_train.loc[i, 'llama_know'] = int(np.mean(cur_evals))

In [None]:
for i in tqdm(range(len(df_test))):
    if pd.notna(df_test.loc[i, 'bela_base_ents_lbls']):
        ents = df_test.loc[i, 'bela_base_ents_lbls'].split(', ')
        cur_evals = []
        for j in range(len(ents)):
            cur_ques = "question: Do you know a lot about "+ ents[j] + "?"
            cur_evals.append(int(generate_paraphrases(cur_ques).split('\n')[0].replace(".", "")))
        
        df_test.loc[i, 'llama_know'] = int(np.mean(cur_evals))

In [None]:
df_train['llama_know'] = df_train['llama_know'].fillna(0)
df_test['llama_know'] = df_test['llama_know'].fillna(0)

In [None]:
df_hf['train'] = df_hf['train'].add_column('llama_know', df_train['llama_know'])
df_hf['test'] = df_hf['test'].add_column('llama_know', df_test['llama_know'])

In [None]:
df_train.to_csv(train_path, index = False)
df_test.to_csv(test_path, index = False)
df_hf.save_to_disk(hf_path.split('v3.hf')[0]+"v4.hf")

In [146]:
print(np.mean(df_trivia['llama_know']))
print(np.std(df_trivia['llama_know']))
print(np.min(df_trivia['llama_know']))
print(np.max(df_trivia['llama_know']))

80.98556701030928
17.58671587289575
20.0
100.0


In [147]:
print(np.mean(df_musique['llama_know']))
print(np.std(df_musique['llama_know']))
print(np.min(df_musique['llama_know']))
print(np.max(df_musique['llama_know']))

65.32520325203252
22.669504581415506
0.0
100.0
