In [2]:
templates_for_qa = [
    "Question: {question}?\nAnswer:",
    "{question}?",
    "Answer the following question:\n\n{question}",
    "Answer this question:\n\n{question}?",
    "Please answer this question: {question}",
    "Answer the question...{question}?",
    "What is the answer to this question? {question}\n\n",
    "Can you tell me the answer to {question}?",
    "Next question: {question}\n\n",
    "Q: {question} A:",
    "{question}\nWhat is the answer?",
    "Write the answer: {question}",
    "{question}???",
]
background_templete = "Question: {question}\n Document: {background}"


In [None]:
from datasets import load_dataset
import random
from src.eval.utils import has_answer , get_substring_match_score
import json
from tqdm import tqdm

model_name = 'mistral' # llama, mistral, qwen

## question , answer pair
dataset = load_dataset("nq_open")

## retrieved contexts from ColBERT/ background, gt_background, hn_background
## hugginface repo: eunseong/nq_colbertv2
## hf download eunseong/nq_colbertv2 --repo-type dataset --local-dir nq_colbertv2
ret_file_path = 'data_care/nq_colbertv2.json'
with open(ret_file_path, 'r') as f:
    ret_data = json.load(f)
    
## generation results target model / closed_book_correct
## NOTE: We need to generate closed-book results from the target model
## e.g., nq_train_closed_book_mistral.json / use evaluate_closedbook.sh
model_output_path = f'ft_results/nq_train_closed_book_{model_name}.json'
with open(model_output_path, 'r') as f:
    generated_data = json.load(f)
    
preds = []
answers = [x['answer'] for x in generated_data]
for d in generated_data:
    preds.append(d['pred'])
_, score_per_sample = get_substring_match_score(preds,answers)

  from .autonotebook import tqdm as notebook_tqdm
Downloading readme: 8.77kB [00:00, 11.5MB/s]
Downloading data: 100%|██████████| 4.46M/4.46M [00:00<00:00, 4.82MB/s]
Downloading data: 100%|██████████| 214k/214k [00:00<00:00, 474kB/s]
Generating train split: 100%|██████████| 87925/87925 [00:00<00:00, 2130112.11 examples/s]
Generating validation split: 100%|██████████| 3610/3610 [00:00<00:00, 704351.19 examples/s]


In [5]:
total_data = []
for idx, (sample, ret_sample, s) in enumerate(tqdm(zip(dataset['train'] , ret_data , score_per_sample), total=len(ret_data))):
    cur= {}
    messages = []
    answer = sample['answer']
    question = sample['question']
    question = random.choice(templates_for_qa).format_map(dict(question=question))
    
    #messages 
    messages.append({"role": "user","content": question})
    messages.append({"role": "assistant","content": answer[0]})
    
    cur = {
        "id": f"nq_{idx}",
        "hn_background_id": f"hn_nq_{idx}",
        "gt_background_id": f"nq_{idx}",
        "messages": messages,
        "task_type": "open_qa",
        "answers": answer
    }
    # background, hn_background, gt_background
    ctxs = ret_sample['ctxs']
    for ctx in ctxs:
        is_pos = has_answer(answer, ctx['text']) 
        if is_pos and 'gt_background' not in cur:
            cur['gt_background'] = ctx['text']
            cur['background'] = background_templete.format_map(dict(question=sample['question'], background=ctx['text']))
        elif not is_pos and 'hn_background' not in cur:
            cur['hn_background'] = background_templete.format_map(dict(question=sample['question'], background=ctx['text']))
        elif 'hn_background' in cur and 'gt_background' in cur:
            break
    
    if 'gt_background' not in cur:
        cur['gt_background'] = ctxs[1]['text']
        cur['background'] = background_templete.format_map(dict(question=sample['question'], background = ctxs[1]['text']))    
    if 'hn_background' not in cur:
        cur['hn_background'] =  background_templete.format_map(dict(question=sample['question'], background = ctxs[1]['text']))    
    cur['closed_book_correct'] = s==1
    total_data.append(cur)
        

  0%|          | 100/87925 [00:00<04:56, 295.79it/s]


In [None]:
# split into train/val/test
random.seed(42)
random.shuffle(total_data)

split_idx = int(len(total_data) * 0.1)
valid_data = total_data[:split_idx]
train_data = total_data[split_idx:]

## save train / valudation
train_data_path = 'train_file_path'
with open(train_data_path, 'w') as f:
    for item in train_data:
        f.write(json.dumps(item) + '\n')
        
val_data_path = 'valid_file_path'
with open(val_data_path, 'w') as f:
    for item in valid_data:
        item['id'] = f"valid_{item['id']}"
        f.write(json.dumps(item) + '\n')

79133 8792
