In [None]:
# %pip install flaml[retrievechat]~=2.0.0rc5

## Set your API Endpoint

The [`config_list_from_json`](https://microsoft.github.io/FLAML/docs/reference/autogen/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file.


In [2]:
import os
os.environ["ALL_PROXY"] = ""

In [3]:
from flaml import autogen

config_list = autogen.config_list_from_json(
    env_or_file=".config.local",
    file_location=".",
    filter_dict={
        "model": {
            "gpt-4",
            "gpt4",
            "gpt-4-32k",
            "gpt-4-32k-0314",
            "gpt-35-turbo",
            "gpt-3.5-turbo",
        }
    },
)

assert len(config_list) > 0
config_list[0]['model'] = 'gpt-35-turbo'
print("models to use: ", [config_list[i]["model"] for i in range(len(config_list))])

models to use:  ['gpt-35-turbo']


## Construct agents for RetrieveChat

We start by initialzing the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`. The system message needs to be set to "You are a helpful assistant." for RetrieveAssistantAgent. The detailed instructions are given in the user message. Later we will use the `RetrieveUserProxyAgent.generate_init_prompt` to combine the instructions and a math problem for an initial prompt to be sent to the LLM assistant.

In [4]:
from flaml.autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
from flaml.autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
import chromadb

# 1. create an RetrieveAssistantAgent instance named "assistant"
assistant = RetrieveAssistantAgent(
    name="assistant", 
    system_message="You are a helpful assistant.",
    llm_config={
        "request_timeout": 600,
        "seed": 42,
        "config_list": config_list,
    },
)

# 2. create the RetrieveUserProxyAgent instance named "ragproxyagent"
corpus_file = "https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/corpus.txt"

# Create a new collection for NaturalQuestions dataset
ragproxyagent = RetrieveUserProxyAgent(
    name="ragproxyagent",
    human_input_mode="NEVER",
    max_consecutive_auto_reply=10,
    retrieve_config={
        "task": "qa",
        "docs_path": corpus_file,
        "chunk_token_size": 2000,
        "model": config_list[0]["model"],
        "client": chromadb.PersistentClient(path="/tmp/chromadb"),
        "collection_name": "natural-questions",
        "chunk_mode": "one_line",
        "embedding_model": "all-MiniLM-L6-v2",
    },
)

### Natural Questions QA

Use RetrieveChat to answer questions for [NaturalQuestion](https://ai.google.com/research/NaturalQuestions) dataset.

We'll first create a new document collection based on all the context corpus, then we select some questions and answer them with RetrieveChat.


In [5]:
import json

queries_file = "https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/queries.jsonl"
!wget -O /tmp/chromadb/queries.jsonl $queries_file
queries = [json.loads(line) for line in open("/tmp/chromadb/queries.jsonl").readlines() if line]
questions = [q["text"] for q in queries]
answers = [q["metadata"]["answer"] for q in queries]
print(questions[:5])
print(answers[:5])
print("Number of questions:", len(questions))

--2023-08-25 12:35:08--  https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/queries.jsonl
Resolving huggingface.co (huggingface.co)... 143.204.126.33, 143.204.126.6, 143.204.126.36, ...
Connecting to huggingface.co (huggingface.co)|143.204.126.33|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1380571 (1.3M) [text/plain]
Saving to: ‘/tmp/chromadb/queries.jsonl’


2023-08-25 12:35:11 (656 KB/s) - ‘/tmp/chromadb/queries.jsonl’ saved [1380571/1380571]

['what is non controlling interest on balance sheet', 'how many episodes are in chicago fire season 4', 'who sings love will keep us alive by the eagles', 'who is the leader of the ontario pc party', 'where did the last name keith come from']
[["the portion of a subsidiary corporation 's stock that is not owned by the parent corporation"], ['23'], ['Timothy B. Schmit'], ['Patrick Walter Brown'], ['from Keith in East Lothian , Scotland', "from a nickname , derived from the Middle High German kī

In [6]:
from io import StringIO 
import sys

class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self
    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio    # free up some memory
        sys.stdout = self._stdout

In [13]:
import time

retrieve_answers = []
questions_sample = []
answers_sample = []
num_questions = 7000
st = time.time()
for idx, qa_problem in enumerate(questions[:num_questions]):
    if idx % 100 == 0:
        ct = time.time()
        print(f"\nProgress {idx/num_questions*100:.2f}%, Time Used {(ct-st)/3600:.2f} hours\n")
    assistant.reset()
    try:
        with Capturing() as print_output:
            ragproxyagent.initiate_chat(assistant, problem=qa_problem, n_results=30)
        retrieve_answers.append(print_output[-3])
        questions_sample.append(qa_problem)
        answers_sample.append(answers[:num_questions][idx])
    except Exception as e:
        print(e)
        print("Error in problem: ", qa_problem)


Progress 0.00%, Time Used 0.00 hours

The response was filtered due to the prompt triggering Azure OpenAI’s content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please read our documentation: https://go.microsoft.com/fwlink/?linkid=2198766
Error in problem:  who wrote the theme song for mission impossible

Progress 1.43%, Time Used 0.00 hours

The response was filtered due to the prompt triggering Azure OpenAI’s content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please read our documentation: https://go.microsoft.com/fwlink/?linkid=2198766
Error in problem:  who wrote the theme song for mission impossible

Progress 2.86%, Time Used 0.00 hours

The response was filtered due to the prompt triggering Azure OpenAI’s content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please read our documentation: https

In [14]:
print(retrieve_answers[:5])
print("len(retrieve_answers):", len(retrieve_answers))
print("len(answers_sample):", len(answers_sample))
print("len(questions_sample):", len(questions_sample))

["Non controlling interest on balance sheet refers to the portion of a subsidiary corporation's stock that is not owned by the parent corporation. It is generally less than 50% of outstanding shares and shown as part of equity on the balance sheet.", 'There are 23 episodes in Chicago Fire season 4.', 'The Eagles sing "Love Will Keep Us Alive".', 'Patrick Walter Brown is the leader of the Ontario PC Party.', 'The surname Keith has several origins, including being derived from Keith in East Lothian, Scotland, and from a Middle High German word meaning "sprout" or "offspring".']
len(retrieve_answers): 6661
len(answers_sample): 6661
len(questions_sample): 6661


In [15]:
# https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#F1
def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))

def compute_f1_recall(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens), int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0, 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec), rec

def get_gold_answers(example):
    """helper function that retrieves all possible true answers from a squad2.0 example"""
    
    gold_answers = [answer["text"] for answer in example.answers if answer["text"]]

    # if gold_answers doesn't exist it's because this is a negative example - 
    # the only correct answer is an empty string
    if not gold_answers:
        gold_answers = [""]
        
    return gold_answers

In [16]:
all_em_scores = []
all_f1_scores = []
all_recall_scores = []
for i in range(len(retrieve_answers)):
    prediction = retrieve_answers[i]
    gold_answers = answers_sample[i]

    em_score = max((compute_exact_match(prediction, answer)) for answer in gold_answers)
    f1_score = max((compute_f1_recall(prediction, answer)[0]) for answer in gold_answers)
    recall_score = max((compute_f1_recall(prediction, answer)[1]) for answer in gold_answers)

    all_em_scores.append(em_score)
    all_f1_scores.append(f1_score)
    all_recall_scores.append(recall_score)

    # if i % 10 == 0 or recall_score < 0.3:
    print(f"Question: {questions_sample[i]}")
    print(f"Prediction: {prediction}")
    print(f"True Answers: {gold_answers}")
    print(f"EM: {em_score} \t F1: {f1_score} \t Recall: {recall_score}")

print("=======================================")
print(f"Average EM: {sum(all_em_scores) / len(all_em_scores)}")
print(f"Average F1: {sum(all_f1_scores) / len(all_f1_scores)}")
print(f"Average Recall: {sum(all_recall_scores) / len(all_recall_scores)}")

Question: what is non controlling interest on balance sheet
Prediction: Non controlling interest on balance sheet refers to the portion of a subsidiary corporation's stock that is not owned by the parent corporation. It is generally less than 50% of outstanding shares and shown as part of equity on the balance sheet.
True Answers: ["the portion of a subsidiary corporation 's stock that is not owned by the parent corporation"]
EM: 0 	 F1: 0.4313725490196079 	 Recall: 0.8461538461538461
Question: how many episodes are in chicago fire season 4
Prediction: There are 23 episodes in Chicago Fire season 4.
True Answers: ['23']
EM: 0 	 F1: 0.19999999999999998 	 Recall: 1.0
Question: who sings love will keep us alive by the eagles
Prediction: The Eagles sing "Love Will Keep Us Alive".
True Answers: ['Timothy B. Schmit']
EM: 0 	 F1: 0 	 Recall: 0
Question: who is the leader of the ontario pc party
Prediction: Patrick Walter Brown is the leader of the Ontario PC Party.
True Answers: ['Patrick Wal

In [17]:
for qa_problem in questions[:500]:
    print(f"\n\n>>>>>>>>>>>>>> case: {qa_problem} <<<<<<<<<<<<<<\n\n")
    assistant.reset()
    try:
        ragproxyagent.initiate_chat(assistant, problem=qa_problem, n_results=30)
    except Exception as e:
        print(f"Exception: {e}")



>>>>>>>>>>>>>> case: what is non controlling interest on balance sheet <<<<<<<<<<<<<<


doc_ids:  [['doc_0', 'doc_3334', 'doc_720', 'doc_2732', 'doc_2510', 'doc_5084', 'doc_5068', 'doc_3727', 'doc_1938', 'doc_4689', 'doc_5249', 'doc_1751', 'doc_480', 'doc_3989', 'doc_2115', 'doc_1233', 'doc_2264', 'doc_633', 'doc_2293', 'doc_5274', 'doc_5213', 'doc_3991', 'doc_2880', 'doc_2737', 'doc_1257', 'doc_1748', 'doc_2038', 'doc_4073', 'doc_2876', 'doc_3480']]
[32mAdding doc_id doc_0 to context.[0m
[32mAdding doc_id doc_3334 to context.[0m
[32mAdding doc_id doc_720 to context.[0m
[32mAdding doc_id doc_2732 to context.[0m
[32mAdding doc_id doc_2510 to context.[0m
[32mAdding doc_id doc_5084 to context.[0m
[32mAdding doc_id doc_5068 to context.[0m
[32mAdding doc_id doc_3727 to context.[0m
[32mAdding doc_id doc_1938 to context.[0m
[32mAdding doc_id doc_4689 to context.[0m
[32mAdding doc_id doc_5249 to context.[0m
[32mAdding doc_id doc_1751 to context.[0m
[32mAdding doc_id 