In [1]:
# %pip install "pyautogen[retrievechat]~=0.2.0b5"

## Set your API Endpoint

The [`config_list_from_json`](https://microsoft.github.io/FLAML/docs/reference/autogen/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file.


In [2]:
import autogen

config_list = autogen.config_list_from_json(
    env_or_file=".config.local",
    file_location=".",
    filter_dict={
        "model": {
            "gpt-4",
            "gpt4",
            "gpt-4-32k",
            "gpt-4-32k-0314",
            "gpt-35-turbo",
            "gpt-3.5-turbo",
        }
    },
)

assert len(config_list) > 0
config_list[0]['model'] = 'gpt-35-turbo'
print("models to use: ", [config_list[i]["model"] for i in range(len(config_list))])

  from .autonotebook import tqdm as notebook_tqdm


models to use:  ['gpt-35-turbo']


## Construct agents for RetrieveChat

We start by initialzing the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`. The system message needs to be set to "You are a helpful assistant." for RetrieveAssistantAgent. The detailed instructions are given in the user message. Later we will use the `RetrieveUserProxyAgent.generate_init_prompt` to combine the instructions and a math problem for an initial prompt to be sent to the LLM assistant.

In [3]:
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
import chromadb

# 1. create an RetrieveAssistantAgent instance named "assistant"
assistant = RetrieveAssistantAgent(
    name="assistant", 
    system_message="You are a helpful assistant.",
    llm_config={
        "timeout": 60,
        "cache_seed": 42,
        "config_list": config_list,
    },
)

# 2. create the RetrieveUserProxyAgent instance named "ragproxyagent"
corpus_file = "https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/corpus.txt"

# Create a new collection for NaturalQuestions dataset
ragproxyagent = RetrieveUserProxyAgent(
    name="ragproxyagent",
    human_input_mode="NEVER",
    max_consecutive_auto_reply=10,
    retrieve_config={
        "task": "qa",
        "docs_path": corpus_file,
        "chunk_token_size": 2000,
        "model": config_list[0]["model"],
        "client": chromadb.PersistentClient(path="/tmp/chromadb"),
        "collection_name": "natural-questions",
        "chunk_mode": "one_line",
        "embedding_model": "all-MiniLM-L6-v2",
        "get_or_create": True,
    },
)

  return torch._C._cuda_getDeviceCount() > 0


### Natural Questions QA

Use RetrieveChat to answer questions for [NaturalQuestion](https://ai.google.com/research/NaturalQuestions) dataset.

We'll first create a new document collection based on all the context corpus, then we select some questions and answer them with RetrieveChat.


In [15]:
import json

queries_file = "https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/queries.jsonl"
!wget -O /tmp/chromadb/queries.jsonl $queries_file
queries = [json.loads(line) for line in open("/tmp/chromadb/queries.jsonl").readlines() if line]
questions = [q["text"] for q in queries]
answers = [q["metadata"]["answer"] for q in queries]
print(questions[:5])
print(answers[:5])
print("Number of questions:", len(questions))

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


--2023-11-22 00:12:47--  https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/queries.jsonl
Resolving huggingface.co (huggingface.co)... 99.84.108.55, 99.84.108.129, 99.84.108.87, ...
Connecting to huggingface.co (huggingface.co)|99.84.108.55|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1380571 (1.3M) [text/plain]
Saving to: ‘/tmp/chromadb/queries.jsonl’


2023-11-22 00:12:47 (40.6 MB/s) - ‘/tmp/chromadb/queries.jsonl’ saved [1380571/1380571]

['what is non controlling interest on balance sheet', 'how many episodes are in chicago fire season 4', 'who sings love will keep us alive by the eagles', 'who is the leader of the ontario pc party', 'where did the last name keith come from']
[["the portion of a subsidiary corporation 's stock that is not owned by the parent corporation"], ['23'], ['Timothy B. Schmit'], ['Patrick Walter Brown'], ['from Keith in East Lothian , Scotland', "from a nickname , derived from the Middle High German kīt , a

In [5]:
from io import StringIO 
import sys

class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self
    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio    # free up some memory
        sys.stdout = self._stdout

In [6]:
import time

retrieve_answers = []
questions_sample = []
answers_sample = []
num_questions = 7000
st = time.time()
for idx, qa_problem in enumerate(questions[:num_questions]):
    if idx % 100 == 0:
        ct = time.time()
        print(f"\nProgress {idx/num_questions*100:.2f}%, Time Used {(ct-st)/3600:.2f} hours\n")
    assistant.reset()
    try:
        with Capturing() as print_output:
            ragproxyagent.initiate_chat(assistant, problem=qa_problem, n_results=10)
        retrieve_answers.append(print_output[-3])
        questions_sample.append(qa_problem)
        answers_sample.append(answers[:num_questions][idx])
    except Exception as e:
        print(e)
        print("Error in problem: ", qa_problem)


Progress 0.00%, Time Used 0.00 hours



	<Table> <Tr> <Th> Film </Th> <Th> Year </Th> <Th> Fuck count </Th> <Th> Minutes </Th> <Th> Uses / mi ...
	<Table> <Tr> <Th> Character </Th> <Th> Ultimate Avengers </Th> <Th> Ultimate Avengers 2 </Th> <Th> I ...
	<Table> <Tr> <Th> Position </Th> <Th> Country </Th> <Th> Town / City </Th> <Th> PM2. 5 </Th> <Th> PM ...
	<Table> <Tr> <Th> Rank </Th> <Th> Country ( or dependent territory ) </Th> <Th> Population </Th> <Th ...
	<Table> <Tr> <Th> Rank </Th> <Th> State </Th> <Th> Gross collections ( in thousands ) </Th> <Th> Rev ...
	<Table> <Tr> <Th> Date </Th> <Th> Province </Th> <Th> Mag . </Th> <Th> MMI </Th> <Th> Deaths </Th> < ...
	<Table> <Tr> <Th> City </Th> <Th> River </Th> <Th> State </Th> </Tr> <Tr> <Td> Gangakhed </Td> <Td>  ...
	<Table> <Tr> <Th> Player </Th> <Th> Pos . </Th> <Th> Team </Th> <Th> Career start </Th> <Th> Career  ...
	<Table> ABO and Rh blood type distribution by country ( population averages ) <Tr> <Th> Country </Th ...
	<Table> <Tr> <Th> </Th> <Th colspan="3"> Tota

Error code: 400 - {'error': {'message': "The response was filtered due to the prompt triggering Azure OpenAI's content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please read our documentation: https://go.microsoft.com/fwlink/?linkid=2198766", 'type': None, 'param': 'prompt', 'code': 'content_filter', 'status': 400, 'innererror': {'code': 'ResponsibleAIPolicyViolation', 'content_filter_result': {'hate': {'filtered': False, 'severity': 'safe'}, 'self_harm': {'filtered': False, 'severity': 'safe'}, 'sexual': {'filtered': False, 'severity': 'safe'}, 'violence': {'filtered': True, 'severity': 'medium'}}}}}
Error in problem:  who played tom on as the world turns

Progress 1.43%, Time Used 0.24 hours

Error code: 400 - {'error': {'message': "The response was filtered due to the prompt triggering Azure OpenAI's content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please

In [7]:
print(retrieve_answers[:5])
print("len(retrieve_answers):", len(retrieve_answers))
print("len(answers_sample):", len(answers_sample))
print("len(questions_sample):", len(questions_sample))

["Non controlling interest or minority interest is the portion of a subsidiary corporation's stock that is not owned by the parent corporation. It is shown as a separate line item on the balance sheet under equity.", 'The fourth season of Chicago Fire has 23 episodes.', 'The Eagles sing Love Will Keep Us Alive.', 'Patrick Walter Brown is the current leader of the Ontario PC Party.', 'The surname Keith has several origins. In some cases it is derived from Keith in East Lothian, Scotland. In other cases, the surname is originated from a nickname, derived from the Middle High German kīt, a word meaning "sprout", "offspring".']
len(retrieve_answers): 6761
len(answers_sample): 6761
len(questions_sample): 6761


In [8]:
# https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#F1
def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))

def compute_f1_recall(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens), int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0, 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec), rec

def get_gold_answers(example):
    """helper function that retrieves all possible true answers from a squad2.0 example"""
    
    gold_answers = [answer["text"] for answer in example.answers if answer["text"]]

    # if gold_answers doesn't exist it's because this is a negative example - 
    # the only correct answer is an empty string
    if not gold_answers:
        gold_answers = [""]
        
    return gold_answers

In [None]:
all_em_scores = []
all_f1_scores = []
all_recall_scores = []
for i in range(len(retrieve_answers)):
    prediction = retrieve_answers[i]
    gold_answers = answers_sample[i]

    em_score = max((compute_exact_match(prediction, answer)) for answer in gold_answers)
    f1_score = max((compute_f1_recall(prediction, answer)[0]) for answer in gold_answers)
    recall_score = max((compute_f1_recall(prediction, answer)[1]) for answer in gold_answers)

    all_em_scores.append(em_score)
    all_f1_scores.append(f1_score)
    all_recall_scores.append(recall_score)

    # if i % 10 == 0 or recall_score < 0.3:
    print(f"Question: {questions_sample[i]}")
    print(f"Prediction: {prediction}")
    print(f"True Answers: {gold_answers}")
    print(f"EM: {em_score} \t F1: {f1_score} \t Recall: {recall_score}")

print("=======================================")
print(f"Average EM: {sum(all_em_scores) / len(all_em_scores)}")
print(f"Average F1: {sum(all_f1_scores) / len(all_f1_scores)}")
print(f"Average Recall: {sum(all_recall_scores) / len(all_recall_scores)}")

In [10]:
import sys

# Save the current sys.stdout for later restoration
original_stdout = sys.stdout

# Redirect sys.stdout to a file
with open('logs-original-all.txt', 'w') as f:
    sys.stdout = f
    
    for qa_problem in questions[:7000]:
        print(f"\n\n>>>>>>>>>>>>>> case: {qa_problem} <<<<<<<<<<<<<<\n\n")
        assistant.reset()
        try:
            ragproxyagent.initiate_chat(assistant, problem=qa_problem, n_results=10)
        except Exception as e:
            print(f"Exception: {e}")

# Restore sys.stdout to its original value
sys.stdout = original_stdout

In [1]:
from analysis_log import main
main("logs-original-all.txt", question_process=None)

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/lijiang1/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/lijiang1/nltk_data...
[nltk_data]   Package punkt is already up-to-date!



Analysis log file: logs-original-all.txt
Total Number of questions: 6775
len_lines=344802
_cnt_update_context=1597


Number of questions: 6761
Average EM: 0.007691169945274368
Average F1: 0.2725437837938005
Average Recall: 0.6939445335402984


Number of questions: 422
Average EM: 0.004739336492890996
Average F1: 0.10725425464574016
Average Recall: 0.26880411462104403


Number of questions: 6339
Average EM: 0.007887679444707366
Average F1: 0.28354744072714694
Average Recall: 0.7222469876787945
