In [None]:
import os
os.environ['JAVA_HOME'] = "/usr/lib/jvm/java-11-openjdk-amd64"

### topi 13k turns!

In [None]:
# from shared_utils.indexing_utils import SparseIndexer, DocumentCollection
import json
import jsonlines
from tqdm import tqdm
from copy import deepcopy
import io
import argparse
from statistics import mean, stdev

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pytrec_eval


In [None]:
path = "/data/../nlp_data/LongAlpaca-12k/LongAlpaca-12k.json"
lines = json.load(open(path,"r", encoding="utf-8"))

attr_required = list(lines[0].keys())
list(lines[0].keys()), list(lines[-1].keys())

In [None]:
def set_prompt(line, args, n_recent=3):
    # (w pssg, wo pssg) x (icl, zsl): prompt_type, use_pssg
    
    # Inst: 
    # "Given a question and its context, decontextualize the question by addressing coreference and omission issues. 
    # The resulting question should retain its original meaning and be as informative as possible, 
    # and should not duplicate any previously asked questions in the context."
    # if args.use_pssg:    
    #     Instruction = "Given a question, its previous questions (Q) & answers (A) and retrieved documents (Document), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context."
    # else:
    #     Instruction = "Given a question and its context, decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context."
    
    if args.use_pssg:    
        if args.instruct_pssg == 'original':
            Instruction = "Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context."
        elif args.instruct_pssg == 'filter_irrelevant':
            Instruction = "Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful."
        elif args.instruct_pssg == 'summary':
            Instruction = "Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Given the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question."
        elif args.instruct_pssg == 'filter_irrelevant_summary':
            Instruction = "Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful. Considering the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question."
        elif args.instruct_pssg == 'reasoning':
            Instruction = "Given a question, its previous questions (Q), retrieved documents (Document), and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Use the documents to enrich your question if they're relevant, or draw on the Q&A context for a precise reformulation if the documents aren't helpful."
            Instruction = Instruction + " Before rewriting, evaluate which parts of the context are essential to address, helping to rewrite your question effectively."
    else:
        if args.instruct_pssg == 'original':
            Instruction = "Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context."
        elif args.instruct_pssg == 'filter_irrelevant':
            Instruction = "Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context."
        elif args.instruct_pssg == 'summary':
            Instruction = "Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Given the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question."
        elif args.instruct_pssg == 'filter_irrelevant_summary':
            Instruction = "Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context. Considering the potential noise and dependencies within the context, creating a concise summary of it first could be an effective strategy for accurately rephrasing the question. Therefore, start by summarizing the context before you decontextualize the question."
        elif args.instruct_pssg == 'reasoning':
            Instruction = "Given a question, its previous questions (Q) and answers (A), decontextualize the question by addressing coreference and omission issues. The resulting question should retain its original meaning and be as informative as possible, and should not duplicate any previously asked questions in the context."
            Instruction = Instruction + " Before rewriting, evaluate which parts of the context are essential to address, helping to rewrite your question effectively."
            
            
    curr_ctx = []
    if args.use_pssg: # using {}-passages-per-line.json
        n_prev_QAturn = len(line['NewContext'])//2
        s_idx_adddocs = max(n_prev_QAturn - n_recent, 0) * 2 # starting-idx to add passage
        p_docs = [ f"Document: {d}." for d in line['Truth_passages_contents'][-n_recent:] ] # recent top1 docs
        
        p_docs_i = 0
        # (Q-Doc-A)-...
        for idx, sent in enumerate(line['NewContext']): # run the below when turn_no >= 1
            if idx % 2 == 0:
                curr_ctx.append(f"Q: {sent}")
                if idx >= s_idx_adddocs:
                    curr_ctx.append(p_docs[p_docs_i])
                    p_docs_i += 1
                else:
                    curr_ctx.append("Document: No relevant documents.")
            else:
                curr_ctx.append(f"A: {sent}")
                
    else:
        ctx = [ x for pair in zip(line["history_query"], line["history_answer"]) for x in pair]
        for idx, sent in enumerate(ctx):
            if idx % 2 == 0:
                curr_ctx.append(f"Q: {sent}")
            else:
                curr_ctx.append(f"A: {sent}")
                
    curr_ctx = " ".join(curr_ctx)
    curr_ctx = f"[{curr_ctx}]"
    
    if args.prompt_type == "icl":
        if args.use_pssg:
            # e1 = "Context: [Q: When was Born to Fly released? Document: Born to Fly is a song co-written and recorded by American country music artist Sara Evans. It was released in June 2000 as the first single and title track from her 2000 album of the same name. A: Sara Evans's third studio album, Born to Fly, was released on October 10, 2000.] \nQuestion: Was Born to Fly well received by critics?\nRewrite: Was Born to Fly well received by critics?"
            # e2 = "Context: [Q: When was Keith Carradine born? Document: No relevant documents. A: Keith Ian Carradine was born August 8, 1949. Q: Is he married? Document: Carradine married Sandra Will on February 6, 1982. They were separated in 1993, before Will filed for divorce in 1999. The couple had two children: Cade Richmond Carradine (born July 19, 1982) and Sorel Johannah Carradine (born June 18, 1985). A: Keith Carradine married Sandra Will on February 6, 1982.]\nQuestion: Do they have any children?\nRewrite: Do Keith Carradine and Sandra Will have any children?"
            # e3 = "Context: [Q: Who proposed that atoms are the basic units of matter? Document: Arguably the most important of all Dalton's investigations are concerned with the atomic theory in chemistry. While his name is inseparably associated with this theory, the origin of Dalton's atomic theory is not fully understood.[19][20] The theory may have been suggested to him either by researches on ethylene (olefiant gas) and methane (carburetted hydrogen) or by analysis of nitrous oxide (protoxide of azote) and nitrogen dioxide (deutoxide of azote), both views resting on the authority of Thomas Thomson. A: John Dalton proposed that each chemical element is composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds.] \nQuestion: How did the proposal come about?\nRewrite: How did John Dalton's proposal that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds come about?"
            # e4 = "Context: [Q: What is it called when two liquids separate? Document: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension.[1] The layer closer to the top of the container—the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out—is poured off, leaving denser liquid or the solid behind. The process typically is unable to remove all of the top layer, meaning the separation is incomplete or at least one of the two separated components is still contaminated by the other one. A: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension.  Q: How does the separation occur?  Document: No relevant documents.  A: The layer closer to the top of the container-the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out-is poured off.]\nQuestion: Then what happens?\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?"
            # # e4 = "Context: [No previous conversation.]\nQuestion: Then what happens?\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?"
            e1 = "Context: [Q: When was Born to Fly released? Document: Born to Fly is a song co-written and recorded by American country music artist Sara Evans. It was released in June 2000 as the first single and title track from her 2000 album of the same name. A: Sara Evans's third studio album, Born to Fly, was released on October 10, 2000.] \nQuestion: Was Born to Fly well received by critics?\nRewrite: Was Born to Fly well received by critics?"
            e2 = "Context: [Q: When was Keith Carradine born? Document: No relevant documents. A: Keith Ian Carradine was born August 8, 1949. Q: Is he married? Document: Carradine married Sandra Will on February 6, 1982. They were separated in 1993, before Will filed for divorce in 1999. The couple had two children: Cade Richmond Carradine (born July 19, 1982) and Sorel Johannah Carradine (born June 18, 1985). A: Keith Carradine married Sandra Will on February 6, 1982.]\nQuestion: Do they have any children?\nRewrite: Do Keith Carradine and Sandra Will have any children?"
            e3 = "Context: [Q: Who proposed that atoms are the basic units of matter? Document: Arguably the most important of all Dalton's investigations are concerned with the atomic theory in chemistry. While his name is inseparably associated with this theory, the origin of Dalton's atomic theory is not fully understood. The theory may have been suggested to him either by researches on ethylene (olefiant gas) and methane (carburetted hydrogen) or by analysis of nitrous oxide (protoxide of azote) and nitrogen dioxide (deutoxide of azote), both views resting on the authority of Thomas Thomson. A: John Dalton proposed that each chemical element is composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds.] \nQuestion: How did the proposal come about?\nRewrite: How did John Dalton's proposal that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds come about?"
            e4 = "Context: [Q: What is it called when two liquids separate? Document: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension. The layer closer to the top of the container—the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out—is poured off, leaving denser liquid or the solid behind. The process typically is unable to remove all of the top layer, meaning the separation is incomplete or at least one of the two separated components is still contaminated by the other one. A: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension.  Q: How does the separation occur?  Document: No relevant documents.  A: The layer closer to the top of the container-the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out-is poured off.]\nQuestion: Then what happens?\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?"
            if args.instruct_pssg == 'original' or args.instruct_pssg == 'filter_irrelevant':
                e1, e2, e3, e4 = e1, e2, e3, e4

            elif args.instruct_pssg == 'summary' or args.instruct_pssg == 'filter_irrelevant_summary':
                e1_tldr = "TLDR Summary: Born to Fly is both a song and the title of Sara Evans's third studio album. The song was released as the album's first single in June 2000, and the album itself was released on October 10, 2000."
                e2_tldr = "TLDR Summary: Keith Ian Carradine, born on August 8, 1949, married Sandra Will on February 6, 1982. They separated in 1993, and Sandra Will filed for divorce in 1999. The couple has two children, Cade Richmond Carradine and Sorel Johannah Carradine."
                e3_tldr = "TLDR Summary: John Dalton proposed the atomic theory, which posits that atoms are the fundamental units of matter, with each chemical element being composed of unique atoms that can combine to form complex compounds. The exact inspiration for Dalton's theory is unclear, but it might have stemmed from his research on gases or the analysis of nitrous oxide and nitrogen dioxide, possibly influenced by Thomas Thomson."
                e4_tldr = "TLDR Summary: The context explains decantation, a separation process for mixtures of immiscible liquids or liquid-solid mixtures like suspensions. It involves pouring off the top, less dense liquid or the liquid cleared of sediment, leaving behind the denser liquid or solid. The process may not completely remove the top layer, potentially leaving some contamination."

                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_tldr +\
                         ' The rewritten query is ' + "\"" + e1.split('Rewrite: ')[-1] + "\""
                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_tldr +\
                         ' The rewritten query is ' + "\"" + e2.split('Rewrite: ')[-1] + "\""
                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_tldr +\
                         ' The rewritten query is ' + "\"" + e3.split('Rewrite: ')[-1] + "\""
                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_tldr +\
                         ' The rewritten query is ' + "\"" + e4.split('Rewrite: ')[-1] + "\""

            elif args.instruct_pssg == 'reasoning':
                e1_reasoning = "The question is already clear."
                e2_reasoning = "The original question uses the pronoun \"they\" which is ambiguous without explicit context. By specifying \"Keith Carradine and Sandra Will\" as the subjects, the revised question eliminates any ambiguity about who \"they\" refers to, directly connecting the inquiry to the individuals mentioned in the previous context."
                e3_reasoning = "The original question omits what the proposal actually is. Including the specific details of Dalton's atomic theory (that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds) directly in the question adds necessary context and allows the question to stand alone, making it understandable even without prior knowledge of the conversation."
                e4_reasoning = "The context revolves around decantation, a specific scientific process. Recognizing this as the core topic ensures that the rewrite focuses on the next logical step in this particular procedure. Question: Then what happens? is vague without specifying what it refers to. By identifying that it refers to the action of pouring off the top layer in the decantation process, we address coreference issues, making it clear what the 'then' is referring to."

                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_reasoning +\
                         ' The rewritten query is ' + "\"" + e1.split('Rewrite: ')[-1] + "\""
                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_reasoning +\
                         ' The rewritten query is ' + "\"" + e2.split('Rewrite: ')[-1] + "\""
                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_reasoning +\
                         ' The rewritten query is ' + "\"" + e3.split('Rewrite: ')[-1] + "\""
                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_reasoning +\
                         ' The rewritten query is ' + "\"" + e4.split('Rewrite: ')[-1] + "\""

        else: # without past passages    
            
            e1 = "Context: [Q: When was Born to Fly released? A: Sara Evans's third studio album, Born to Fly, was released on October 10, 2000.]\nQuestion: Was Born to Fly well received by critics?\nRewrite: Was Born to Fly well received by critics?"
            e2 = "Context: [Q: When was Keith Carradine born? A: Keith Ian Carradine was born August 8, 1949. Q: Is he married? A: Keith Carradine married Sandra Will on February 6, 1982.]\nQuestion: Do they have any children?\nRewrite: Do Keith Carradine and Sandra Will have any children?"
            e3 = "Context: [Q: Who proposed that atoms are the basic units of matter? A: John Dalton proposed that each chemical element is composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds.]\nQuestion: How did the proposal come about?\nRewrite: How did John Dalton's proposal that each chemical element is composed of atoms of a single unique type, and they can combine to form more complex structures called chemical compounds come about?"
            e4 = "Context: [Q: What is it called when two liquids separate? A: Decantation is a process for the separation of mixtures of immiscible liquids or of a liquid and a solid mixture such as a suspension. Q: How does the separation occur? A: The layer closer to the top of the container-the less dense of the two liquids, or the liquid from which the precipitate or sediment has settled out-is poured off.]\nQuestion: Then what happens?\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?"
            # e4 = "Context: [No previous conversation.]\nQuestion: Then what happens?\nRewrite: Then what happens after the layer closer to the top of the container is poured off with decantation?"
            
            if args.instruct_pssg == 'original' or args.instruct_pssg == 'filter_irrelevant':
                e1, e2, e3, e4 = e1, e2, e3, e4
            
            elif args.instruct_pssg == 'summary' or args.instruct_pssg == 'filter_irrelevant_summary':
                e1_tldr = "TLDR Summary: Inquiry about the release date of Sara Evans's album \"Born to Fly,\" which was on October 10, 2000."
                e2_tldr = "TLDR Summary: Inquiry about Keith Carradine's birth date, which is August 8, 1949, and marital status, revealing he married Sandra Will on February 6, 1982."
                e3_tldr = "TLDR Summary: John Dalton proposed atoms as the basic units of matter, which can combine to form chemical compounds."
                e4_tldr = "TLDR Summary: Decantation separates mixtures of immiscible liquids or liquids and solids by pouring off the top layer after settling."

                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_tldr +\
                         ' The rewritten query is ' + "\"" + e1.split('Rewrite: ')[-1] + "\""
                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_tldr +\
                         ' The rewritten query is ' + "\"" + e2.split('Rewrite: ')[-1] + "\""
                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_tldr +\
                         ' The rewritten query is ' + "\"" + e3.split('Rewrite: ')[-1] + "\""
                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_tldr +\
                         ' The rewritten query is ' + "\"" + e4.split('Rewrite: ')[-1] + "\""

            elif args.instruct_pssg == 'reasoning':
                e1_reasoning = "The question is already clear."
                e2_reasoning = "The question \"Do they have any children?\" is ambiguous without directly referencing who \"they\" are. By naming \"Keith Carradine and Sandra Will\" explicitly, we eliminate any ambiguity regarding who the question is about."
                e3_reasoning = "The question \"How did the proposal come about?\" is vague because it doesn't specify which proposal it's referring to. By restating that the proposal is about \"each chemical element being composed of atoms of a single, unique type, and they can combine to form more complex structures called chemical compounds,\" we make the question self-contained."
                e4_reasoning = "The question \"Then what happens?\" is vague without specifying which process it refers to. By stating \"after the layer closer to the top of the container is poured off,\" the question explicitly refers to the action that was previously described, making it clear which stage of the process we're inquiring about what happens next."

                e1 = e1.split('Rewrite:')[0] + 'Rewrite: ' + e1_reasoning +\
                         ' The rewritten query is ' + "\"" + e1.split('Rewrite: ')[-1] + "\""
                e2 = e2.split('Rewrite:')[0] + 'Rewrite: ' + e2_reasoning +\
                         ' The rewritten query is ' + "\"" + e2.split('Rewrite: ')[-1] + "\""
                e3 = e3.split('Rewrite:')[0] + 'Rewrite: ' + e3_reasoning +\
                         ' The rewritten query is ' + "\"" + e3.split('Rewrite: ')[-1] + "\""
                e4 = e4.split('Rewrite:')[0] + 'Rewrite: ' + e4_reasoning +\
                         ' The rewritten query is ' + "\"" + e4.split('Rewrite: ')[-1] + "\""
                         

        prompt = f"{Instruction}\n\n{e1}\n\n{e2}\n\n{e3}\n\n{e4}\n\nContext: {curr_ctx}\nQuestion: {line['query']}\nRewrite: "
        
        
    elif args.prompt_type == "zsl":
        prompt = f"{Instruction}\n\nContext: {curr_ctx}\nQuestion: {line['Question']}\nRewrite: "
    # print("prompt: ", prompt)

    return prompt

In [None]:

def print_res(run_file, qrel_data, rel_threshold, return_summary=True):
    with open(run_file, 'r' )as f:
        run_data = f.readlines()
    # with open(qrel_file, 'r') as f:
    #     qrel_data = f.readlines()
    # print(run_data)
    qrels = {}
    qrels_ndcg = {}
    runs = {}
    
    for line in qrel_data:
        line = line.strip().split()
        query = line[0]
        passage = line[2]
        rel = int(line[3])
        if query not in qrels:
            qrels[query] = {}
        if query not in qrels_ndcg:
            qrels_ndcg[query] = {}

        # for NDCG
        qrels_ndcg[query][passage] = rel
        # for MAP, MRR, Recall
        if rel >= rel_threshold:
            rel = 1
        else:
            rel = 0
        qrels[query][passage] = rel
    
    for line in run_data:
        line = line.split(" ")
        query = line[0]
        passage = line[2]
        rel = int(line[4])
        if query not in runs:
            runs[query] = {}
        runs[query][passage] = rel

    # pytrec_eval eval
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"map", "recip_rank", "recall.1", "recall.3", "recall.5", "recall.10", 
                         "recall.20", "recall.30", "recall.100", })
    res = evaluator.evaluate(runs)
    # map_list = [v['map'] for v in res.values()]
    mrr_list = [v['recip_rank'] for v in res.values()]
    recall_1_list = [v["recall_1"] for v in res.values()]
    recall_3_list = [v["recall_3"] for v in res.values()]
    recall_5_list = [v["recall_5"] for v in res.values()]
    recall_10_list = [v["recall_10"] for v in res.values()]
    recall_20_list = [v["recall_20"] for v in res.values()]
    recall_30_list = [v["recall_30"] for v in res.values()]
    recall_100_list = [v["recall_100"] for v in res.values()]
    # print(res)

    evaluator = pytrec_eval.RelevanceEvaluator(qrels_ndcg, {"ndcg_cut.3"})
    res_ndcg = evaluator.evaluate(runs)
    ndcg_3_list = [v['ndcg_cut_3'] for v in res_ndcg.values()]
    # print(res)
    
    res_summary = {
            # "MAP": round(100*np.average(map_list),2),
            "MRR": round(100*np.average(mrr_list),2),
            "NDCG@3": round(100*np.average(ndcg_3_list),2),
            "Recall@1": round(100*np.average(recall_1_list), 2),
            "Recall@3": round(100*np.average(recall_3_list), 2),
            "Recall@5": round(100*np.average(recall_5_list), 2),
            "Recall@10": round(100*np.average(recall_10_list), 2),
            "Recall@20": round(100*np.average(recall_20_list), 2),
            "Recall@30": round(100*np.average(recall_30_list), 2),
            "Recall@100": round(100*np.average(recall_100_list), 2),
        }
    if return_summary:
        return res_summary 
    else:
        for k in res.keys():
            res[k].update(res_ndcg[k])
        return res


def print_res_pseudo_qrels(run_file, pseudo_qrels, rel_threshold, return_summary=True):
    with open(run_file, 'r' )as f:
        run_data = f.readlines()
    # with open(qrel_file, 'r') as f:
    #     qrel_data = f.readlines()
    # print(run_data)
    qrels = pseudo_qrels # {}
    qrels_ndcg = pseudo_qrels # {}
    runs = {}
    
    
    for line in run_data:
        line = line.split(" ")
        query = line[0]
        passage = line[2]
        rel = int(line[4])
        if query not in runs:
            runs[query] = {}
        runs[query][passage] = rel

    # pytrec_eval eval
    evaluator = pytrec_eval.RelevanceEvaluator(qrels, {"map", "recip_rank", "recall.1", "recall.3", "recall.5", "recall.10", 
                         "recall.20", "recall.30", "recall.100", })
    res = evaluator.evaluate(runs)
    # map_list = [v['map'] for v in res.values()]
    mrr_list = [v['recip_rank'] for v in res.values()]
    recall_1_list = [v["recall_1"] for v in res.values()]
    recall_3_list = [v["recall_3"] for v in res.values()]
    recall_5_list = [v["recall_5"] for v in res.values()]
    recall_10_list = [v["recall_10"] for v in res.values()]
    recall_20_list = [v["recall_20"] for v in res.values()]
    recall_30_list = [v["recall_30"] for v in res.values()]
    recall_100_list = [v["recall_100"] for v in res.values()]
    # print(res)

    evaluator = pytrec_eval.RelevanceEvaluator(qrels_ndcg, {"ndcg_cut.3"})
    res_ndcg = evaluator.evaluate(runs)
    ndcg_3_list = [v['ndcg_cut_3'] for v in res_ndcg.values()]
    # print(res)
    
    res_summary = {
            # "MAP": round(100*np.average(map_list),2),
            "MRR": round(100*np.average(mrr_list),2),
            "NDCG@3": round(100*np.average(ndcg_3_list),2),
            "Recall@1": round(100*np.average(recall_1_list), 2),
            "Recall@3": round(100*np.average(recall_3_list), 2),
            "Recall@5": round(100*np.average(recall_5_list), 2),
            "Recall@10": round(100*np.average(recall_10_list), 2),
            "Recall@20": round(100*np.average(recall_20_list), 2),
            "Recall@30": round(100*np.average(recall_30_list), 2),
            "Recall@100": round(100*np.average(recall_100_list), 2),
        }
    if return_summary:
        return res_summary 
    else:
        for k in res.keys():
            res[k].update(res_ndcg[k])
        return res
    

In [None]:
qrel_file = "/data/../nlp_data/topiocqa/train_gold.trec"
with open(qrel_file, 'r') as f:
    qrel_data = f.readlines()
    
rel_threshold = 1
run_file_dir = "/data2/../nlp_data/convgqr/bm25/chatgpt/"
p_type = "icl"
inst_pssg = "original"
seed = "0"
temp = "8"
topp = "8"
eval_type = "oracle"
run_file = run_file_dir + f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{seed}_{eval_type}.trec"



In [None]:

all_res = []
for s in range(12):
    fname = f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec"
    run_file = run_file_dir + fname
    res = print_res(run_file, qrel_data, rel_threshold, return_summary=False)
    all_res += [res]
for s in range(15,18):
    fname = f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec"
    run_file = run_file_dir + fname
    res = print_res(run_file, qrel_data, rel_threshold, return_summary=False)
    all_res += [res]
    
# all_res = [res1, ..., res12]
best_res_dict = {}
conv_q_ids = list(all_res[0].keys())
# print("conv_q_ids: ", conv_q_ids[:30])
for conv_q_i in conv_q_ids:
    res_list = []
    for res in all_res:
        res_list += [res[conv_q_i]]
    # print('res_list ', res_list)
    # take best
    # Calculate the average score for each dictionary
    avg_scores = [sum(d.values()) / len(d) for d in res_list]

    # Identify the index of the dictionary with the highest average score
    index_of_highest_avg = avg_scores.index(max(avg_scores))

    # Retrieve the dictionary with the highest average score
    dict_with_highest_avg = res_list[index_of_highest_avg]
    # print('dict_with_highest_avg: ', dict_with_highest_avg)
    best_res_dict[conv_q_i] = dict_with_highest_avg

metrics = best_res_dict
map_list = [v['map'] for v in metrics.values()]
mrr_list = [v['recip_rank'] for v in metrics.values()]
recall_1_list = [v["recall_1"] for v in metrics.values()]
recall_3_list = [v["recall_3"] for v in metrics.values()]
recall_5_list = [v["recall_5"] for v in metrics.values()]
recall_10_list = [v["recall_10"] for v in metrics.values()]
recall_20_list = [v["recall_20"] for v in metrics.values()]
recall_30_list = [v["recall_30"] for v in metrics.values()]
recall_100_list = [v["recall_100"] for v in metrics.values()]
ndcg_3_list = [v['ndcg_cut_3'] for v in metrics.values()]

np.set_printoptions(precision=4)

eval_metrics = {
            "MAP": round(100*np.average(map_list),2),
            "MRR": round(100*np.average(mrr_list),2),
            "NDCG@3": round(100*np.average(ndcg_3_list),2),
    "Recall@1": round(100*np.average(recall_1_list), 2),
    "Recall@3": round(100*np.average(recall_3_list), 2),
    "Recall@5": round(100*np.average(recall_5_list), 2),
    "Recall@10": round(100*np.average(recall_10_list), 2),
    "Recall@20": round(100*np.average(recall_20_list), 2),
    "Recall@30": round(100*np.average(recall_30_list), 2),
    "Recall@100": round(100*np.average(recall_100_list), 2),
        }
eval_metrics

In [None]:
# get perf dicts by conv-turn

all_res = all_res # + [gt_res]
stat_res_dict = {}
detailed_res_dict = {}
conv_q_ids = list(all_res[0].keys())
# print("conv_q_ids: ", conv_q_ids[:30])
for conv_q_i in conv_q_ids:
    res_list = []
    for res in all_res:
        res_list += [res[conv_q_i]]

    values_by_d = []

    # Populate the lists with values from each dictionary
    for d in res_list:
        vals = list(d.values())
        d_avg = mean(vals)
        values_by_d += [d_avg]
        
    # Calculate averages and stds
    avg, std = mean(values_by_d), stdev(values_by_d) if len(values_by_d) > 1 else 0
    stat_res_dict[conv_q_i] = (avg,std)
    
    # save detailed performance
    detailed_res_dict[conv_q_i] = res_list
    
# stat_res_dict




In [None]:
# Step 1 & 2: Calculate the average score for each key
avg_scores = {}
for key, values in detailed_res_dict.items():
    avg_scores[key] = []
    for metrics in values:
        # metrics = v  # Assuming we're always interested in the 0-th element
        avg_score = sum(metrics.values()) / len(metrics)
        avg_scores[key] += [avg_score]

# Step 3: Group keys by their average scores
grouped_by_avg_score = {}
for key, avgs in avg_scores.items():
    groups_by_avg = {}
    for i, avg in enumerate(avgs):
        if avg not in groups_by_avg:
            groups_by_avg[avg] = [i]
        else:
            # print(key, i)
            groups_by_avg[avg].append(i)
            
    grouped_by_avg_score[key] = groups_by_avg

# # If you need the groups sorted by the average score
# sorted_grouped_by_avg_score = dict(sorted(grouped_by_avg_score.items()))

# # Displaying the result
# for avg_score, keys in grouped_by_avg_score.items():
#     print(f"Average Score: {avg_score}, Keys: {keys}")

for k in list(grouped_by_avg_score.keys()):
    grouped_by_avg_score[k] = dict(sorted(grouped_by_avg_score[k].items(), key=lambda k: -k[0]))

In [None]:
# get rewritten queries by conv-turns

pred_file_dir = "/data2/../nlp_data/infocqr_data/topiocqa/"

# load all pred-queries from temp_paths
all_pred_data = {}
pred_i = 0
for s in range(12):
    fname = f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_temp{temp}_p{topp}_sampled.jsonl"
    pred_file = pred_file_dir + fname
    with open(pred_file, "r") as f:
        data = f.readlines()
    data = [json.loads(data[i]) for i in range(len(data))]
    all_pred_data[pred_i] = data
    pred_i += 1
    
for s in range(15,18):
    fname = f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_originalQ_seed{s}_temp{temp}_p{topp}_sampled.jsonl"
    pred_file = pred_file_dir + fname
    with open(pred_file, "r") as f:
        data = f.readlines()
    data = [json.loads(data[i]) for i in range(len(data))]
    all_pred_data[pred_i] = data
    pred_i += 1
    
all_proc_preds = {}
for i,data in all_pred_data.items():
    temp_data = {}
    # if i == len(all_pred_data)-1: # Truth_rewrite
    #     for dt in tqdm(data):
    #         guid = f"{dt['Conversation_no']}_{dt['Turn_no']}"
    #         gt_query = dt['Truth_rewrite']
    #         temp_data[guid] = {'pred_query':gt_query}
    # else:
    for dt in tqdm(data):
        guid = f"{dt['conv_id']}-{dt['turn_id']}"
        pred_query = dt['oracle_utt_text']
        temp_data[guid] = {'pred_query':pred_query}
    all_proc_preds[i] = temp_data
    

In [None]:
import torch

back_retrieval_answer_bm25 = torch.load("/data2/../nlp_data/llm_qr/outputs/BM25/topi_back_retrieval_answer")


In [None]:
# sparse predictions
# get predicted docs 

run_file_dir = "/data2/../nlp_data/convgqr/bm25/chatgpt/"
all_results_cands_bm25 = []

    
for s in range(12): #12
    runs = {}
    fname = f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec"
    run_file = run_file_dir + fname
    # res = print_res(run_file, qrel_data, rel_threshold, return_summary=True)
    with open(run_file, 'r' )as f:
        run_data = f.readlines()
    for line in run_data:
        line = line.split(" ")
        query = line[0]
        passage = line[2]
        rel = int(line[4])
        if query not in runs:
            runs[query] = []
        runs[query] += [int(passage)] # [passage] = rel
        
    all_results_cands_bm25 += [runs] # 12 x 13k x 100
    
for s in range(15,18): #12
    runs = {}
    fname = f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec"
    run_file = run_file_dir + fname
    # res = print_res(run_file, qrel_data, rel_threshold, return_summary=True)
    with open(run_file, 'r' )as f:
        run_data = f.readlines()
    for line in run_data:
        line = line.split(" ")
        query = line[0]
        passage = line[2]
        rel = int(line[4])
        if query not in runs:
            runs[query] = []
        runs[query] += [int(passage)] # [passage] = rel
        
    all_results_cands_bm25 += [runs] # 12 x 13k x 100


In [None]:
# Initialize a dictionary to store results for each cut_pseudo value
pseudo_qrels_dict = {}

# Iterate over cut_pseudo values from 1 to 10
for cut_pseudo in tqdm([3,]):
    pseudo_qrels = {}
    
    # Iterate over keys in grouped_by_avg_score
    for iter_i in range(len(grouped_by_avg_score)):
        qid = list(grouped_by_avg_score.keys())[iter_i]
        
        if qid in back_retrieval_answer_bm25:
            pseudo_qrels[qid] = {}
            n_iters = min(len(back_retrieval_answer_bm25[qid]), 100)
            pseudo_gold = set()
            
            for i in range(n_iters):
                pseudo_gold.add(back_retrieval_answer_bm25[qid][i]['id'])
                if len(pseudo_gold) >= cut_pseudo:
                    break
            
            for passage in pseudo_gold:
                pseudo_qrels[qid][passage] = 1
                
    # Store the result for the current cut_pseudo value
    pseudo_qrels_dict[cut_pseudo] = pseudo_qrels

# Optionally, print the length of pseudo_qrels for each cut_pseudo value
for cut_pseudo, qrels in pseudo_qrels_dict.items():
    print(f"cut_pseudo = {cut_pseudo}, number of qrels: {len(qrels)}")

In [None]:
# Initialize a dictionary to store results for each cut_pseudo value
all_res_pseudo_dict = {}

# Iterate over each cut_pseudo value and its corresponding pseudo_qrels
for cut_pseudo, pseudo_qrels in tqdm(pseudo_qrels_dict.items()):
    all_res_pseudo = []
    
    # Iterate over seed values and compute results
    for s in range(12):
        fname = f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec"
        run_file = run_file_dir + fname
        res = print_res_pseudo_qrels(run_file, pseudo_qrels, rel_threshold, return_summary=False)
        all_res_pseudo.append(res)
        
    for s in range(15, 18):
        fname = f"train_chatgpt_{p_type}_WOpssg_{inst_pssg}_seed{s}_{eval_type}.trec"
        run_file = run_file_dir + fname
        res = print_res_pseudo_qrels(run_file, pseudo_qrels, rel_threshold, return_summary=False)
        all_res_pseudo.append(res)
    
    # Store the result for the current cut_pseudo value
    all_res_pseudo_dict[cut_pseudo] = all_res_pseudo



# Optionally, print the results for each cut_pseudo value
for cut_pseudo, results in all_res_pseudo_dict.items():
    print(f"cut_pseudo = {cut_pseudo}, results: {len(results), len(results[0])}")


In [None]:
   
# all_res = [res1, ..., res12]
best_res_dict = {}
pseudo_avg_scores = {}
conv_q_ids = list(all_res_pseudo[0].keys())
indice_of_jaccard_highest = {}
# print("conv_q_ids: ", conv_q_ids[:30])
for conv_q_i in conv_q_ids:
    res_list = []
    for res in all_res:
        res_list += [res[conv_q_i]]
        
    pseudo_res_lists_dict = {}

    for cut_pseudo, all_res_pseudo in all_res_pseudo_dict.items():
        pseudo_res_list = []

        for res in all_res_pseudo:
            pseudo_res_list.append(res[conv_q_i])

        pseudo_res_lists_dict[cut_pseudo] = pseudo_res_list
        
    # take best
    # Calculate the average score for each dictionary
    avg_scores = []
    for pred_i in range(len(pseudo_res_lists_dict[1])):
        scores = 0
        for cut in list(pseudo_res_lists_dict.keys()):
            res_dict  = pseudo_res_lists_dict[cut][pred_i]
            scores += (sum(res_dict.values()) / len(res_dict))*(1/cut)
        avg_scores += [scores]
    
    pseudo_avg_scores[conv_q_i] = avg_scores
    
    # Identify the index of the dictionary with the highest average score
    index_of_highest_avg = avg_scores.index(max(avg_scores))
    
    indice_of_jaccard_highest[conv_q_i] = index_of_highest_avg
    # Retrieve the dictionary with the highest average score
    dict_with_highest_avg = res_list[index_of_highest_avg]
    # print('dict_with_highest_avg: ', dict_with_highest_avg)
    best_res_dict[conv_q_i] = dict_with_highest_avg

metrics = best_res_dict
map_list = [v['map'] for v in metrics.values()]
mrr_list = [v['recip_rank'] for v in metrics.values()]
recall_100_list = [v['recall_100'] for v in metrics.values()]
recall_20_list = [v['recall_20'] for v in metrics.values()]
recall_10_list = [v['recall_10'] for v in metrics.values()]
recall_5_list = [v['recall_5'] for v in metrics.values()]
ndcg_3_list = [v['ndcg_cut_3'] for v in metrics.values()]

np.set_printoptions(precision=4)

eval_metrics = {
            "MAP": round(100*np.average(map_list),2),
            "MRR": round(100*np.average(mrr_list),2),
            "NDCG@3": round(100*np.average(ndcg_3_list),2),
            "Recall@5": round(100*np.average(recall_5_list),2),
            "Recall@10": round(100*np.average(recall_10_list),2),
            "Recall@20": round(100*np.average(recall_20_list),2),
            "Recall@100": round(100*np.average(recall_100_list),2), 
        }
eval_metrics

In [None]:
pseudo_best_reQ = {}
# for k in list(indice_of_jaccard_highest['Abm-Pbm'].keys()):
pseudo_best_reQ['1-1'] = {qid: all_proc_preds[pred_ind][qid]['pred_query'] \
                            for qid, pred_ind in indice_of_jaccard_highest.items()}

pseudo_best_reQ['1-1']['1-10'],# pseudo_best_reQ['1-100']['1-7']

In [None]:
pseudo_best_reQ = {}
for k in list(indice_of_jaccard_highest['Abm-Pbm'].keys()):
    pseudo_best_reQ[k] = {qid: all_proc_preds[pred_ind][qid]['pred_query'] \
                    for qid, pred_ind in indice_of_jaccard_highest['Abm-Pbm'][k].items()}

pseudo_best_reQ['1-1']['1-10'],# pseudo_best_reQ['1-100']['1-7']

In [None]:
print(sum(np.array(list(indice_of_jaccard_highest.values()))<12))

# print(sum(np.array(list(indice_of_jaccard_highest['Abm-Pbm']['1-100'].values()))<12))

In [None]:
pseudo_best_reQ.keys()

In [None]:
split = 'train'
root = "/data/../nlp_data/topiocqa/" # 'datasets/qrecc/' 
input_data = "train_new.json" #
with open(os.path.join(root, input_data), encoding="utf-8") as f:
    lines = f.readlines()
    
lines = [json.loads(l) for l in lines]
args = {
    'use_pssg': False, 
    'instruct_pssg': 'original',
    'prompt_type': 'icl'
}
args = argparse.Namespace(**args)
        
gt_data = []
best_data = {}
for k in list(pseudo_best_reQ.keys()):
    best_data[k] = []
    for line in tqdm(lines):
        conv_id = f"{line['conv_id']}-{line['turn_id']}"

        prompt = set_prompt(line, args)
        if conv_id in pseudo_best_reQ[k]:
            best_data[k] += [ {'instruction':prompt, 'output':pseudo_best_reQ[k][conv_id], } ]

best_data.keys(), len(best_data['1-1']), 

In [None]:
for k in list(pseudo_best_reQ.keys()):
    f_path = "/data/../nlp_data/LongAlpaca-12k/Topiocqa_SFT.json"
    with open(f_path, 'w', encoding='utf-8') as f:
        json.dump(best_data[k], f, indent=4)
        
        

--- **edited until here**