In [None]:

import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"

os.environ["TRANSFORMERS_CACHE"] = "/data/../llm_cache"
from huggingface_hub import login

login("..")

In [6]:
import os
import sys
import math
import torch
import numpy as np
import pandas as pd
import argparse
import textwrap
import transformers
from peft import PeftModel
from transformers import GenerationConfig, TextStreamer, BitsAndBytesConfig
from llama_attn_replace import replace_llama_attn


from dataclasses import dataclass, field
from typing import Dict, Optional

from accelerate import Accelerator
import datasets
from datasets import Dataset, load_dataset
from peft import LoraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed

from trl import DPOTrainer


import torch
import torch.nn as nn
import transformers
from torch.utils.data import Dataset
from transformers import Trainer, DataCollatorForLanguageModeling, BitsAndBytesConfig
from llama_attn_replace_sft import replace_llama_attn
from gptneox_attn_replace import replace_gpt_neox_attn
from peft import LoraConfig, get_peft_model
from torch.distributed import barrier
from transformers import TrainerCallback, TrainingArguments, TrainerState, TrainerControl

import re
import os
import json
import string
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from evaluate import load

from metrics.drop_answer_em_f1 import DropAnswerEmAndF1
from metrics.support_em_f1 import SupportEmF1Metric
from metrics.answer_support_recall import AnswerSupportRecallMetric
from metrics.squad_answer_em_f1 import SquadAnswerEmF1Metric


from copy import deepcopy

import os
import json
import jsonlines
os.environ['JAVA_HOME'] = "/usr/lib/jvm/java-11-openjdk-amd64"
# from shared_utils.indexing_utils import SparseIndexer, DocumentCollection
from pyserini.search.lucene import LuceneSearcher

bm25_k1 = 0.9
bm25_b = 0.4

index_dir_path = "/data2/../nlp_data/topiocqa/indexes/bm25"
searcher = LuceneSearcher(index_dir_path)
searcher.set_bm25(bm25_k1, bm25_b)


In [12]:

def get_docs(run_file):
    with open(run_file, 'r' )as f:
        run_data = f.readlines()

    runs = {} 
    for line in run_data:
        original_line = deepcopy(line)
        try:
            line = line.split("\t")
            query = line[0]
            passage = line[2]
            rel = int(line[4])
        except IndexError:
            line = original_line.split(" ")
            query = line[0]
            passage = line[2]
            rel = int(line[4])

        if query not in runs:
            runs[query] = {}
        runs[query][passage] = rel
    return runs


In [13]:
# ours
runs = get_docs("../topiocqa/test_2R_SFT101_DPO500_b5_finefeedback_dpo_3R_refself_beta0.5.trec")
# baselines
runs_retpo = get_docs("../topiocqa/retpo.trec" ,)
runs_hyde = get_docs("../topiocqa/hyde.trec",)
runs_convgqr = get_docs("../topiocqa/convgqr.trec",)
runs_t5qr = get_docs("../topiocqa/t5qr.trec",)
runs_llmcs = get_docs("../topiocqa/llmcs_cot.trec",)
runs_infocqr = get_docs("../topiocqa/infocqr.trec",)
runs_hydellm = get_docs("../topiocqa/hyde_llm.trec",)

# conv data
with open("/data2/../nlp_data/topiocqa/qa.jsonl", "r") as f:
    data = f.readlines()
    data_ids = [json.loads(data[i])['id'] for i in range(len(data))]
    n = len(set(data_ids))

data_dict = [json.loads(data[i]) for i in range(len(data))]
data_dict = {l['id']:l for l in data_dict}


In [24]:

def format_prompt(query, retrieved_documents):
    PROMPT = f"""
    Based on the given reference documents, answer the following question.
    When answering, do not repeat the question, and only provide the correct answer.
    Provide the answer only in JSON format as {{"Answer":"Your answer"}}.

    Reference Documents:
    ---------------------
    {retrieved_documents}
    ——————————
    Question: {query}
    Answer: 
    """
    return PROMPT


In [25]:

def generate(formatted_prompt):
    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": formatted_prompt}  
    ]
    input_ids = tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)
    
    # Create attention mask (1s for all tokens)
    attention_mask = torch.ones_like(input_ids)
    
    generated_ids = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_new_tokens=128,
        temperature=1e-10,
        pad_token_id=tokenizer.eos_token_id
    )
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(input_ids, generated_ids)
    ]
    return tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]

In [26]:

model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype="auto",
    device_map='auto'
)

tokenizer = AutoTokenizer.from_pretrained(model_id)

model.generation_config.pad_token_id = tokenizer.pad_token_id


In [31]:
import openai
from tqdm import tqdm
import numpy as np
import os
import random

def format_query_prompt(query, gt, pd):
    prefix =  [{'role': 'system',
             'content': "You are an evaluation tool. Just answer by {Yes} or {No}."}]
    prefix.extend([{'role': 'user',
             'content': f" Here is a question , a golden answer and an AI - generated answer . Judge whether the AI - generated answer is correct according to the question and golden answer , answer with {{ Yes }} or {{ No }}.\ nQuestion : { query }.\ nGolden answer : { gt }\ nGenerated answer : { pd } Response : {{"}
             ]
             )
    return prefix


def run_llm(client, model_name, messages):
    response = client.chat.completions.create(messages=messages, model=model_name)
    return response.choices[0].message.content # , cost


def extract_score_from_text(text):
  
    score = 1 if "yes" in text.lower() else 0   
    return score

In [32]:
bertscore = load("bertscore", model_type="bert-base-cased")

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def answer_extractor(potentially_cot: str) -> str:
    if potentially_cot.startswith('"') and potentially_cot.endswith('"'):
        potentially_cot = potentially_cot[1:-1]

    cot_regex = re.compile(".* answer is:? (.*)\\.?")
    match = cot_regex.match(potentially_cot)
    if match:
        output = match.group(1)
        if output.endswith("."):
            output = output[:-1]
    else:
        output = potentially_cot

    return output

def calculate_acc(prediction, ground_truth):
    for gt in ground_truth:
        if gt in prediction:
            return 1
    return 0

def calculate_bleu(reference, hypothesis):
    smoothie = SmoothingFunction().method4
    return sentence_bleu([reference], hypothesis, smoothing_function=smoothie)

def calculate_rouge(reference, hypothesis):
    scorer = rouge_scorer.RougeScorer(['rouge1', 'rougeL'], use_stemmer=True)
    scores = scorer.score(reference, hypothesis)
    return scores


def evaluate_by_dicts(input_path, output_path):
    metrics = [SquadAnswerEmF1Metric()]
     
    total_acc = 0
    total_lines = 0
    total_rouge1 = 0
    total_rougel = 0
    total_bert = 0

    with open(input_path, 'r', encoding='utf-8') as file:
        for line in file:
            data = json.loads(line.strip())

            ground_truth, prediction = data['answers'], data['generated_answer']
            bert_score = bertscore.compute(references = ground_truth, predictions = prediction,lang = 'en')

            prediction = [prediction]            

            assert isinstance(prediction, (str, list))
            if isinstance(prediction, str):
                if prediction.strip().startswith("[") or prediction.strip().endswith("]"):
                    prediction = [e for e in prediction.replace('"', "").replace("[", "").replace("]", "").split(",")]
                else:
                    prediction = [prediction]

            assert isinstance(prediction, (list, tuple))
            prediction = [str(e) for e in prediction]
            prediction = [answer_extractor(_prediction) for _prediction in prediction]

            normalized_prediction = normalize_answer(prediction[0])
            normalized_ground_truth = [normalize_answer(i) for i in ground_truth]

            acc = calculate_acc(normalized_prediction, normalized_ground_truth)
            total_acc += acc
            # print(normalized_prediction, normalized_ground_truth)

            # bleu_score = calculate_bleu(normalized_ground_truth, normalized_prediction)
            # total_bleu += bleu_score

            rouge_scores = calculate_rouge(" ".join(normalized_ground_truth), normalized_prediction)
            total_rouge1 += rouge_scores['rouge1'].fmeasure
            total_rougel += rouge_scores['rougeL'].fmeasure

            total_bert += bert_score['f1'][0]

            total_lines += 1
            try : 
                metrics[0](prediction, ground_truth)
            except : 
                pass
            
        total_acc = total_acc / total_lines
        # total_bleu = total_bleu / total_lines
        total_rouge1 = total_rouge1 / total_lines
        total_rougel = total_rougel / total_lines
        total_bert = total_bert / total_lines
        
        evaluation_results = metrics[0].get_metric()
        evaluation_results['acc'] = total_acc
        # evaluation_results['bleu'] = total_bleu
        evaluation_results['rouge1'] = total_rouge1
        evaluation_results['rougel'] = total_rougel
        evaluation_results['bert_f1'] = total_bert
        
    save_results(evaluation_results, output_path)
    
def save_results(results_dict, output_path):
    with open(output_path, "w") as file:
        json.dump(results_dict, file, indent=4)


In [33]:
def clean_output(output):
    """Returns LLM's answer from its output
    Assumes output is in JSON format of {'Answer':answer}.
    If output is not in expected format, returns string signifying error
    """
    end_index = output.find('}')
    if end_index == -1:
        end_index = len(output)
    json_text = output[:end_index+1].strip()
    try:
        d = json.loads(json_text)
        return str(d['Answer'])
    except:
        return f'JSON parse error({json_text})'

In [34]:
from random import sample
import random
random.seed(0)
def generate_eval(top_docs, out_path, eval_path, conv=data_dict, topk=16, n_gen = -1):
    
    result = []

    qids = list(top_docs.keys())
    # qids = sample(qids, min(n_gen, len(qids)))
    # # qids = qids[:min(n_gen, len(qids))]
    for i, qid in tqdm(enumerate(qids)):
        topk_docs_ours = list(top_docs[qid].keys())[:topk]
        try:
            topk_docs_ours = [json.loads(searcher.doc(qrels_prefix+d).raw())['contents'] for d in topk_docs_ours]
        except:
            print(topk_docs_ours)

        query = conv[qid]['query']
        ans = conv[qid]['answer']

        context = '\n'.join(topk_docs_ours)
        formatted_prompt = format_prompt(query, context)
        output = generate(formatted_prompt)
        output = list(map(clean_output, [output]))


        result.append({'question':query, 'answers':[ans], 'generated_answer':output})

    result_str = list(map(lambda x:json.dumps(x, ensure_ascii=False), result))
    with open(out_path, 'w', encoding='utf-8') as f:
        f.write('\n'.join(result_str))
    evaluate_by_dicts(out_path, eval_path)
    with open(eval_path, 'r', encoding='utf-8') as f:
        evaluations = json.load(f)
    return evaluations # retpo

In [60]:
client = openai.OpenAI(api_key = "",)
model_name = "chatgpt-4o-latest" # "chatgpt-4o-latest"#  "gpt-4o-mini-2024-07-18" # "chatgpt-4o-latest"# "gpt-4o-mini-2024-07-18"# "chatgpt-4o-latest"# "gpt-4o-mini-2024-07-18"

def llm_eval(input_path):
    scores = []

    with open(input_path, 'r', encoding='utf-8') as eval_file:
        for idx, line in tqdm(enumerate(eval_file)):
            dic = json.loads(line)
            question = dic['question']
            groundtruth = dic['answers']
            prediction = dic['generated_answer']
            prompt = format_query_prompt(question, groundtruth, prediction)
            generated_eval = run_llm(client, model_name, prompt)
            score = extract_score_from_text(generated_eval)
            scores += [score]
            
    return scores # np.mean(scores)* 100
    

In [37]:

path = "/data2/../nlp_data/generation/topiocqa"
generate_eval(top_docs = runs, 
              out_path = path+"test_ours.json",
              eval_path = path+"test_ours_out.json",
              topk=4, 
             )

llmcs_eval = generate_eval(top_docs = runs_llmcs, 
              out_path = path+"test_llmcs.json",
              eval_path = path+"test_llmcs_out.json",
                          topk=4, )
print(llmcs_eval)

path = "/data2/../nlp_data/generation/topiocqa"
generate_eval(top_docs = runs_retpo, 
              out_path = path+"test_retpo.json",
              eval_path = path+"test_retpo_out.json",
             topk=4, )

hyde_eval = generate_eval(top_docs = runs_hyde, 
              out_path = path+"test_hyde.json",
              eval_path = path+"test_hyde_out.json",
              topk=4, )
print(hyde_eval)

convgqr_eval = generate_eval(top_docs = runs_convgqr, 
              out_path = path+"test_convgqr.json",
              eval_path = path+"test_convgqr_out.json",
              topk=4, )
print(convgqr_eval)


t5qr_eval = generate_eval(top_docs = runs_t5qr, 
              out_path = path+"test_t5qr.json",
              eval_path = path+"test_t5qr_out.json",
                         topk=4, )
print(t5qr_eval)




infocqr_eval = generate_eval(top_docs = runs_infocqr, 
              out_path = path+"test_infocqr.json",
              eval_path = path+"test_infocqr_out.json", 
                            topk=4, )
print(infocqr_eval)

hydellm_eval = generate_eval(top_docs = runs_hydellm, 
              out_path = path+"test_hydellm.json",
              eval_path = path+"test_hydellm_out.json",
                            topk=4, )
print(hydellm_eval)

In [64]:
llm_ours = llm_eval(input_path=path+"test_ours.json")
print(np.mean(llm_ours)*100)


llm_llmcs = llm_eval(input_path=path+"test_llmcs.json")
print(np.mean(llm_llmcs)*100)


llm_retpo = llm_eval(input_path=path+"test_retpo.json")
print(np.mean(llm_retpo)*100)


llm_hyde = llm_eval(input_path=path+"test_hyde.json")
print(np.mean(llm_hyde)*100)
llm_convgqr = llm_eval(input_path=path+"test_convgqr.json")
print(np.mean(llm_convgqr)*100)
llm_t5qr = llm_eval(input_path=path+"test_t5qr.json")
print(np.mean(llm_t5qr)*100)

llm_infocqr = llm_eval(input_path=path+"test_infocqr.json")
print(np.mean(llm_infocqr)*100)
llm_hydellm = llm_eval(input_path=path+"test_hydellm.json")
print(np.mean(llm_hydellm)*100)
