In [1]:
%cd ..

/shared_data0/weiqiuy/github/ALCE


In [2]:
import os
import openai
import json
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
from transformers import AutoTokenizer
import time
import string
import numpy as np
import torch
import re
import sys
# sys.path.append('..')
from searcher import SearcherWithinDocs
import yaml
from utils import *
import nltk
nltk.download('punkt')
from nltk import sent_tokenize
from run import LLM
from soprompt import SoPrompt


# Define an argparse.ArgumentParser instance within a Jupyter notebook environment.
import argparse

# Define a function to encapsulate the argparse setup to mimic the script behavior in a notebook.
def setup_argparse():
    parser = argparse.ArgumentParser(description="Setup for an interactive command line interface simulation.")

    # Adding arguments as specified
    parser.add_argument("--config", type=str, default=None, help="Path to the config file")
    parser.add_argument("--prompt_file", type=str, help="Path to the prompt file")
    parser.add_argument("--eval_file", type=str, help="Path to the eval file")
    parser.add_argument("--quick_test", type=int, default=None, help="Quickly test a few examples")
    parser.add_argument("--ndoc", type=int, help="Number of documents")
    parser.add_argument("--shot", type=int, help="Number of ICL demonstrations")
    parser.add_argument("--seed", type=int, default=42, help="Seed for the random number generator")
    parser.add_argument("--no_doc_in_demo", type=bool, default=False, help="Whether to remove the documents in the demos")
    parser.add_argument("--fewer_doc_in_demo", type=bool, default=False, help="Whether to use fewer documents in the demos")
    parser.add_argument("--ndoc_in_demo", type=int, default=None, help="When using --fewer_doc_in_demo, use this to designate how many docs in demo")
    parser.add_argument("--dataset_name", type=str, help="Name of the dataset (for saving)")
    parser.add_argument("--tag", type=str, help="Tag of run (for saving)")
    parser.add_argument("--model", type=str, help="Model to use")
    parser.add_argument("--openai_api", type=bool, default=False, help="Whether to use OpenAI API")
    parser.add_argument("--azure", action="store_true", default=False, help="Azure openai API")
    parser.add_argument("--temperature", type=float, default=0.5, help="Temperature for decoding")
    parser.add_argument("--top_p", type=float, default=1.0, help="Nucleus sampling top-p")
    parser.add_argument("--max_new_tokens", type=int, default=300, help="Max number of new tokens to generate in one step")
    parser.add_argument("--max_length", type=int, default=2048, help="Max length the model can take. Should set properly wrt the model to avoid position overflow.")
    parser.add_argument("--num_samples", type=int, default=1, help="Sample multiple answers.")
    parser.add_argument("--use_shorter", type=str, default=None, help="Whether to use summary data or extraction data for documents. Option: None, `summary`, `extraction`")
    parser.add_argument("--interactive", type=bool, default=False, help="Whether to run in interactive mode")
    parser.add_argument("--interactive_query", type=str, default=None, help="The query to use in interactive mode, either `doc_id` (corresponding to interact in paper) or `search` (corresponding to inlinesearch in paper).")
    parser.add_argument("--retriever", type=str, default=None, help="When using interactive search mode, which retriever to use. Options: `tfidf`, `gtr-t5-large`")
    parser.add_argument("--retriever_device", type=str, default="cuda", help="Where to put the dense retriever if using. Options: `cuda`, `cpu`")
    parser.add_argument("--retrieve_in_all_docs", type=bool, default=False, help="Retrieve in all documents instead of just top ndoc")
    parser.add_argument("--max_turn", type=int, default=10, help="Max number of all actions")
    parser.add_argument("--max_doc_show", type=int, default=3, help="Max number of documents to show at one time.")
    parser.add_argument("--force_cite_show", type=bool, default=False, help="Force citing the documents that are shown to the model")
    parser.add_argument("--overwrite", action="store_true", help="Overwrite existing citations for posthoc cite")
    
    return parser

# Instantiate the parser setup
parser = setup_argparse()

[nltk_data] Downloading package punkt to /home/runai-home/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /home/runai-home/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [3]:
command_line_args = [
    '--config', 'configs/asqa_turbo_shot2_ndoc3_gtr_default.yaml',
    '--quick_test', '1',
    '--prompt_file', 'prompts/asqa_soprompt.json',
    '--retriever', 'gtr-t5-large',
    '--temperature', '0',
    '--overwrite',
    '--use_shorter', 'extraction'
]

In [4]:
# Example of simulating command line arguments
args = parser.parse_args(args=command_line_args)
config = yaml.safe_load(open(args.config)) if args.config is not None else {}
parser.set_defaults(**config)
args = parser.parse_args(args=command_line_args)
for k in args.__dict__:
    print(f"{k}: {args.__dict__[k]}")

if "turbo" in args.model:
    # ChatGPT has a longer max length
    args.max_length = 4096

if "16k" in args.model:
    args.max_length = 16384
elif "32k" in args.model:
    args.max_length = 32768
elif "turbo" in args.model:
    args.max_length = 4096
elif "gpt-4" in args.model:
    args.max_length = 8192
elif "llama-2" in args.model.lower() or "llama2" in args.model.lower():
    args.max_length = 4096

config: configs/asqa_turbo_shot2_ndoc3_gtr_default.yaml
prompt_file: prompts/asqa_soprompt.json
eval_file: data/asqa_eval_gtr_top100.json
quick_test: 1
ndoc: 3
shot: 2
seed: 42
no_doc_in_demo: False
fewer_doc_in_demo: False
ndoc_in_demo: None
dataset_name: asqa
tag: gtr
model: gpt-3.5-turbo-0301
openai_api: True
azure: False
temperature: 0.0
top_p: 1.0
max_new_tokens: 300
max_length: 2048
num_samples: 1
use_shorter: extraction
interactive: False
interactive_query: None
retriever: gtr-t5-large
retriever_device: cuda
retrieve_in_all_docs: False
max_turn: 10
max_doc_show: 3
force_cite_show: False
overwrite: True


In [5]:
# Load the model or setup the API
llm = LLM(args)

# Generate prompts
np.random.seed(args.seed)

# Load data
prompt_data = json.load(open(args.prompt_file, encoding='utf-8'))
eval_data = json.load(open(args.eval_file, encoding='utf-8'))
if args.eval_file is not None:
    external = json.load(open(args.eval_file))

In [6]:
soprompt = SoPrompt(llm, args, prompt_data)

2024-03-06 19:43:20,502 - INFO - Load pretrained SentenceTransformer: sentence-transformers/gtr-t5-large


OutOfMemoryError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 79.15 GiB total capacity; 894.13 MiB already allocated; 18.19 MiB free; 896.00 MiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Eval

In [None]:
idx = 1
soprompt_output = soprompt.generate(eval_data[idx], doc_mode='comb', pred_mode='groups',
                                        return_dict=True, external=external)
soprompt_output['output'] = soprompt_output['generation']

In [None]:
import copy

data = [copy.deepcopy(soprompt_output)]

In [None]:
from eval import *

eval_parser = argparse.ArgumentParser()
# eval_parser.add_argument("--f", type=str, required=True, help="Output file. Should have field `question`, `output`, (ROUGE) `answer`, \
#                     (accuracy) `qa_pairs`, (AIS) `docs`")
eval_parser.add_argument("--no_rouge", action="store_true", help="Do not evaluate ROUGE score")
eval_parser.add_argument("--qa", action="store_true", help="Use the QA model")
eval_parser.add_argument("--mauve", action="store_true", help="Use the mauve score model")
eval_parser.add_argument("--citations", action="store_true", help="Evaluation with citation")
eval_parser.add_argument("--at_most_citations", type=int, default=3, help="At most take this many documents (mostly for precision)")
eval_parser.add_argument("--claims_nli", action="store_true", help="Use claims for ELI5")

# QAMPARI
eval_parser.add_argument("--cot", action="store_true", help="For QAMPARI, try to find colon and separate the COT and answer listing")

eval_cl_args = [
    '--citations',
    '--qa',
    '--mauve'
]

eval_args = eval_parser.parse_args(args=eval_cl_args)

if "qampari" in args.dataset_name:
    eval_args.no_rouge = True
    eval_args.qa = False
    eval_args.mauve = False
    eval_args.decontext = False
    qampari = True
else:
    qampari = False

In [None]:
def eval_generation(data):
    for i in range(len(data)):
        data[i]['output'] = data[i]['output'].strip().split("\n")[0]
        data[i]['output'] = data[i]['output'].replace("<|im_end|>", "")


    # Remove all citations for all non-AutoAIS evaluation
    normalized_data = copy.deepcopy(data)
    for i in range(len(normalized_data)):
        normalized_data[i]['output'] = remove_citations(normalized_data[i]['output'])

    result = {}
    result['length'] = compute_len(normalized_data)
    result['str_em'], result['str_hit'] = compute_str_em(normalized_data)
    if qampari:
        result.update(compute_qampari_f1(normalized_data, cot=eval_args.cot))
    if not eval_args.no_rouge:
        result['rougeLsum'] = compute_rouge(normalized_data)
    if eval_args.qa:
        result.update(compute_qa(normalized_data))
    if eval_args.mauve:
        result['mauve'] = compute_mauve(normalized_data)
    if eval_args.citations: 
        result.update(compute_autoais(data, qampari=qampari, at_most_citations=eval_args.at_most_citations))
    if eval_args.claims_nli:
        result["claims_nli"] = compute_claims(normalized_data)

    return result

In [None]:
from tqdm.auto import tqdm
def eval_mode(num_data, doc_mode, pred_mode):
    outputs = []
    for idx in tqdm(range(num_data)):
        soprompt_output = soprompt.generate(eval_data[idx], doc_mode=doc_mode, pred_mode=pred_mode,
                                            return_dict=True, external=external)
        soprompt_output['output'] = soprompt_output['generation']
        outputs.append(soprompt_output)
    result = eval_generation(outputs)
    return result, outputs

In [None]:
from soprompt import SOPROMPT_DOC_MODES, SOPROMPT_PRED_MODES
from tqdm.auto import tqdm

num_eval = 10
results = {}
outputs = {}
for doc_mode in tqdm(SOPROMPT_DOC_MODES):
    for pred_mode in tqdm(SOPROMPT_PRED_MODES):
        print('doc_mode', doc_mode, 'pred_mode', pred_mode)
        result, output = eval_mode(10, doc_mode, pred_mode)
        results[f'{doc_mode}-{pred_mode}'] = result
        outputs[f'{doc_mode}-{pred_mode}'] = output

In [None]:
results

In [None]:
import matplotlib.pyplot as plt

# Extract metric names (assuming all entries have the same metrics)
# Extract metrics and methods
methods = list(results.keys())
metrics = list(results[methods[0]].keys())
# Correction needed in the loop to correctly use metric names from the data dictionary

fig, ax = plt.subplots(figsize=(15, 10))

# Number of groups
num_groups = len(metrics)
# Number of methods (bars in each group)
num_methods = len(methods)

# Positions of the groups
group_positions = np.arange(num_groups)

# Width of a single bar
bar_width = 0.1

# Loop over each metric to create a set of bars
for i, method in enumerate(methods):
    performance = [results[method][metric] for metric in metrics]
    positions = [x + (i * bar_width) for x in group_positions]
    
    ax.bar(positions, performance, bar_width, label=method)

ax.set_xlabel('Metrics', fontsize=14)
ax.set_ylabel('Scores', fontsize=14)
ax.set_title('Scores by Metric and Method', fontsize=16)
ax.set_xticks([r + bar_width for r in range(len(metrics))])
ax.set_xticklabels(metrics, rotation=45, ha="right")
ax.legend()

plt.tight_layout()
plt.show()


In [20]:
outputs.keys()

dict_keys(['retrieved-docs', 'retrieved-groups', 'postcited-docs', 'postcited-groups', 'comb-docs', 'comb-groups', 'none-docs', 'none-groups'])

In [25]:
outputs['retrieved-docs'][idx].keys()

dict_keys(['qa_pairs', 'wikipages', 'annotations', 'sample_id', 'question', 'docs', 'answer', 'doc_mode', 'pred_mode', 'ndoc', 'generation', 'prompt', 'prompt_len', 'output'])

In [None]:
outputs['retrieved-docs']

In [24]:
idx = 1


In [15]:
soprompt_output = soprompt.generate(eval_data[0], mode='closedbook', 
                                            return_dict=True, external=external)
print(soprompt_output.keys())

dict_keys(['qa_pairs', 'wikipages', 'annotations', 'sample_id', 'question', 'docs', 'answer', 'ndoc'])


## Show

In [7]:
def show_example_ggp(idx, mode='closedbook_postcite_repred'):
    assert mode in ['closedbook_postcite_repred', 'groupgen_pred']
    soprompt_output = soprompt.generate(eval_data[idx], mode=mode, 
                                        return_dict=True, external=external)
    print(f'===== MODE: {mode} =====')
    print('-----groupgen_prompt-----')
    print(soprompt_output['groupgen_prompt'])
    print('-----grouppred_prompt-----')
    print(soprompt_output['grouppred_prompt'])
    print('-----groups-----')
    print(soprompt_output['groups'])
    print('-----groups_str-----')
    print(soprompt_output['groups_str'])
    print('-----question-----')
    print(soprompt_output['question'])
    print('-----answer-----')
    print(soprompt_output['answer'])
    print('-----generation-----')
    print(soprompt_output['generation'])
    if mode == 'closedbook_postcite_repred':
        print('-----postcite_best_docs-----')
        print('\n\n'.join([doc['text'] for doc in soprompt_output['postcite_best_docs']]))
        print('-----original_docs-----')
        print('\n\n'.join([doc['text'] for doc in soprompt_output['original_docs'][:args.ndoc]]))
        print('-----entail-----')
        print(soprompt_output['entail'])
        print('-----entail_prec-----')
        print(soprompt_output['entail_prec'])
        print('-----num_postcite_overlap-----')
        print(soprompt_output['num_postcite_overlap'])
        print('-----num_new_used-----')
        print(soprompt_output['num_new_used'])
    return soprompt_output

def show_example_simple(idx, mode='vanilla'):
    assert mode in ['vanilla', 'closedbook']
    vanilla_output = soprompt.generate(eval_data[idx], mode='vanilla')
    print(f'===== MODE: {mode} =====')
    print('-----generation-----')
    print(vanilla_output)
    return vanilla_output

def show_example_all(idx):
    show_example_ggp(idx, mode='closedbook_postcite_repred')
    show_example_ggp(idx, mode='groupgen_pred') 
    show_example_simple(idx, mode='vanilla')
    show_example_simple(idx, mode='closedbook')

In [9]:
for idx in range(10):
    soprompt_output = soprompt.generate(eval_data[idx], mode='closedbook_postcite_repred', 
                                        return_dict=True, external=external)
    print(idx, 'num_postcite_overlap', soprompt_output['num_postcite_overlap'],
         'num_new_used', soprompt_output['num_new_used'])

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:30,490 - INFO - Running AutoAIS...
2024-03-05 22:33:30,574 - INFO - Running AutoAIS...
2024-03-05 22:33:30,658 - INFO - Running AutoAIS...
2024-03-05 22:33:30,740 - INFO - Running AutoAIS...


0 num_postcite_overlap 0 num_new_used 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:34,631 - INFO - Running AutoAIS...
2024-03-05 22:33:34,692 - INFO - Running AutoAIS...


1 num_postcite_overlap 1 num_new_used 1


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:37,754 - INFO - Running AutoAIS...
2024-03-05 22:33:37,817 - INFO - Running AutoAIS...


2 num_postcite_overlap 1 num_new_used 2


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:41,289 - INFO - Running AutoAIS...
2024-03-05 22:33:41,357 - INFO - Running AutoAIS...


3 num_postcite_overlap 1 num_new_used 1


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:44,489 - INFO - Running AutoAIS...


4 num_postcite_overlap 0 num_new_used 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:47,686 - INFO - Running AutoAIS...


5 num_postcite_overlap 0 num_new_used 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:50,918 - INFO - Running AutoAIS...


6 num_postcite_overlap 0 num_new_used 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:52,967 - INFO - Running AutoAIS...


7 num_postcite_overlap 1 num_new_used 1


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:56,643 - INFO - Running AutoAIS...
2024-03-05 22:33:56,747 - INFO - Running AutoAIS...


8 num_postcite_overlap 0 num_new_used 0


Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:33:59,443 - INFO - Running AutoAIS...


9 num_postcite_overlap 0 num_new_used 0


In [8]:
idx = 1
output = show_example_ggp(idx, mode='closedbook_postcite_repred')
# output.keys()

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:31:34,680 - INFO - Loading AutoAIS model...


Loading checkpoint shards:   0%|          | 0/5 [00:00<?, ?it/s]

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
2024-03-05 22:32:05,001 - INFO - Running AutoAIS...
2024-03-05 22:32:05,121 - INFO - Running AutoAIS...
Token indices sequence length is longer than the specified maximum sequence length for this model (1818 > 1024). Running this sequence through the model will result in indexing errors


===== MODE: closedbook_postcite_repred =====
-----groupgen_prompt-----
Instruction: Find groups of sentences from the documents that are useful for answering the given question. Each group should contain sentences from at least one and at most three documents, and should be standalone and can be understood out of context. Always copy the whole sentence, without modification, and add the document id it comes from, using [1][2][3]. If multiple sentences from one or more documents are stating the same fact, only use a minimum sufficient subset of sentences to form the group.

Question: When did the us break away from england?

Document [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coor

In [11]:
show_example_all(1)

Batches:   0%|          | 0/4 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

2024-03-05 22:15:37,834 - INFO - Running AutoAIS...
2024-03-05 22:15:37,898 - INFO - Running AutoAIS...


===== MODE: closedbook_postcite_repred =====
-----groupgen_prompt-----
Instruction: Find groups of sentences from the documents that are useful for answering the given question. Each group should contain sentences from at least one and at most three documents, and should be standalone and can be understood out of context. Always copy the whole sentence, without modification, and add the document id it comes from, using [1][2][3]. If multiple sentences from one or more documents are stating the same fact, only use a minimum sufficient subset of sentences to form the group.

Question: When did the us break away from england?

Document [1](Title: United States withdrawal from Saudi Arabia): United States withdrawal from Saudi Arabia Beginning during Operation Desert Shield in August 1990, while preparing for the Gulf War, the United States sent a large troop contingent to Saudi Arabia. After the war, remnant troops, primarily U.S. Air Force personnel, augmented by a smaller number of coor

NameError: name 'postcite_best_docs' is not defined