In [1]:
%cd ..

/shared_data0/weiqiuy/github/ALCE


In [87]:
import os
import openai
import json
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer
import time
import string
import numpy as np
import torch
import re
import sys
# sys.path.append('..')
from searcher import SearcherWithinDocs
import yaml
from utils import *
import nltk
nltk.download('punkt')
from nltk import sent_tokenize
from run import LLM


# Define an argparse.ArgumentParser instance within a Jupyter notebook environment.
import argparse

# Define a function to encapsulate the argparse setup to mimic the script behavior in a notebook.
def setup_argparse():
    parser = argparse.ArgumentParser(description="Setup for an interactive command line interface simulation.")

    # Adding arguments as specified
    parser.add_argument("--config", type=str, default=None, help="Path to the config file")
    parser.add_argument("--prompt_file", type=str, help="Path to the prompt file")
    parser.add_argument("--eval_file", type=str, help="Path to the eval file")
    parser.add_argument("--quick_test", type=int, default=None, help="Quickly test a few examples")
    parser.add_argument("--ndoc", type=int, help="Number of documents")
    parser.add_argument("--shot", type=int, help="Number of ICL demonstrations")
    parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
    parser.add_argument("--no_doc_in_demo", type=bool, default=False, help="Whether to remove the documents in the demos")
    parser.add_argument("--fewer_doc_in_demo", type=bool, default=False, help="Whether to use fewer documents in the demos")
    parser.add_argument("--ndoc_in_demo", type=int, default=None, help="When using --fewer_doc_in_demo, use this to designate how many docs in demo")
    parser.add_argument("--dataset_name", type=str, help="Name of the dataset (for saving)")
    parser.add_argument("--tag", type=str, help="Tag of run (for saving)")
    parser.add_argument("--model", type=str, help="Model to use")
    parser.add_argument("--openai_api", type=bool, default=False, help="Whether to use OpenAI API")
    parser.add_argument("--azure", action="store_true", default=False, help="Azure openai API")
    parser.add_argument("--temperature", type=float, default=0.5, help="Temperature for decoding")
    parser.add_argument("--top_p", type=float, default=1.0, help="Nucleus sampling top-p")
    parser.add_argument("--max_new_tokens", type=int, default=300, help="Max number of new tokens to generate in one step")
    parser.add_argument("--max_length", type=int, default=2048, help="Max length the model can take. Should set properly wrt the model to avoid position overflow.")
    parser.add_argument("--num_samples", type=int, default=1, help="Sample multiple answers.")
    parser.add_argument("--use_shorter", type=str, default=None, help="Whether to use summary data or extraction data for documents. Option: None, `summary`, `extraction`")
    parser.add_argument("--interactive", type=bool, default=False, help="Whether to run in interactive mode")
    parser.add_argument("--interactive_query", type=str, default=None, help="The query to use in interactive mode, either `doc_id` (corresponding to interact in paper) or `search` (corresponding to inlinesearch in paper).")
    parser.add_argument("--retriever", type=str, default=None, help="When using interactive search mode, which retriever to use. Options: `tfidf`, `gtr-t5-large`")
    parser.add_argument("--retriever_device", type=str, default="cuda", help="Where to put the dense retriever if using. Options: `cuda`, `cpu`")
    parser.add_argument("--retrieve_in_all_docs", type=bool, default=False, help="Retrieve in all documents instead of just top ndoc")
    parser.add_argument("--max_turn", type=int, default=10, help="Max number of all actions")
    parser.add_argument("--max_doc_show", type=int, default=3, help="Max number of documents to show at one time.")
    parser.add_argument("--force_cite_show", type=bool, default=False, help="Force citing the documents that are shown to the model")
    parser.add_argument("--overwrite", action="store_true", help="Overwrite existing citations for posthoc cite")
    
    return parser

# Instantiate the parser setup
parser = setup_argparse()

[nltk_data] Downloading package punkt to /home/runai-home/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
def make_group_prompt(group, group_id, group_prompt):
    # For doc prompt:
    # - {ID}: doc id (starting from 1)
    # - {T}: title
    # - {P}: text
    # use_shorter: None, "summary", or "extraction"

    text = group
    return group_prompt.replace("{P}", text).replace("{ID}", str(group_id+1))


def make_demo_groupgen(item, prompt, ndoc=None, doc_prompt=None, groupgen_instruction=None, 
                       group_prompt=None, use_shorter=None, 
              test=False):
    # For demo prompt
    # - {INST}: the instruction
    # - {D}: the documents
    # - {Q}: the question
    # - {G}: the groups
    # ndoc: number of documents to put in context
    # use_shorter: None, "summary", or "extraction"

    prompt = prompt.replace("{INST}", groupgen_instruction).replace("{Q}", item['question'])
    if "{D}" in prompt:
        if ndoc == 0:
            prompt = prompt.replace("{D}\n", "") # if there is no doc we also delete the empty line
        else:
            doc_list = get_shorter_text(item, item["docs"], ndoc, use_shorter) if use_shorter is not None else item["docs"][:ndoc]
            text = "".join([make_doc_prompt(doc, doc_id, doc_prompt, use_shorter=use_shorter) for doc_id, doc in enumerate(doc_list)])
            prompt = prompt.replace("{D}", text)
            
    if "{G}" in prompt:
        if not test:
            group_list = item["groups"]
            ngroup = len(item["groups"])
            if ngroup == 0:
                prompt = prompt.replace("{G}\n", "") # if there is no group we also delete the empty line
            else:
                text = "".join([make_group_prompt(group, group_id, group_prompt) \
                                for group_id, group in enumerate(group_list)])
                prompt = prompt.replace("{G}", text)
        else:
            prompt = prompt.replace("{G}", "").rstrip()

    return prompt

def make_demo_grouppred(item, prompt, grouppred_instruction=None, group_prompt=None, test=False):
    # For demo prompt
    # - {INST}: the instruction
    # - {G}: the groups
    # - {Q}: the question
    # - {A}: the answers
    # ndoc: number of documents to put in context

    prompt = prompt.replace("{INST}", grouppred_instruction).replace("{Q}", item['question'])
    
    if "{G}" in prompt:
        # prompt = prompt.replace("{G}", item["groups_str"])
        group_list = item["groups"]
        try:
            ngroup = len(item["groups"])
            if ngroup == 0:
                prompt = prompt.replace("{G}\n", "") # if there is no group we also delete the empty line
            else:
                text = "".join([make_group_prompt(group, group_id, group_prompt) \
                                for group_id, group in enumerate(group_list)])
                prompt = prompt.replace("{G}", text)
        except:
            import pdb; pdb.set_trace()
            
    if not test:
        answer = "\n" + "\n".join(item["answer_from_groups"]) if isinstance(item["answer_from_groups"], list) \
                else item["answer_from_groups"]
        prompt = prompt.replace("{A}", "").rstrip() + answer
    else:
        prompt = prompt.replace("{A}", "").rstrip() # remove any space or \n

    return prompt

In [4]:
# retrieval
def remove_citations(sent):
    return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")

def find_external_docs_idx(question, external):
    # find the index of externaol doc that matches the question
    for idx, item in enumerate(external):
        if item['question'] == question:
            return idx
    return None

In [None]:
global autoais_model, autoais_tokenizer
autoais_model, autoais_tokenizer = None, None

In [14]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
    pipeline
)

def _run_nli_autoais(passage, claim):
    """
    Run inference for assessing AIS between a premise and hypothesis.
    Adapted from https://github.com/google-research-datasets/Attributed-QA/blob/main/evaluation.py
    """
    global autoais_model, autoais_tokenizer
    input_text = "premise: {} hypothesis: {}".format(passage, claim)
    input_ids = autoais_tokenizer(input_text, return_tensors="pt").input_ids.to(autoais_model.device)
    with torch.inference_mode():
        outputs = autoais_model.generate(input_ids, max_new_tokens=10)
    result = autoais_tokenizer.decode(outputs[0], skip_special_tokens=True)
    inference = 1 if result == "1" else 0
    return inference

QA_MODEL="gaotianyu1350/roberta-large-squad"
AUTOAIS_MODEL="google/t5_xxl_true_nli_mixture"


def compute_autoais_single(sent, docs):
    """
    Compute AutoAIS score.

    Args:
        sent: str, one sentence
        docs: list of str, multiple sentences
        data: requires field `output` and `docs`
              - docs should be a list of items with fields `title` and `text` (or `phrase` and `sent` for QA-extracted docs)
        citation: check citations and use the corresponding references.
        decontext: decontextualize the output
    """

    global autoais_model, autoais_tokenizer
    if autoais_model is None:
        logger.info("Loading AutoAIS model...")
        autoais_model = AutoModelForSeq2SeqLM.from_pretrained(AUTOAIS_MODEL, torch_dtype=torch.bfloat16, max_memory=get_max_memory(), device_map="auto")
        autoais_tokenizer = AutoTokenizer.from_pretrained(AUTOAIS_MODEL, use_fast=False)

    logger.info(f"Running AutoAIS...")

    def _format_document(doc):
        """Format document for AutoAIS."""

        if "sent" in doc:
            # QA-extracted docs
            return "Title: %s\n%s" % (doc['title'], doc['sent'])
        else:
            return "Title: %s\n%s" % (doc['title'], doc['text'])
        
    sent = remove_citations(sent).strip()
    target_sent = sent
    joint_passage = '\n'.join([_format_document(doc) for doc in docs])
    joint_entail = _run_nli_autoais(joint_passage, target_sent)
    entail = joint_entail
    # calculate the precision score if applicable
    sent_mcite_support = 0
    sent_mcite_overcite = 0
    entail_prec = 0
    if joint_entail and len(docs) > 1:
        sent_mcite_support += 1
        # Precision check: did the model cite any unnecessary documents?
        for psgs_id in range(len(docs)):
            # condition A
            passage = _format_document(docs[psgs_id]) 
            nli_result = _run_nli_autoais(passage, target_sent)

            # condition B
            if not nli_result:
                subset_exclude = docs[:psgs_id] + docs[psgs_id + 1:]
                passage = '\n'.join([_format_document(doc) for doc in subset_exclude])
                nli_result = _run_nli_autoais(passage, target_sent)
                if nli_result: # psgs_id is not necessary
                    flag = 0
                    sent_mcite_overcite += 1 
                else:
                    entail_prec += 1
            else:
                entail_prec += 1
    else:
        entail_prec += joint_entail 
    total_citations = len(docs)
    
    return {
        'entail': entail,
        'entail_prec': entail_prec / total_citations
    }

In [86]:
import os
import torch
import numpy as np
import copy
# from openai import OpenAI
from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import namedtuple, defaultdict





CompletionOutput = namedtuple("CompletionOutput", ["output_text", "completion_output"])


class SoPrompt():
    def __init__(self, 
                 llm,
                 args,
                 prompt_data,
                 get_groupgen_prompt=None,
                 get_grouppred_prompt=None,
                 get_prompt=None,
                 get_closedbook_prompt=None
                ):
        
        self.llm = llm
        self.args = args
        self.prompt_data = prompt_data
        
        if get_groupgen_prompt is None:
            def get_groupgen_prompt(eval_item, test=False):
                groupgen_prompt = make_demo_groupgen(
                    eval_item, prompt=prompt_data["groupgen_prompt"], ndoc=args.ndoc, 
                    doc_prompt=prompt_data["doc_prompt"],
                    groupgen_instruction=prompt_data["groupgen_instruction"], 
                    group_prompt=prompt_data["group_prompt"],
                    test=test
                )
                return groupgen_prompt
        self.get_groupgen_prompt = get_groupgen_prompt
        
        if get_grouppred_prompt is None:
            def get_grouppred_prompt(eval_item, test=False):
                grouppred_prompt = make_demo_grouppred(
                    eval_item, prompt=prompt_data["grouppred_prompt"], 
                    grouppred_instruction=prompt_data["grouppred_instruction"], 
                    group_prompt=prompt_data["group_prompt"],
                    test=test
                )
                return grouppred_prompt
        self.get_grouppred_prompt = get_grouppred_prompt
        
        if get_prompt is None:
            def get_prompt(eval_item, test=False):
                ndoc = args.ndoc
                if args.no_doc_in_demo:
                    ndoc = 0
                elif args.fewer_doc_in_demo:
                    assert args.ndoc_in_demo is not None
                    ndoc = args.ndoc_in_demo
                prompt = make_demo(
                    eval_item, prompt=prompt_data["demo_prompt"], ndoc=ndoc, 
                    doc_prompt=prompt_data["doc_prompt"], 
                    instruction=prompt_data["instruction"], use_shorter=args.use_shorter, test=test
                )
                return prompt
        self.get_prompt = get_prompt
        
        if get_closedbook_prompt is None:
            def get_closedbook_prompt(eval_item, test=False):
                ndoc = args.ndoc
                if args.no_doc_in_demo:
                    ndoc = 0
                elif args.fewer_doc_in_demo:
                    assert args.ndoc_in_demo is not None
                    ndoc = args.ndoc_in_demo
                prompt = make_demo(
                    eval_item, prompt=prompt_data["closedbook_prompt"], ndoc=0, 
                    doc_prompt=prompt_data["doc_prompt"], 
                    instruction=prompt_data["instruction"], use_shorter=args.use_shorter, test=test
                )
                return prompt
        self.get_closedbook_prompt = get_closedbook_prompt
        
        # Load retrieval model
        if "gtr" in args.retriever:
            from sentence_transformers import SentenceTransformer
            self.gtr_model = SentenceTransformer(f'sentence-transformers/{args.retriever}', 
                                            device=args.retriever_device)
        else:
            self.gtr_model = None
        
        self.init_head_prompts()
        
    def init_head_prompts(self):
        # Generate the demonstration part
        head_prompt = ""
        head_groupgen_prompt = ""
        head_grouppred_prompt = ""
        head_closedbook_prompt = ""
        train_ids = np.random.choice(len(self.prompt_data["demos"]), self.args.shot, replace=False)
        for train_id in train_ids:
            train_item = self.prompt_data["demos"][train_id]
            ndoc = self.args.ndoc
            if self.args.no_doc_in_demo:
                ndoc = 0
            elif self.args.fewer_doc_in_demo:
                assert self.args.ndoc_in_demo is not None
                ndoc = self.args.ndoc_in_demo
            head_prompt += self.get_prompt(train_item)
            head_groupgen_prompt += self.get_groupgen_prompt(train_item)
            head_grouppred_prompt += self.get_grouppred_prompt(train_item)
            head_closedbook_prompt += self.get_closedbook_prompt(train_item)
            head_prompt += self.prompt_data["demo_sep"]
            head_groupgen_prompt += self.prompt_data["demo_sep"]
            head_grouppred_prompt += self.prompt_data["demo_sep"]
            head_closedbook_prompt += self.prompt_data["demo_sep"]
        self.head_prompt = head_prompt
        self.head_groupgen_prompt = head_groupgen_prompt
        self.head_grouppred_prompt = head_grouppred_prompt
        self.head_closedbook_prompt = head_closedbook_prompt
    
    def init_model(self):
        raise NotImplementedError
        
    def get_completion(self):
        raise NotImplementedError
    
    def generate(self, data, mode='groupgen_pred', return_dict=False, external=None):
        assert mode in ['vanilla', 'groupgen_pred', 'closedbook', 'closedbook_phc',
                        'closedbook_phc_repred']
        data = copy.deepcopy(data)
        
        if mode == 'groupgen_pred':
            data = self.groupgen_pred(data)
            generation = data['generation']
            if return_dict:
                return data
        elif mode == 'closedbook':
            generation = self.closedbook_predict(data)['generation']
        elif mode == 'closedbook_phc':
            closedbook_output = self.closedbook_predict(data)
            results = self.add_posthoc_cite(closedbook_output['generation'], 
                                               data['question'],
                                               external)
            if return_dict:
                data.update(results)
                return data
            generation = results['postcite_generation']
        elif mode == 'closedbook_phc_repred':
            closedbook_output = self.closedbook_predict(data)
            results = self.add_posthoc_cite(closedbook_output['generation'], 
                                               data['question'],
                                               external)
            # import pdb; pdb.set_trace()
            sents = results['sents']
            docs = results['best_docs']
            ais_results = defaultdict(list)
            docs_entail = []
            for sent, doc in zip(sents, docs):
                ais_results_single = compute_autoais_single(sent, [doc])
                # import pdb; pdb.set_trace()
                ais_results['entail'].append(ais_results_single['entail'])
                ais_results['entail_prec'].append(ais_results_single['entail_prec'])
                if ais_results_single['entail'] > 0:
                    docs_entail.append(doc)
            data['original_docs'] = copy.deepcopy(data['docs'])
            data['posthoc_docs'] = copy.deepcopy(docs_entail)
            new_docs = docs_entail
            for doc in data['docs']:
                if doc['text'] not in [doc_['text'] for doc_ in new_docs]:
                    new_docs.append(doc)
            data['docs'] = new_docs # First put the docs that are retrieved. #docs_entail + data['docs']
            data['entail'] = ais_results['entail']
            data['entail_prec'] = ais_results['entail_prec']
            
            data = self.groupgen_pred(data)
            generation = data['generation']
            if return_dict:
                data.update(results)
                return data
        else: # vanilla
            results = self.backbone_predict(data)
            if return_dict:
                data.update(results)
                return data
            generation = results['generation']
        return generation
    
    def groupgen_pred(self, data):
        groupgen_output = self.group_gen(data)
        groups = groupgen_output['groups']
        groups_str = groupgen_output['generation']
        groupgen_prompt = groupgen_output['prompt']
        groupgen_prompt_len = groupgen_output['prompt_len']
        data['groups'] = groups
        data['groups_str'] = groups_str
        data['groupgen_prompt'] = groupgen_prompt
        data['groupgen_prompt_len'] = groupgen_prompt_len
        grouppred_output = self.group_predict(data)
        generation = grouppred_output['generation']
        grouppred_prompt = grouppred_output['prompt']
        grouppred_prompt_len = grouppred_output['prompt_len']
        data['generation'] = generation
        data['grouppred_prompt'] = grouppred_prompt
        data['grouppred_prompt_len'] = grouppred_prompt_len
        return data
    
    def add_posthoc_cite(self, prev_generation, question, external):
        external_idx = find_external_docs_idx(question, external)
        doc_list = external[external_idx]['docs']
        searcher = SearcherWithinDocs(doc_list, self.args.retriever, 
                                      model=gtr_model, 
                                      device=self.args.retriever_device)
        output = prev_generation.strip().split("\n")[0] # Remove new lines and content after
        output = prev_generation.replace("<|im_end|>", "")
        if "qampari" in self.args.dataset_name:
            sents = [question + ' ' + x.strip() 
                     for x in prev_generation.rstrip(".").split(",")]
        else:
            sents = sent_tokenize(output)

        new_output = ""
        best_doc_ids = []
        for sent in sents:
            original_ref = [int(r[1:])-1 for r in re.findall(r"\[\d+", sent)] 

            if len(original_ref) == 0 or self.args.overwrite:
                sent = remove_citations(sent)
                best_doc_id = searcher.search(sent)
                sent = f"[{best_doc_id+1}] " + sent
                best_doc_ids.append(best_doc_id)

            if "qampari" in self.args.dataset_name:
                new_output += sent.replace(question, '').strip() + ", "
            else:
                new_output += sent + " "

        closedbook_phc_output = new_output.rstrip().rstrip(",")
        generation = closedbook_phc_output
        return {
            'sents': sents,
            'postcite_generation': generation,
            'best_doc_ids': best_doc_ids,
            'best_docs': [external[external_idx]['docs'][best_doc_id] for best_doc_id in best_doc_ids]
        }
        return generation
    
    def llm_predict(self, prompt):
        prompt_len = len(self.llm.tokenizer.tokenize(prompt))
        generation = self.llm.generate(prompt, min(self.args.max_new_tokens, 
                                                   self.args.max_length-prompt_len))
        return {
            'generation': generation,
            'prompt': prompt,
            'prompt_len': prompt_len
        }

    def closedbook_predict(self, data):
        prompt = self.head_closedbook_prompt + self.get_closedbook_prompt(data, test=True)
        results = self.llm_predict(prompt)
        return results
    
    def backbone_predict(self, data):
        prompt = self.head_prompt + self.get_prompt(data, test=True)
        results = self.llm_predict(prompt)
        return results
    
    def group_gen(self, data):
        groupgen_prompt = self.head_groupgen_prompt + self.get_groupgen_prompt(data, test=True)
        results = self.llm_predict(groupgen_prompt)
        groups_str = results['generation']
        groups = [':'.join(group.split(':')[1:]) for group in groups_str.split('\n')]
        results['groups'] = groups
        return results

    def group_predict(self, data):
        grouppred_prompt = self.head_grouppred_prompt + self.get_grouppred_prompt(data, test=True)
        results = self.llm_predict(grouppred_prompt)
        return results

In [186]:
command_line_args = [
    '--config', 'configs/asqa_opt-6.7b_shot1_ndoc3_gtr_default.yaml', 
    '--quick_test', '1' 
]

In [88]:
command_line_args = [
    '--config', 'configs/asqa_turbo_shot2_ndoc3_gtr_default.yaml',
    '--quick_test', '1',
    '--prompt_file', 'prompts/asqa_soprompt.json',
    '--retriever', 'gtr-t5-large',
    '--temperature', '0',
    '--overwrite'
]

In [89]:
# Example of simulating command line arguments
args = parser.parse_args(args=command_line_args)
config = yaml.safe_load(open(args.config)) if args.config is not None else {}
parser.set_defaults(**config)
args = parser.parse_args(args=command_line_args)
for k in args.__dict__:
    print(f"{k}: {args.__dict__[k]}")

if "turbo" in args.model:
    # ChatGPT has a longer max length
    args.max_length = 4096

if "16k" in args.model:
    args.max_length = 16384
elif "32k" in args.model:
    args.max_length = 32768
elif "turbo" in args.model:
    args.max_length = 4096
elif "gpt-4" in args.model:
    args.max_length = 8192
elif "llama-2" in args.model.lower() or "llama2" in args.model.lower():
    args.max_length = 4096

config: configs/asqa_turbo_shot2_ndoc3_gtr_default.yaml
prompt_file: prompts/asqa_soprompt.json
eval_file: data/asqa_eval_gtr_top100.json
quick_test: 1
ndoc: 3
shot: 2
seed: 42
no_doc_in_demo: False
fewer_doc_in_demo: False
ndoc_in_demo: None
dataset_name: asqa
tag: gtr
model: gpt-3.5-turbo-0301
openai_api: True
azure: False
temperature: 0.0
top_p: 1.0
max_new_tokens: 300
max_length: 2048
num_samples: 1
use_shorter: None
interactive: False
interactive_query: None
retriever: gtr-t5-large
retriever_device: cuda
retrieve_in_all_docs: False
max_turn: 10
max_doc_show: 3
force_cite_show: False
overwrite: True


In [90]:
# Load the model or setup the API
llm = LLM(args)

# Generate prompts
np.random.seed(args.seed)

# Load data
prompt_data = json.load(open(args.prompt_file, encoding='utf-8'))
eval_data = json.load(open(args.eval_file, encoding='utf-8'))

In [91]:
# retrieval
def remove_citations(sent):
    return re.sub(r"\[\d+", "", re.sub(r" \[\d+", "", sent)).replace(" |", "").replace("]", "")

def find_external_docs_idx(question, external):
    # find the index of externaol doc that matches the question
    for idx, item in enumerate(external):
        if item['question'] == question:
            return idx
    return None

if args.eval_file is not None:
    external = json.load(open(args.eval_file))

# Load retrieval model
if "gtr" in args.retriever:
    from sentence_transformers import SentenceTransformer
    gtr_model = SentenceTransformer(f'sentence-transformers/{args.retriever}', 
                                    device=args.retriever_device)

2024-03-05 19:43:21,942 - INFO - Load pretrained SentenceTransformer: sentence-transformers/gtr-t5-large


In [92]:
# class SoPrompt():
#     def __init__(self, 
#                  llm,
#                  args,
#                  prompt_data,
#                  get_groupgen_prompt=None,
#                  get_grouppred_prompt=None,
#                  get_prompt=None
#                 ):

soprompt = SoPrompt(llm, args, prompt_data)

2024-03-05 19:43:24,478 - INFO - Load pretrained SentenceTransformer: sentence-transformers/gtr-t5-large


In [93]:
def show_example(idx, mode='closedbook_phc_repred'):
    soprompt_output = soprompt.generate(eval_data[idx], mode=mode, 
                                        return_dict=True, external=external)
    print(f'===== MODE: {mode} =====')
    print('-----groupgen_prompt-----')
    print(soprompt_output['groupgen_prompt'])
    print('-----grouppred_prompt-----')
    print(soprompt_output['grouppred_prompt'])
    print('-----groups-----')
    print(soprompt_output['groups'])
    print('-----groups_str-----')
    print(soprompt_output['groups_str'])
    print('-----question-----')
    print(soprompt_output['question'])
    print('-----answer-----')
    print(soprompt_output['answer'])
    print('-----generation-----')
    print(soprompt_output['generation'])
    return soprompt_output

def show_example_vanilla(idx):
    vanilla_output = soprompt.generate(eval_data[idx], mode='vanilla')
    print('-----vanilla_output-----')
    print(vanilla_output)
    return vanilla_output

def show_example_all(idx):
    show_example(idx, mode='closedbook_phc_repred')
    show_example(idx, mode='groupgen_pred')
    show_example_vanilla(idx)

In [96]:
for idx in range(10):
    closedbook_phc_repred_output = soprompt.generate(eval_data[idx], mode='closedbook_phc_repred',
                                          return_dict=True,
                                         external=external)
    print(idx, len(closedbook_phc_repred_output['best_docs']), len(closedbook_phc_repred_output['posthoc_docs']))

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:19,937 - INFO - Running AutoAIS...
2024-03-05 19:44:20,025 - INFO - Running AutoAIS...
2024-03-05 19:44:20,110 - INFO - Running AutoAIS...
2024-03-05 19:44:20,195 - INFO - Running AutoAIS...


0 4 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:24,479 - INFO - Running AutoAIS...
2024-03-05 19:44:24,544 - INFO - Running AutoAIS...


1 2 1


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:27,877 - INFO - Running AutoAIS...
2024-03-05 19:44:27,945 - INFO - Running AutoAIS...


2 2 2


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:31,536 - INFO - Running AutoAIS...
2024-03-05 19:44:31,601 - INFO - Running AutoAIS...


3 2 1


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:34,778 - INFO - Running AutoAIS...


4 1 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:38,107 - INFO - Running AutoAIS...


5 1 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:41,371 - INFO - Running AutoAIS...


6 1 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:43,471 - INFO - Running AutoAIS...


7 1 1


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:47,017 - INFO - Running AutoAIS...
2024-03-05 19:44:47,101 - INFO - Running AutoAIS...


8 2 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:44:49,895 - INFO - Running AutoAIS...


9 1 0


In [94]:
show_example_all(1)

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:43:28,297 - INFO - Running AutoAIS...
2024-03-05 19:43:28,365 - INFO - Running AutoAIS...
Token indices sequence length is longer than the specified maximum sequence length for this model (1818 > 1024). Running this sequence through the model will result in indexing errors


===== MODE: closedbook_phc_repred =====
-----groupgen_prompt-----
Instruction: Find groups of sentences from the documents that are useful for answering the given question. Each group should contain sentences from at least one and at most three documents, and should be standalone and can be understood out of context. Always copy the whole sentence, without modification, and add the document id it comes from, using [1][2][3]. If multiple sentences from one or more documents are stating the same fact, only use a minimum sufficient subset of sentences to form the group.

Question: When did the us break away from england?

Document [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coordinat

In [59]:
closedbook_phc_repred_output = soprompt.generate(eval_data[1], mode='closedbook_phc_repred',
                                          return_dict=True,
                                         external=external)

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 19:19:22,732 - INFO - Running AutoAIS...
2024-03-05 19:19:22,799 - INFO - Running AutoAIS...
Token indices sequence length is longer than the specified maximum sequence length for this model (1818 > 1024). Running this sequence through the model will result in indexing errors


In [60]:
closedbook_phc_repred_output.keys()

dict_keys(['qa_pairs', 'wikipages', 'annotations', 'sample_id', 'question', 'docs', 'answer', 'original_docs', 'posthoc_docs', 'entail', 'entail_prec', 'groups', 'groups_str', 'groupgen_prompt', 'groupgen_prompt_len', 'generation', 'grouppred_prompt', 'grouppred_prompt_len', 'sents', 'postcite_generation', 'best_doc_ids', 'best_docs'])

In [61]:
len(closedbook_phc_repred_output['best_docs']), len(closedbook_phc_repred_output['posthoc_docs'])

(2, 1)

In [None]:
closedbook_phc_repred_output

In [55]:
closedbook_output = soprompt.generate(eval_data[0], mode='closedbook')

In [46]:
closedbook_output

'The player with the highest number of goals in world football is disputed, as different sources have different statistics. However, according to the Guinness World Records, the player with the most official goals in football history is Josef Bican, who scored an estimated 805 goals in 530 matches [1]. Other sources claim that Brazilian player Pelé scored 1281 goals in 1363 games, including unofficial matches [2]. Another player often mentioned in the discussion is Lionel Messi, who has scored over 700 goals in his career [3].'

In [68]:
closedbook_phc_output = soprompt.generate(eval_data[0], mode='closedbook_phc',
                                          return_dict=True,
                                         external=external)

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [70]:
closedbook_phc_output.keys()

dict_keys(['qa_pairs', 'wikipages', 'annotations', 'sample_id', 'question', 'docs', 'answer', 'postcite_generation', 'best_doc_ids', 'best_docs'])

In [71]:
closedbook_phc_output['postcite_generation']

'[1] As of October 2021, the football player with the highest number of goals in the world is Lionel Messi, with a total of 754 goals. [74] He surpassed the previous record held by Brazilian player Pele in December 2020.'

In [72]:
closedbook_phc_output['best_docs']

[{'id': '6669150',
  'title': 'Argentina–Brazil football rivalry',
  'text': '"Football Player of the Century", by IFFHS International Federation of Football History and Statistics, 1999, "South America Football Player of the Century", by IFFHS International Federation of Football History and Statistics. Pelé\'s 1281 goals are recognized by FIFA as the highest total achieved by a professional footballer, although the Soccer Statistic Foundation (rssf) recognizes only 767 goals in official mode, occupying the third place after Josef Bican (805) and Romario (772). For his part, Maradona has been named the best soccer player in World Cup history both by The Times and FourFourTwo, publication that also rewarded him as the "Best',
  'score': 0.73388671875,
  'summary': 'Pelé holds the record for the highest total goals achieved by a professional footballer with 1281 goals, recognized by FIFA. However, the Soccer Statistic Foundation recognizes only 767 goals in official mode, with Josef Bic

In [13]:
vanilla_output = soprompt.generate(eval_data[0], mode='vanilla')

Token indices sequence length is longer than the specified maximum sequence length for this model (1739 > 1024). Running this sequence through the model will result in indexing errors


In [14]:
vanilla_output

'According to FIFA, Pelé has scored the highest total of 1281 goals as a professional footballer, although the Soccer Statistic Foundation (rssf) recognizes only 767 goals in official mode, occupying the third place after Josef Bican (805) and Romario (772) [1]. However, the Football Association of Zambia claimed that Godfrey Chitalu scored 116 goals during the 1972 calendar year and 107 during the 1972 season, which is the highest official tally claimed by a national football association [2][3].'

In [15]:
eval_data[0]['question']

'Who has the highest goals in world football?'

In [16]:
external_idx = find_external_docs_idx(eval_data[0]['question'], external)
doc_list = external[external_idx]['docs']

In [18]:
searcher = SearcherWithinDocs(doc_list, args.retriever, model=gtr_model, device=args.retriever_device)

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

In [28]:
eva_item = eval_data[0]
output = vanilla_output.strip().split("\n")[0] # Remove new lines and content after
output = vanilla_output.replace("<|im_end|>", "")
if "qampari" in args.dataset_name:
    sents = [eva_item['question'] + ' ' + x.strip() for x in vanilla_output.rstrip(".").split(",")]
else:
    sents = sent_tokenize(output)

new_output = ""
for sent in sents:
    original_ref = [int(r[1:])-1 for r in re.findall(r"\[\d+", sent)] 

    if len(original_ref) == 0 or args.overwrite:
        print("\n-----")
        print("Original sentence:", sent)
        print("Original ref:", original_ref)
        sent = remove_citations(sent)
        best_doc_id = searcher.search(sent)
        print("New ref:", best_doc_id)
        sent = f"[{best_doc_id+1}] " + sent
        print("New sentence:", sent)
        if "qampari" in args.dataset_name:
            new_output += sent.replace(eva_item['question'], '').strip() + ", "
        else:
            new_output += sent + " "
    else:
        if "qampari" in args.dataset_name:
            new_output += sent.replace(eva_item['question'], '').strip() + ", "
        else:
            new_output += sent + " "

vanilla_output_phc = new_output.rstrip().rstrip(",")
print("Final output: " + vanilla_output_phc)


-----
Original sentence: According to FIFA, Pelé has scored the highest total of 1281 goals as a professional footballer, although the Soccer Statistic Foundation (rssf) recognizes only 767 goals in official mode, occupying the third place after Josef Bican (805) and Romario (772) [1].
Original ref: [0]


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

New ref: 0
New sentence: [1] According to FIFA, Pelé has scored the highest total of 1281 goals as a professional footballer, although the Soccer Statistic Foundation (rssf) recognizes only 767 goals in official mode, occupying the third place after Josef Bican (805) and Romario (772).

-----
Original sentence: However, the Football Association of Zambia claimed that Godfrey Chitalu scored 116 goals during the 1972 calendar year and 107 during the 1972 season, which is the highest official tally claimed by a national football association [2][3].
Original ref: [1, 2]


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

New ref: 1
New sentence: [2] However, the Football Association of Zambia claimed that Godfrey Chitalu scored 116 goals during the 1972 calendar year and 107 during the 1972 season, which is the highest official tally claimed by a national football association.
Final output: [1] According to FIFA, Pelé has scored the highest total of 1281 goals as a professional footballer, although the Soccer Statistic Foundation (rssf) recognizes only 767 goals in official mode, occupying the third place after Josef Bican (805) and Romario (772). [2] However, the Football Association of Zambia claimed that Godfrey Chitalu scored 116 goals during the 1972 calendar year and 107 during the 1972 season, which is the highest official tally claimed by a national football association.


In [195]:
soprompt_output_0 = show_example(0)

-----groupgen_prompt-----
Instruction: Find groups of sentences from the documents that are useful for answering the given question. Each group should contain sentences from at least one and at most three documents, and should be standalone and can be understood out of context. Always copy the whole sentence, without modification, and add the document id it comes from, using [1][2][3]. If multiple sentences from one or more documents are stating the same fact, only use a minimum sufficient subset of sentences to form the group.

Question: When did the us break away from england?

Document [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coordinating and training personnel from the U.S.

In [196]:
vanilla_output_0 = show_example_vanilla(0)

-----vanilla_output-----
According to FIFA, Pelé has the highest total of goals achieved by a professional footballer, with 1281 goals [1]. However, the Soccer Statistic Foundation (rssf) recognizes only 767 goals in official mode, placing Pelé in third place after Josef Bican (805) and Romario (772) [1]. While the Football Association of Zambia claims that Godfrey Chitalu scored 116 goals during the 1972 calendar year and 107 during the 1972 season, FIFA has not ratified this claim as they do not keep statistical track of domestic competitions [2][3].


In [197]:
soprompt_output_1 = show_example(1)

-----groupgen_prompt-----
Instruction: Find groups of sentences from the documents that are useful for answering the given question. Each group should contain sentences from at least one and at most three documents, and should be standalone and can be understood out of context. Always copy the whole sentence, without modification, and add the document id it comes from, using [1][2][3]. If multiple sentences from one or more documents are stating the same fact, only use a minimum sufficient subset of sentences to form the group.

Question: When did the us break away from england?

Document [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coordinating and training personnel from the U.S.

In [198]:
vanilla_output_1 = show_example_vanilla(1)

-----vanilla_output-----
The original artist of "The Sound of Silence" is the American music duo Simon & Garfunkel [1].


In [199]:
soprompt_output_2 = show_example(2)

-----groupgen_prompt-----
Instruction: Find groups of sentences from the documents that are useful for answering the given question. Each group should contain sentences from at least one and at most three documents, and should be standalone and can be understood out of context. Always copy the whole sentence, without modification, and add the document id it comes from, using [1][2][3]. If multiple sentences from one or more documents are stating the same fact, only use a minimum sufficient subset of sentences to form the group.

Question: When did the us break away from england?

Document [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coordinating and training personnel from the U.S.

In [None]:
vanilla_output_2 = show_example_vanilla(2)

In [6]:
# Generate the demonstration part
head_prompt = ""
train_ids = np.random.choice(len(prompt_data["demos"]), args.shot, replace=False)
for train_id in train_ids:
    train_item = prompt_data["demos"][train_id]
    ndoc = args.ndoc
    if args.no_doc_in_demo:
        ndoc = 0
    elif args.fewer_doc_in_demo:
        assert args.ndoc_in_demo is not None
        ndoc = args.ndoc_in_demo
    head_prompt += make_demo(
        train_item, prompt=prompt_data["demo_prompt"], ndoc=ndoc, doc_prompt=prompt_data["doc_prompt"], 
        instruction=prompt_data["instruction"], use_shorter=args.use_shorter 
    )
    head_prompt += prompt_data["demo_sep"]

In [16]:
prompt_data["demo_prompt"]

'{INST}\n\nQuestion: {Q}\n\n{D}\nAnswer: {A}'

In [11]:
head_prompt

'Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.\n\nQuestion: When did the us break away from england?\n\nDocument [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coordinating and training personnel from the U.S. Navy, U.S. Army and U.S. Marine Corps remained 

In [13]:
make_demo(
        train_item, prompt=prompt_data["demo_prompt"], ndoc=ndoc, doc_prompt=prompt_data["doc_prompt"], 
        instruction=prompt_data["instruction"], use_shorter=args.use_shorter 
    )

'Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.\n\nQuestion: Who played galen in planet of the apes?\n\nDocument [1](Title: Planet of the Apes): installment. Jacobs died on June 27, 1973, bringing an end to the APJAC Productions era of the "Planet of the Apes" franchise. Former Fox executive Stan Hough took over as producer for the television project, titled "Planet of the Apes". CBS picked up the series for its 1974 autumn lineup. Ron Harper and James Naughton played Alan Virdon and Peter Burke, two 20th-century American astronauts who pass through a time war

In [14]:
eval_data[0].keys()

dict_keys(['qa_pairs', 'wikipages', 'annotations', 'sample_id', 'question', 'docs', 'answer', 'prompt', 'output'])

In [15]:
 # Sample quick test
if args.quick_test is not None:
    eval_ids = np.random.choice(len(eval_data), args.quick_test, replace=False)
    eval_data = [eval_data[int(idx)] for idx in eval_ids]

logger.info("Generating prompts...") 
incomplete_doc_list = 0 # For some questions there might be fewer than ndoc documents
for idx, eval_item in enumerate(tqdm(eval_data)):
    eval_data[idx]['prompt'] = head_prompt + make_demo(
        eval_item, prompt=prompt_data["demo_prompt"], ndoc=args.ndoc, doc_prompt=prompt_data["doc_prompt"],
        instruction=prompt_data["instruction"], use_shorter=args.use_shorter, 
        test=True
    )
    print(eval_data[idx]['prompt'])
    doc_list = get_shorter_text(eval_item, eval_item["docs"], args.ndoc, args.use_shorter) if args.use_shorter is not None else eval_item["docs"][:args.ndoc]
    if not args.retrieve_in_all_docs:
        # If --retrieve_in_all_docs, we keep the original docs and do not trim them by ndoc
        # Otherwise, take the new docs (truncated by ndoc and filtered if using summary/extraction)
        eval_data[idx]['docs'] = doc_list
    if len(doc_list) < args.ndoc:
        incomplete_doc_list += 1
logger.info("Done.")
if incomplete_doc_list > 0:
    logger.warning(f"There are {incomplete_doc_list} questions that have incomplete document list (may due to a lot of them are filtered out by summary/extraction).")


2024-02-27 20:22:48,185 - INFO - Generating prompts...
100%|██████████| 1/1 [00:00<00:00, 15196.75it/s]
2024-02-27 20:22:48,187 - INFO - Done.


Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.

Question: When did the us break away from england?

Document [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coordinating and training personnel from the U.S. Navy, U.S. Army and U.S. Marine Corps remained in Sa

In [8]:
# Load retriever for interactive search 
if args.interactive and args.interactive_query == "search" and "gtr" in args.retriever:
    from sentence_transformers import SentenceTransformer
    gtr_model = SentenceTransformer(f'sentence-transformers/{args.retriever}', device=args.retriever_device)
    from searcher import SearcherWithinDocs

In [10]:
for idx, item in enumerate(tqdm(eval_data)):
    prompt = item['prompt']
    prompt_len = len(llm.tokenizer.tokenize(prompt))

    if idx == 0:
        print(prompt)

    output_array = []
    for _ in range(args.num_samples):
        if args.interactive:
            print("============ Interactive =============")
            output_answer = ""
            doc_list = item['docs']

            interactive_prompt = prompt.rstrip() + "\n" # Start a new line
            inline_doc = ""
            num_turn = 0

            doc_history = []
            while True:
                # For each action, it should end at the new line
                # Three possible actions
                # - Check: Document [1][2][3] / search query
                # - Output: output 
                # - End
                num_turn += 1
                new_prompt = interactive_prompt + inline_doc
                new_prompt_len = len(llm.tokenizer.tokenize(new_prompt))

                if idx == 0:
                    print(f"-------------- Step {num_turn} prompt --------------")
                    print(new_prompt)
                    print("-----------------------------")

                output = llm.generate(new_prompt, min(args.max_new_tokens, args.max_length-new_prompt_len), stop=["\n", "\n\n"])

                if len(inline_doc) > 0:
                    output = "Output: " + output # "Output: " was included in inline_doc
                inline_doc = "" # Delete inline_doc after use
                interactive_prompt += output + "\n"
                logger.info(f"Model output: \"{output}\"")

                if output.strip().lower()[:3] == "end":
                    # Model decides to end the generation
                    break
                elif "sorry" in output.lower() and ("relevant document" in output.lower() or "relevant information" in output.lower()) or "none of the documents" in output.lower():
                    # Instruction-tuned model may abstain from answer the question
                    break
                elif output.strip().lower()[:5] == "check" or output.strip().lower()[:6] == "search":
                    # Checkout or search documents
                    if args.interactive_query == "search":
                        query = output.replace("Search:", "").replace("search:", "").strip()
                        if len(doc_list) == 0:
                            show_doc_ids = []
                        else:
                            searcher = SearcherWithinDocs(doc_list, args.retriever, model=gtr_model, device=args.retriever_device)
                            show_doc_ids = [int(searcher.search(query))]
                    elif args.interactive_query == "doc_id":
                        show_doc_ids = [int(r[1:])-1 for r in re.findall(r"\[\d+", output)] # In text citation id starts from 1
                        show_doc_ids = [doc_id for doc_id in show_doc_ids if doc_id < len(doc_list) and doc_id >= 0]
                        show_doc_ids = show_doc_ids[:args.max_doc_show] # Avoiding showing too many documents
                    else:
                        raise NotImplementedError

                    inline_doc = "".join([make_doc_prompt(doc_list[doc_id], doc_id, prompt_data["doc_prompt"]) for doc_id in show_doc_ids])
                    inline_doc += "Output:" # Force the model to generate output in the next step
                    doc_history.append(show_doc_ids)
                elif output.strip().lower()[:6] == "output":
                    output = output.strip().replace("Output:", "").strip()
                    if args.force_cite_show:
                        output = remove_citations(output)
                        if len(doc_history) == 0:
                            logger.warn("No doc history??")
                        else:
                            # Just cite whatever documents the model has seen in the last step
                            if "qampari" in args.eval_file:
                                output = ", ".join(["".join([f"[{doc+1}]" for doc in doc_history[-1]]) + " " + entity.strip() for entity in output.rstrip().rstrip(",").split(",")]) + ", "
                            else:
                                output = " ".join(["".join([f"[{doc+1}]" for doc in doc_history[-1]]) + " " + o for o in sent_tokenize(output)]) + "."
                    output_answer += " " + output 
                else:
                    # Sometimes model starts to output random things.
                    break

                if num_turn >= args.max_turn:
                    logger.warning("Reach maximum number of turns. Terminate now.")
                    break

            if "qampari" in args.eval_file:
                output_answer = output_answer.rstrip().rstrip(",")
            output_array.append(output_answer)
            item['prompt'] = interactive_prompt
            item['doc_history'] = doc_history
        else: 
            output_array.append(llm.generate(prompt, min(args.max_new_tokens, args.max_length-prompt_len)))
            item['prompt'] = prompt

        output_array[-1] = output_array[-1].replace("<|im_end|>", "").rstrip()
        if output_array[-1].endswith("End."):
            output_array[-1] = output_array[-1][:-len("End.")]

        logger.info(f"Prompt length={prompt_len}")
        logger.info(f"Question: {item['question']}")
        logger.info(f"Gold answer: {item['answer']}")
        logger.info(f"Final model output: {output_array[-1]}") 

    item['output'] = output_array if len(output_array) > 1 else output_array[0]

  0%|          | 0/1 [00:00<?, ?it/s]

prompt Instruction: Write an accurate, engaging, and concise answer for the given question using only the provided search results (some of which might be irrelevant) and cite them properly. Use an unbiased and journalistic tone. Always cite for any factual claim. When citing several search results, use [1][2][3]. Cite at least one document and at most three documents in each sentence. If multiple documents support the sentence, only cite a minimum sufficient subset of the documents.

Question: When did the us break away from england?

Document [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coordinating and training personnel from the U.S. Navy, U.S. Army and U.S. Marine Corps remaine

2024-02-27 20:11:26,396 - INFO - Prompt length=1749
2024-02-27 20:11:26,396 - INFO - Question: The festival of holi marks the end of winter and the beginning of?
2024-02-27 20:11:26,396 - INFO - Gold answer: Holi ( /ˈhoʊliː/) is a popular ancient Indian festival, also known as the "Festival of Love", the "Festival of Colours" and the "Festival of Spring". Holi celebrates the arrival of spring, the end of winter, the blossoming of love and for many, it is a festive day to meet others, play and laugh, forget and forgive, and repair broken relationships.
2024-02-27 20:11:26,397 - INFO - Final model output: The festival of Holi marks the end of winter and the beginning of spring [1][2][3]. It is a Hindu spring festival celebrated predominantly in India, but has also spread to other areas of Asia and parts of the Western world through the diaspora from the Indian subcontinent [1]. It is also celebrated as a thanksgiving for a good harvest [1].
100%|██████████| 1/1 [00:00<00:00,  1.13it/s]
