In [1]:
# %pip install "pyautogen[retrievechat]~=0.2.0b5"

## Set your API Endpoint

The [`config_list_from_json`](https://microsoft.github.io/FLAML/docs/reference/autogen/oai/openai_utils#config_list_from_json) function loads a list of configurations from an environment variable or a json file.


In [2]:
import autogen

config_list = autogen.config_list_from_json(
    env_or_file=".config.local",
    file_location=".",
    filter_dict={
        "model": {
            "gpt-4",
            "gpt4",
            "gpt-4-32k",
            "gpt-4-32k-0314",
            "gpt-35-turbo",
            "gpt-3.5-turbo",
        }
    },
)

assert len(config_list) > 0
config_list[0]['model'] = 'gpt-35-turbo'
print("models to use: ", [config_list[i]["model"] for i in range(len(config_list))])

  from .autonotebook import tqdm as notebook_tqdm


models to use:  ['gpt-35-turbo']


## Construct agents for RetrieveChat

We start by initialzing the `RetrieveAssistantAgent` and `RetrieveUserProxyAgent`. The system message needs to be set to "You are a helpful assistant." for RetrieveAssistantAgent. The detailed instructions are given in the user message. Later we will use the `RetrieveUserProxyAgent.generate_init_prompt` to combine the instructions and a math problem for an initial prompt to be sent to the LLM assistant.

In [1]:
import nltk
from nltk.corpus import words
import random

nltk.download('words')
random.seed(42)
def generate_random_word():
    # Get the list of English words from the nltk corpus
    word_list = words.words()

    # Choose a random word from the list
    random_word = random.choice(word_list)

    return random_word


def random_variation(question):
    # Split the question into words
    words = question.split()

    # Randomly select a word to modify
    index_to_modify = random.randint(0, len(words) - 1)

    # Generate a random modification (you can customize this part)
    modification_options = ['add', 'remove', 'replace']
    modification = random.choice(modification_options)

    if modification == 'add':
        # Add a random word
        new_word = generate_random_word()
        words.insert(index_to_modify, new_word)
    elif modification == 'remove':
        # Remove the selected word
        del words[index_to_modify]
    elif modification == 'replace':
        # Replace the selected word with a random word
        new_word = generate_random_word()
        words[index_to_modify] = new_word

    # Join the modified words back into a question
    modified_question = ' '.join(words)

    return modified_question

[nltk_data] Downloading package words to /home/lijiang1/nltk_data...
[nltk_data]   Package words is already up-to-date!


In [4]:
from autogen.agentchat.contrib.retrieve_assistant_agent import RetrieveAssistantAgent
from autogen.agentchat.contrib.retrieve_user_proxy_agent import RetrieveUserProxyAgent
from autogen.retrieve_utils import create_vector_db_from_dir, query_vector_db
import chromadb

# # Use this class to only retrieve docs with random variations
# class MixRetrieveUserProxyAgent(RetrieveUserProxyAgent):
#     def retrieve_docs(self, problem: str, n_results: int = 20, search_string: str = ""):
#         problem = random_variation(problem)
#         if not self._collection or not self._get_or_create:
#             print("Trying to create collection.")
#             self._client = create_vector_db_from_dir(
#                 dir_path=self._docs_path,
#                 max_tokens=self._chunk_token_size,
#                 client=self._client,
#                 collection_name=self._collection_name,
#                 chunk_mode=self._chunk_mode,
#                 must_break_at_empty_line=self._must_break_at_empty_line,
#                 embedding_model=self._embedding_model,
#                 get_or_create=self._get_or_create,
#                 embedding_function=self._embedding_function,
#                 custom_text_split_function=self.custom_text_split_function,
#                 custom_text_types=self._custom_text_types,
#                 recursive=self._recursive,
#             )
#             self._collection = True
#             self._get_or_create = True

#         results = query_vector_db(
#             query_texts=[problem],
#             n_results=n_results,
#             search_string=search_string,
#             client=self._client,
#             collection_name=self._collection_name,
#             embedding_model=self._embedding_model,
#             embedding_function=self._embedding_function,
#         )
#         self._search_string = search_string
#         self._results = results
#         print("doc_ids: ", results["ids"])


# 1. create an RetrieveAssistantAgent instance named "assistant"
assistant = RetrieveAssistantAgent(
    name="assistant", 
    system_message="You are a helpful assistant.",
    llm_config={
        "timeout": 60,
        "cache_seed": 42,
        "config_list": config_list,
    },
)

# 2. create the RetrieveUserProxyAgent instance named "ragproxyagent"
corpus_file = "https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/corpus.txt"

# Create a new collection for NaturalQuestions dataset
ragproxyagent = RetrieveUserProxyAgent(
    name="ragproxyagent",
    human_input_mode="NEVER",
    max_consecutive_auto_reply=10,
    retrieve_config={
        "task": "qa",
        "docs_path": corpus_file,
        "chunk_token_size": 2000,
        "model": config_list[0]["model"],
        "client": chromadb.PersistentClient(path="/tmp/chromadb"),
        "collection_name": "natural-questions",
        "chunk_mode": "one_line",
        "embedding_model": "all-MiniLM-L6-v2",
        "get_or_create": True,
    },
)

  return torch._C._cuda_getDeviceCount() > 0


### Natural Questions QA

Use RetrieveChat to answer questions for [NaturalQuestion](https://ai.google.com/research/NaturalQuestions) dataset.

We'll first create a new document collection based on all the context corpus, then we select some questions and answer them with RetrieveChat.


In [5]:
import json

queries_file = "https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/queries.jsonl"
!wget -O /tmp/chromadb/queries.jsonl $queries_file
queries = [json.loads(line) for line in open("/tmp/chromadb/queries.jsonl").readlines() if line]
questions = [random_variation(q["text"]) for q in queries]
answers = [q["metadata"]["answer"] for q in queries]
print(questions[:5])
print(answers[:5])
print("Number of questions:", len(questions))

--2023-11-21 18:16:34--  https://huggingface.co/datasets/thinkall/NaturalQuestionsQA/resolve/main/queries.jsonl
Resolving huggingface.co (huggingface.co)... 99.84.108.87, 99.84.108.70, 99.84.108.129, ...
Connecting to huggingface.co (huggingface.co)|99.84.108.87|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1380571 (1.3M) [text/plain]
Saving to: ‘/tmp/chromadb/queries.jsonl’


2023-11-21 18:16:34 (36.4 MB/s) - ‘/tmp/chromadb/queries.jsonl’ saved [1380571/1380571]

['what superelevated is non controlling interest on balance sheet', 'how many episodes are duodenary in chicago fire season 4', 'who sings bruckleness will keep us alive by the eagles', 'who is the leader of the ontario pc presagefully party', 'where did the last name keith Anaptomorphidae come from']
[["the portion of a subsidiary corporation 's stock that is not owned by the parent corporation"], ['23'], ['Timothy B. Schmit'], ['Patrick Walter Brown'], ['from Keith in East Lothian , Scotland', "f

In [6]:
from io import StringIO 
import sys

class Capturing(list):
    def __enter__(self):
        self._stdout = sys.stdout
        sys.stdout = self._stringio = StringIO()
        return self
    def __exit__(self, *args):
        self.extend(self._stringio.getvalue().splitlines())
        del self._stringio    # free up some memory
        sys.stdout = self._stdout

In [7]:
import time

retrieve_answers = []
questions_sample = []
answers_sample = []
num_questions = 100
st = time.time()
for idx, qa_problem in enumerate(questions[:num_questions]):
    if idx % 100 == 0:
        ct = time.time()
        print(f"\nProgress {idx/num_questions*100:.2f}%, Time Used {(ct-st)/3600:.2f} hours\n")
    assistant.reset()
    try:
        with Capturing() as print_output:
            ragproxyagent.initiate_chat(assistant, problem=qa_problem, n_results=10)
        retrieve_answers.append(print_output[-3])
        questions_sample.append(qa_problem)
        answers_sample.append(answers[:num_questions][idx])
    except Exception as e:
        print(e)
        print("Error in problem: ", qa_problem)


Progress 0.00%, Time Used 0.00 hours



	<Table> <Tr> <Th> Film </Th> <Th> Year </Th> <Th> Fuck count </Th> <Th> Minutes </Th> <Th> Uses / mi ...
	<Table> <Tr> <Th> Character </Th> <Th> Ultimate Avengers </Th> <Th> Ultimate Avengers 2 </Th> <Th> I ...
	<Table> <Tr> <Th> Position </Th> <Th> Country </Th> <Th> Town / City </Th> <Th> PM2. 5 </Th> <Th> PM ...
	<Table> <Tr> <Th> Rank </Th> <Th> Country ( or dependent territory ) </Th> <Th> Population </Th> <Th ...
	<Table> <Tr> <Th> Rank </Th> <Th> State </Th> <Th> Gross collections ( in thousands ) </Th> <Th> Rev ...
	<Table> <Tr> <Th> Date </Th> <Th> Province </Th> <Th> Mag . </Th> <Th> MMI </Th> <Th> Deaths </Th> < ...
	<Table> <Tr> <Th> City </Th> <Th> River </Th> <Th> State </Th> </Tr> <Tr> <Td> Gangakhed </Td> <Td>  ...
	<Table> <Tr> <Th> Player </Th> <Th> Pos . </Th> <Th> Team </Th> <Th> Career start </Th> <Th> Career  ...
	<Table> ABO and Rh blood type distribution by country ( population averages ) <Tr> <Th> Country </Th ...
	<Table> <Tr> <Th> </Th> <Th colspan="3"> Tota

In [8]:
print(retrieve_answers[:5])
print("len(retrieve_answers):", len(retrieve_answers))
print("len(answers_sample):", len(answers_sample))
print("len(questions_sample):", len(questions_sample))

["Non-controlling interest, also known as minority interest, is the portion of a subsidiary corporation's stock that is not owned by the parent corporation. It is shown on a company's balance sheet as part of equity.", 'There are 23 episodes in Chicago Fire season 4.', 'The Eagles sing "Love Will Keep Us Alive".', "The leader of the Ontario PC Party and Ontario's Leader of the Official Opposition is Patrick Walter Brown.", 'The last name Keith has several origins. In some cases, it is derived from Keith in East Lothian, Scotland. In other cases, the surname is originated from a nickname, derived from the Middle High German kīt, a word meaning "sprout", "offspring".']
len(retrieve_answers): 100
len(answers_sample): 100
len(questions_sample): 100


In [9]:
# https://qa.fastforwardlabs.com/no%20answer/null%20threshold/bert/distilbert/exact%20match/f1/robust%20predictions/2020/06/09/Evaluating_BERT_on_SQuAD.html#F1
def normalize_text(s):
    """Removing articles and punctuation, and standardizing whitespace are all typical text processing steps."""
    import string, re

    def remove_articles(text):
        regex = re.compile(r"\b(a|an|the)\b", re.UNICODE)
        return re.sub(regex, " ", text)

    def white_space_fix(text):
        return " ".join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return "".join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))

def compute_exact_match(prediction, truth):
    return int(normalize_text(prediction) == normalize_text(truth))

def compute_f1_recall(prediction, truth):
    pred_tokens = normalize_text(prediction).split()
    truth_tokens = normalize_text(truth).split()
    
    # if either the prediction or the truth is no-answer then f1 = 1 if they agree, 0 otherwise
    if len(pred_tokens) == 0 or len(truth_tokens) == 0:
        return int(pred_tokens == truth_tokens), int(pred_tokens == truth_tokens)
    
    common_tokens = set(pred_tokens) & set(truth_tokens)
    
    # if there are no common tokens then f1 = 0
    if len(common_tokens) == 0:
        return 0, 0
    
    prec = len(common_tokens) / len(pred_tokens)
    rec = len(common_tokens) / len(truth_tokens)
    
    return 2 * (prec * rec) / (prec + rec), rec

def get_gold_answers(example):
    """helper function that retrieves all possible true answers from a squad2.0 example"""
    
    gold_answers = [answer["text"] for answer in example.answers if answer["text"]]

    # if gold_answers doesn't exist it's because this is a negative example - 
    # the only correct answer is an empty string
    if not gold_answers:
        gold_answers = [""]
        
    return gold_answers

In [10]:
all_em_scores = []
all_f1_scores = []
all_recall_scores = []
for i in range(len(retrieve_answers)):
    prediction = retrieve_answers[i]
    gold_answers = answers_sample[i]

    em_score = max((compute_exact_match(prediction, answer)) for answer in gold_answers)
    f1_score = max((compute_f1_recall(prediction, answer)[0]) for answer in gold_answers)
    recall_score = max((compute_f1_recall(prediction, answer)[1]) for answer in gold_answers)

    all_em_scores.append(em_score)
    all_f1_scores.append(f1_score)
    all_recall_scores.append(recall_score)

    # if i % 10 == 0 or recall_score < 0.3:
    print(f"Question: {questions_sample[i]}")
    print(f"Prediction: {prediction}")
    print(f"True Answers: {gold_answers}")
    print(f"EM: {em_score} \t F1: {f1_score} \t Recall: {recall_score}")

print("=======================================")
print(f"Average EM: {sum(all_em_scores) / len(all_em_scores)}")
print(f"Average F1: {sum(all_f1_scores) / len(all_f1_scores)}")
print(f"Average Recall: {sum(all_recall_scores) / len(all_recall_scores)}")

Question: what superelevated is non controlling interest on balance sheet
Prediction: Non-controlling interest, also known as minority interest, is the portion of a subsidiary corporation's stock that is not owned by the parent corporation. It is shown on a company's balance sheet as part of equity.
True Answers: ["the portion of a subsidiary corporation 's stock that is not owned by the parent corporation"]
EM: 0 	 F1: 0.5 	 Recall: 0.8461538461538461
Question: how many episodes are duodenary in chicago fire season 4
Prediction: There are 23 episodes in Chicago Fire season 4.
True Answers: ['23']
EM: 0 	 F1: 0.19999999999999998 	 Recall: 1.0
Question: who sings bruckleness will keep us alive by the eagles
Prediction: The Eagles sing "Love Will Keep Us Alive".
True Answers: ['Timothy B. Schmit']
EM: 0 	 F1: 0 	 Recall: 0
Question: who is the leader of the ontario pc presagefully party
Prediction: The leader of the Ontario PC Party and Ontario's Leader of the Official Opposition is Patr

In [11]:
import sys

# Save the current sys.stdout for later restoration
original_stdout = sys.stdout

# Redirect sys.stdout to a file
with open('logs-mixed-100.txt', 'w') as f:
    sys.stdout = f
    
    for qa_problem in questions[:100]:
        print(f"\n\n>>>>>>>>>>>>>> case: {qa_problem} <<<<<<<<<<<<<<\n\n")
        assistant.reset()
        try:
            ragproxyagent.initiate_chat(assistant, problem=qa_problem, n_results=10)
        except Exception as e:
            print(f"Exception: {e}")

# Restore sys.stdout to its original value
sys.stdout = original_stdout

In [2]:
# Need to restart the kernel to run the following code, and also rerun the random_variation function to reset the seed
from analysis_log import main
main("logs-mixed-100.txt", question_process=random_variation)

[nltk_data] Downloading package stopwords to
[nltk_data]     /home/lijiang1/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!
[nltk_data] Downloading package punkt to /home/lijiang1/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package words to /home/lijiang1/nltk_data...
[nltk_data]   Package words is already up-to-date!



Analysis log file: logs-mixed-100.txt


Total Number of questions: 6775
len_lines=7896
_cnt_update_context=117


Number of questions: 100
Average EM: 0.0
Average F1: 0.19728938677911745
Average Recall: 0.5331193073843101


Number of questions: 20
Average EM: 0.0
Average F1: 0.0628114478114478
Average Recall: 0.12617559523809524


Number of questions: 80
Average EM: 0.0
Average F1: 0.23090887152103493
Average Recall: 0.6348552354208639
