In [None]:
import os
import json
import requests
import pandas as pd
from tqdm import tqdm
import random
import hashlib
import numpy as np
from pydantic import BaseModel
from typing import Optional, Literal, List

class QuestionSetSpecs(BaseModel):
	set: Literal["hotpotqa", "triviaqa"]
	subset: Optional[Literal["easy", "medium", "hard"]] = None
	samples: int
	batch_size: int
	seed: int
	model: str
	gt_answer_model: str # Needed for triviaqa, which has no GT answers.
	evaluator_model: str # For grading the model's answers
	
	ql_api_key: str
	openai_api_key: str
	deep_infra_api_key: str
 
	
	
	# Providing and int `n` here means we will test cases
	# 1, 2, 3, ..., n, plus the unbounded case.
	# Providing `None` means we will only test the unbounded case.
	max_search_depth: Optional[int | None] = None
	rerank_n: Optional[int] = None
	include_chunk_metadata: Optional[bool] = False
	unbounded_cases: Optional[List[str] | None] = None
 
	holdout: Optional[str | None] = None
 


DATASET_SPECS = QuestionSetSpecs(
	# set="hotpotqa",
	set="triviaqa",
	# subset="hard",
	samples=800,
	batch_size=3,
	seed=43,
	# model="deepinfra/meta-llama/Meta-Llama-3.1-8B-Instruct",
	# model="openai/gpt-4o",
	# model="openai/gpt-4o-mini",
	# model="deepinfra/meta-llama/Llama-3.3-70B-Instruct",
	model="llama-3.1-8b-instruct-sgs-tuned",
	gt_answer_model="openai/gpt-4o-mini",
	evaluator_model="openai/gpt-4o-mini",
	max_search_depth=5,
	rerank_n=100,
	include_chunk_metadata=True,
	unbounded_cases=["BM25", "HS", "HS_RR"],
	
	# unbounded_cases=["BM25"],
	
	holdout="finetune_train_set_question_hashes.json",
	
	# Local QL Deployment API key.
	ql_api_key="sk-peX2gbjszf2HdTJ5aqqcmb4iwhLJvidG9YNYILBedz0gXVJT",
	openai_api_key="",		# Load your own key here
	deep_infra_api_key=""	# Load your own key here
)


# QL Document Collection IDs for RAG
TRIVIAQA_ALL_ARTICLES = "J0N8YoPAykCmSG5Q5iHcWKCMvR9I6QUN"
HOTPOTQA_ALL_ARTICLES = "OKT6X298qw2H2chedi4Y12Nnhk4u1RXd"

TARGET_COLLECTIONS = [HOTPOTQA_ALL_ARTICLES] if DATASET_SPECS.set == "hotpotqa" else [TRIVIAQA_ALL_ARTICLES]

QUESTION_SELECTIONS = [] # Each entry must be dict with `question`, `answer`, and `pages` keys

def hash_string(input_string: str) -> str:
    # Create a new sha256 hash object
    sha256_hash = hashlib.sha256()
    
    # Update the hash object with the bytes of the input string
    sha256_hash.update(input_string.encode('utf-8'))
    
    # Get the hexadecimal representation of the hash
    hashed_string = sha256_hash.hexdigest()
    
    return hashed_string

# Jupyter cd to current script dir
try:
    os.chdir(globals()['_dh'][0])
except:
	os.chdir(os.path.dirname(os.path.abspath(__file__)))


CORRECT_ANSWERS_LOOKUP_HOTPOT = {}
EVAL_RESULTS = {}

OUTPUT_FILE = f"outputs_{'hotpot' if DATASET_SPECS.set == 'hotpotqa' else 'trivia'}.json"


if not os.path.exists("correct_answers_lookup.json"):
	with open("correct_answers_lookup.json", "w") as f:
		json.dump({}, f, indent=4)
		f.close()
  
if not os.path.exists("correct_answers_lookup_hotpot.json"):
	with open("correct_answers_lookup_hotpot.json", "w") as f:
		json.dump({}, f, indent=4)
		f.close()

if not os.path.exists(OUTPUT_FILE):
	with open(OUTPUT_FILE, "w") as f:
		json.dump({}, f, indent=4)
		f.close()

with open("correct_answers_lookup_hotpot.json", "r") as f:
	CORRECT_ANSWERS_LOOKUP_HOTPOT = json.load(f)
	f.close()

with open("correct_answers_lookup.json", "r") as f:
	CORRECT_ANSWERS_LOOKUP_TRIVIA = json.load(f)
	f.close()

with open(OUTPUT_FILE, "r") as f:
	EVAL_RESULTS = json.load(f)
	f.close()

# TEST_MODEL_RESPONSE_LOOKUP["gpt-4o-mini"] = {}

current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
parent_dir = os.path.dirname(parent_dir)

dir_over = os.path.join(parent_dir, 'hotpotqa')

print(current_dir)
print(parent_dir)
print(dir_over)

MAX_EVALUATIONS = DATASET_SPECS.samples


# with open(os.path.join(dir_over, 'hotpot_train_v1.1.json'), 'r') as f:
# 	HOTPOT_ALL = json.load(f)
# 	f.close()

# def hotpot_context_to_document(context):
# 	title = context[0]
# 	paragraphs = context[1]
# 	p_joined = '\n'.join(paragraphs)
# 	return f"### {title}\n\n{p_joined}"


In [None]:
current_dir = os.getcwd()
parent_dir = os.path.dirname(current_dir)
# parent_dir = os.path.dirname(parent_dir)

dir_over = os.path.join(parent_dir, 'hotpotqa')
hotpot_dir = os.path.join(parent_dir, 'hotpotqa')
triviaqa_dir = os.path.join(parent_dir, 'triviaqa_wikipedia')
geval_prompt_dir = os.path.join(parent_dir, 'geval_prompts')

print(current_dir)
print(parent_dir)
print(dir_over)


GEVAL_PROMPTS = {}

for (prompt_id, path) in [
	("coherence", "coh_detailed.txt"),
	("consistency", "con_detailed.txt"),
	("fluency", "flu_detailed.txt"),
	("relevance", "rel_detailed.txt")
]:
	with open(os.path.join(geval_prompt_dir, path), "r") as f:
		GEVAL_PROMPTS[prompt_id] = f.read()
		f.close()

if not DATASET_SPECS.holdout is None:
	with open(DATASET_SPECS.holdout, "r") as f:
		HOLDOUT_SET = json.load(f)
		f.close()
	HOLDOUT_SET = set(HOLDOUT_SET)
else:
	HOLDOUT_SET = set()

holdout_count = 0

evaluation_questions, dataset_hash = [], "abc"

if DATASET_SPECS.set == "hotpotqa":
    # Massive face-palm, we were using the wrong dataset.
	# with open(os.path.join(hotpot_dir, 'hotpot_train_v1.1.json'), 'r') as f:
	# 	HOTPOT_ALL = json.load(f)
	# 	f.close()
 
	with open(os.path.join(hotpot_dir, 'hotpot_train_v1.1.json'), 'r') as f:
		HOTPOT_ALL = json.load(f)
		f.close()
	
	np.random.seed(DATASET_SPECS.seed)
	np.random.shuffle(HOTPOT_ALL)
	
	all_levels = set()
	for q in HOTPOT_ALL:
		all_levels.add(q["level"])
	
	print("All categories for hotpotqa:", list(all_levels))
	
	if DATASET_SPECS.subset:
		HOTPOT_ALL = [x for x in HOTPOT_ALL if x["level"] == DATASET_SPECS.subset]
	
	HOTPOT_ALL = HOTPOT_ALL[:DATASET_SPECS.samples*2]
	
	
	for question in tqdm(HOTPOT_ALL):	
		question_hash = hash_string(question["question"])
     
		if question_hash in HOLDOUT_SET:
			holdout_count += 1
			continue

		
		evaluation_questions.append({
			"question": question["question"],
			"answer": question["answer"],
			"pages": [{
				"title": e[0],
       			"content": "\n".join(e[1])
          	} for e in question["context"]]
		})

		dataset_hash = hash_string(f"{dataset_hash} - {json.dumps(evaluation_questions[-1])}")
		
		# evaluation_questions.append(question)
elif DATASET_SPECS.set == "triviaqa":
	parquet_file_paths = [
		f'trivia_qa_wiki_{i}.parquet'
		for i in range(1, 8)
		# for i in range(1, 2)
	]

	total_scanned = 0
	ALL_TRIVIAQA = []
	evaluation_questions = []
	for parquet_file_path in tqdm(parquet_file_paths):
		# Use pandas to read the parquet file
		df = pd.read_parquet(os.path.join(triviaqa_dir, parquet_file_path))

		# Display the first few rows of the dataframe
		# print(df.head())

		for index, row in df.iterrows():
			ALL_TRIVIAQA.append(row.to_dict())

    
	np.random.seed(DATASET_SPECS.seed)
	np.random.shuffle(ALL_TRIVIAQA)
	ALL_TRIVIAQA = ALL_TRIVIAQA[:DATASET_SPECS.samples*2]
	for row in ALL_TRIVIAQA:
		question_hash = hash_string(row["question"])

		if question_hash in HOLDOUT_SET:
			holdout_count += 1
			continue
  
     
		evaluation_questions.append({
			"question": row["question"],
			# "answer": row["answer"],
			"pages": [{
				"title": hash_string(e),
				"content": e
         	} for e in row["entity_pages"]["wiki_context"].tolist()]
		})
		
		dataset_hash = hash_string(f"{dataset_hash} - {json.dumps(evaluation_questions[-1])}")

else:
    raise ValueError("Invalid dataset set")

print(f"Dataset Hash: {dataset_hash}")

# if "dataset_integrity" in EVAL_RESULTS and EVAL_RESULTS["dataset_integrity"] != dataset_hash:
# 	raise ValueError(f"Dataset integrity check failed; hash mismatch: {EVAL_RESULTS['dataset_integrity']} != {dataset_hash}")

EVAL_RESULTS["dataset_integrity"] = dataset_hash

print(f"Total holdouts: {holdout_count}")

In [3]:
import requests
import json
from openai import OpenAI
from typing import Literal

total_input_tokens = {}
total_output_tokens = {}

rates_per_million = {
    "deepinfra/meta-llama/Llama-3.3-70B-Instruct": {"input": 0.23, "output": 0.40},
    "deepinfra/meta-llama/Llama-3.3-70B-Instruct-Turbo": {"input": 0.12, "output": 0.30},
	"deepinfra/meta-llama/Meta-Llama-3.1-70B-Instruct": {"input": 0.23, "output": 0.40},
	"deepinfra/meta-llama/Meta-Llama-3.1-8B-Instruct": {"input": 0.03, "output": 0.05},
	"openai/gpt-4o-mini": {"input": 0.15, "output": 0.60},
	"openai/gpt-4o": {"input": 2.50, "output": 10.00},
}

def get_money_spent():
	total_cost = 0
	for model, tokens in total_input_tokens.items():
		if model in rates_per_million:
			rate = rates_per_million[model]["input"]
			total_cost += tokens * rate / 1000000
	for model, tokens in total_output_tokens.items():
		if model in rates_per_million:
			rate = rates_per_million[model]["output"]
			total_cost += tokens * rate / 1000000
	return total_cost


def call_llm_external(chat_history, model_parameters, functions_available=None, provider : Literal["openai", "deep_infra"] = "deep_infra"):
	global total_input_tokens, total_output_tokens
	assert isinstance(model_parameters, dict), "model_parameters must be a dictionary"
	assert isinstance(chat_history, list), "chat_history must be a list"
	assert all(["content" in e and "role" in e and e["role"] in ["user", "assistant", "system"] for e in chat_history]), \
		"chat_history must be a list of dictionaries with 'content' and 'role' keys"
	
	model = model_parameters["model"]
 
	ql_parameters = {k: v for k, v in model_parameters.items() if k not in ["model"]}
 
	response = requests.get(f"http://localhost:8000/api/format_chat_history", json={
		"auth": {"api_key": DATASET_SPECS.ql_api_key}, 
		"chat_history": chat_history,
		**({"functions_available": functions_available} if functions_available is not None else {})
	})
	response.raise_for_status()
	result = response.json()
 
	# print("RESULT_PROMPT:", json.dumps(result, indent=4))
	assert not ("success" in result and result["success"] == False), result["error"]
	
	chat_history_chopped = result["result"]
	
	# print("CHAT HISTORY CHOPPED:", json.dumps(chat_history_chopped, indent=4))
	
	url_basis = {
		"deep_infra": {"base_url": "https://api.deepinfra.com/v1/openai"},
		"openai": {}
	}
	
	api_key_lookup = {
		"deep_infra": DATASET_SPECS.deep_infra_api_key,
		"openai": DATASET_SPECS.openai_api_key
	}
 
	openai = OpenAI(
		api_key=api_key_lookup[provider],
		**url_basis[provider]
	)
	
	chat_completion = openai.chat.completions.create(
		model=model,
		messages=chat_history_chopped,
		stop=['<|eot_id|>'],
		max_tokens=model_parameters.get("max_tokens", 100),
		temperature=model_parameters.get("temperature", 0.0),
		top_p=model_parameters.get("top_p", 1.0),
	)
 
	token_model = ("openai" if provider == "openai" else "deepinfra") + "/" + model
 
	if not token_model in total_input_tokens:
		total_input_tokens[token_model] = 0
	if not token_model in total_output_tokens:
		total_output_tokens[token_model] = 0
	
	
	result_message = chat_completion.choices[0].message.content
	prompt_tokens, out_tokens = chat_completion.usage.prompt_tokens, chat_completion.usage.completion_tokens
	total_input_tokens[token_model] += prompt_tokens
	total_output_tokens[token_model] += out_tokens
 
	call_results = {}
	if not functions_available is None:
		response_2 = requests.get(f"http://localhost:8000/api/find_function_calls", json={
			"auth": {"api_key": DATASET_SPECS.ql_api_key}, 
			"text_in": result_message,
		})
		response_2.raise_for_status()
		result_2 = response_2.json()
		assert not ("success" in result_2 and result_2["success"] == False), result_2["error"]
  
		# print("FIND FUNCTION CALLS RESULTS:", json.dumps(result_2, indent=4))
		calls = result_2["result"]
		calls_possible = [
			e["name"] 
			for e in functions_available 
		]
		calls = [e for e in calls if "function" in e and e["function"] in calls_possible]
		call_results = {"function_calls": calls}
	
	return {"output": result_message, **call_results}


def call_llm_querylake(chat_history, model_parameters, functions_available=None):
	# global total_input_tokens, total_output_tokens
	
	model = model_parameters["model"]
 
	assert isinstance(model_parameters, dict), "model_parameters must be a dictionary"
	assert isinstance(chat_history, list), "chat_history must be a list"
	assert all(["content" in e and "role" in e and e["role"] in ["user", "assistant", "system"] for e in chat_history]), \
		"chat_history must be a list of dictionaries with 'content' and 'role' keys"

	response = requests.get(f"http://localhost:8000/api/llm", json={
		"auth": {"api_key": DATASET_SPECS.ql_api_key}, 
		"chat_history": chat_history,
		"model_parameters": model_parameters,
		**({"functions_available": functions_available} if functions_available is not None else {})
	})
	response.raise_for_status()

	result = response.json()

	# print("RESULT KEYS:", result["result"].keys())

	assert "input_token_count" in result["result"], "input_token_count not in result"
	assert "output_token_count" in result["result"], "output_token_count not in result"
 
	# print("INPUT TOKENS:", result["result"]["input_token_count"])
	# print("OUTPUT TOKENS:", result["result"]["output_token_count"])

	model = model_parameters["model"]
	
	total_input_tokens[model] = total_input_tokens.get(model, 0) + result["result"]["input_token_count"]
	total_output_tokens[model] = total_output_tokens.get(model, 0) + result["result"]["output_token_count"]
 
 
	if ("success" in result and result["success"] == False):
		print(json.dumps(result, indent=4))

	try:
		assert not ("success" in result and result["success"] == False), result["error"]
	except Exception as e:
		print(e)
		raise e

	# total_input_tokens += result["result"]["input_token_count"]
	# total_output_tokens += result["result"]["output_token_count"]

	return result["result"]


def call_llm(chat_history, 
             model_parameters, 
             functions_available=None):
	
	model = model_parameters["model"]
	
	
 
	model_split = model.split("/")
	model_main = "/".join(model_split[1:]) if len(model_split) > 1 else model_split[0]
 
 
	if len(model_split) > 1 and model_split[0] in ["openai", "deepinfra"]:
		return call_llm_external(
      		chat_history, 
        	{**model_parameters, "model": model_main}, 
         	functions_available, 
          	provider="openai" if model_split[0] == "openai" else "deep_infra"
        )
 
 
	# print("CALL_LLM:", endpoint)
	
	return call_llm_querylake(chat_history, model_parameters, functions_available)

def querylake_chop_chat_history(chat_history, model_parameters, functions_available=None):
	assert isinstance(model_parameters, dict), "model_parameters must be a dictionary"
	assert isinstance(chat_history, list), "chat_history must be a list"
	assert all(["content" in e and "role" in e and e["role"] in ["user", "assistant", "system"] for e in chat_history]), \
		"chat_history must be a list of dictionaries with 'content' and 'role' keys"
	
	ql_parameters = {k: v for k, v in model_parameters.items() if k not in ["model"]}
 
	response = requests.get(f"http://localhost:8000/api/llm", json={
		"auth": {"api_key": DATASET_SPECS.ql_api_key}, 
		"chat_history": chat_history,
		"model_parameters": ql_parameters,
  		"only_format_prompt": True,
		**({"functions_available": functions_available} if functions_available is not None else {})
	})
	response.raise_for_status()
	result = response.json()
 
	# print("RESULT_PROMPT:", json.dumps(result, indent=4))
	assert not ("success" in result and result["success"] == False), result["error"]
	
	chat_history_chopped = result["result"]["chat_history"]
 
	return chat_history_chopped

def call_search_hybrid(parameters):
    
    response = requests.get(f"http://localhost:8000/api/search_hybrid", json={
		"auth": {"api_key": DATASET_SPECS.ql_api_key},
		**parameters
	})
    response.raise_for_status()
    
    result = response.json()
    
    assert not ("success" in result and result["success"] == False), result["error"]
    
    return result["result"]

def call_search_bm25(parameters):
    
    response = requests.get(f"http://localhost:8000/api/search_bm25", json={
		"auth": {"api_key": DATASET_SPECS.ql_api_key},
		**parameters
	})
    response.raise_for_status()
    
    result = response.json()
    
    assert not ("success" in result and result["success"] == False), result["error"]
    
    return result["result"]

def count_tokens(input : str):
	
	
	response = requests.get(f"http://localhost:8000/api/llm_count_tokens", json={
		"auth": {"api_key": DATASET_SPECS.ql_api_key},
		"model_id": "llama-3.1-8b-instruct",
		"input_string": input
	})
	response.raise_for_status()
	result = response.json()
 
	# print("RESULT_PROMPT:", json.dumps(result, indent=4))
	assert not ("success" in result and result["success"] == False), result["error"]
	
	
	return result["result"]

# Parallel Processing

In [4]:
import concurrent.futures

def call_function_batch(function_in, tuple_args):
    with concurrent.futures.ThreadPoolExecutor() as executor:
        results = list(executor.map(lambda p: function_in(*p), tuple_args))
    return results

# The Main Implementation

In [5]:
from typing import Callable, Any, Union, Awaitable, List, Dict, Tuple
import time

class DocumentChunkDictionary(BaseModel):
    id: Union[str, int, List[str], List[int]]
    creation_timestamp: float
    collection_type: Optional[Union[str, None]]
    document_id: Optional[Union[str, None]]
    document_chunk_number: Optional[Union[int, Tuple[int, int], None]]
    # document_integrity: Optional[Union[str, None]]
    collection_id: Optional[Union[str, None]]
    document_name: str
    # website_url : Optional[Union[str, None]]
    # private: bool
    md: dict
    document_md: dict
    text: str
    embedding: Optional[List[float]] = None
    
    hybrid_score: Optional[float] = None
    bm25_score: Optional[float] = None
    similarity_score: Optional[float] = None

TRAINING_SYSTEM_PROMPT = """
You are tasked with searching for information to answer a question.
You will be given a question, and you must perform searches to find the information necessary to answer it.
DO NOT answer with information not provided by searches.
DO NOT make up information.
ONLY answer with information you have found in the database.
"""

TRAINING_PROMPT_1 = """
You are showing how to do a sample exam to prepare students in a research class.
You will be given a question on a topic.
You must attempt to answer it by performing consecutive searches and retrieving sources until you are ready to answer.
You must perform these searches and make notes of the information as you parse through it until you feel confident that you can answer the question.
When you perform a search or otherwise call a function, you will be met with the result as a response. You may then continue.
Only respond with your next step.
Complete the process in an ideal way by calling provided functions.
Your functions will be search_database() for searching, ready_to_answer() for when you decide to answer, and cannot_answer() for if you feel you couldn't answer effectively.
Take as many searches as you need, and feel free to try different angles to gain insight.. 
Only respond with your reasoning and actions.

Here is your question: {question}

Only respond with your next step as if you were taking the exam.
While searching, provide some brief reasoning or observations before performing the search if helpful.
"""


TRAINING_PROMPT_2 = """
You will be given a question on a topic requiring 1-2 paragraphs to answer
You must attempt to answer it by performing consecutive searches and retrieving sources until you are ready to answer.
You must perform these searches and make notes of the information as you parse through it until you feel confident that you can answer the question.
When you perform a search or otherwise call a function, you will be met with the result as a response. You may then continue.
Only respond with your next step.
Complete the process in an ideal way by calling provided functions.
Your functions will be search_database() for searching, ready_to_answer() to indicate that you are ready to answer, and cannot_answer() for if you feel you couldn't answer effectively.
You are allowed {max_search_string} before answering, no more.
Only respond with your reasoning and actions.

Here is your question: {question}

Only respond with your next step.
"""

answer_enabled_prompt = """
Answering enabled. Go ahead and write your final answer, and nothing else. 
Please cite your sources with frequent inline citations whenever they are even slightly applicable.
Use the citations provided inside the <CITATION> XML of respective sources (i.e. {cite:1}, do not include the XML tags).
Again, citations should be as frequent as possible. While claiming something, 
you should ideally have an appropriate citation at the end of every sentence.
Do not make any claims not supported by the sources provided.
Continue.
"""

could_not_answer_prompt = """
You have indicated that you cannot answer.
Please provide a brief explanation as to why you cannot answer the question.
Also, provide what information you have uncovered. 
Please cite your sources with frequent inline citations whenever they are applicable.
Use the citations provided inside the <CITATION> XML of respective sources (i.e. {cite:1}, do not include the XML tags).
Continue.
"""


search_func_description = """
Perform a search of the database for information.
This is deterministic, so don't repeat the same search.
"""

search_func_def = {
    "name": "search_database",
    "description": search_func_description,
    "parameters": [
        {
            "name": "question",
            "type": "str",
            "description": "Effectively a google search. Will be used to retrieve information from the database via term similarity.",
        }
    ]
}

answer_func_def = {
    "name": "ready_to_answer",
    "description": """
If you feel you have enough information to answer the question, call this to let the system know you are ready to answer.
Only then can you write your answer.
""",
    "parameters": []
}

cannot_answer_func_def = {
    "name": "cannot_answer",
    "description": "If you feel you are unable to answer the question even with the tools available, call this.",
    "parameters": []
}


def self_guided_search_manual(
    model : str,
    question: str,
    collection_ids: List[str] = [],
    max_searches : int = 5,
    use_hybrid: bool = False,
    use_rerank: int = None,
    use_metadata: bool = False,
) -> str:
    """
    Self guided search.
    """
    
    start_time = time.time()

    answer_flag, answer, sources_matched, responses, searches = False, None, [], 0, 0
    ready_to_answer_flag = False
    demo_sequence = []
    max_search_string = f"{max_searches} searches" if max_searches > 1 else "1 search"
    chat_history_1 = [
        {"role": "system", "content": TRAINING_SYSTEM_PROMPT},
        {"role": "user", "content": TRAINING_PROMPT_2.format(question=question, max_search_string=max_search_string)}
    ]

    max_responses, prompt_tokens, output_tokens = 15, 0, 0

    previous_searches, previous_results, all_sources = set(), [], []
    answer_found = False
    
    searches_return = []
    
    
    def on_new_search(search: str):
        nonlocal searches_return
        searches_return.append({"search": search})
                
    def on_new_source(source_in: dict):
        nonlocal all_sources
        all_sources.append(source_in)
    
    while True:
        if ready_to_answer_flag:
            answer_flag = True
        
        model_parameters = {
            "model": model,
            "max_tokens": 200 if not answer_flag else 4096,
            "temperature": 0,
            "top_p": 1.0,
            "repetition_penalty": 1.15
        }
        # print("Looping with chat history:", [small_map[chat["role"]] for chat in chat_history_1])
        all_functions_available = [answer_func_def, cannot_answer_func_def]
        last_statement = ""
        if searches < max_searches and not answer_flag:
            all_functions_available.append(search_func_def)
            last_statement = f"\n\nYou have {max_searches - searches} searches remaining."
        elif searches >= max_searches and not answer_flag:
            # answer_flag = True
            last_statement = "\n\n-------ATTENTION------- You have no searches remaining."
        
        chat_history_1[-1]["content"] += last_statement
        

        
        model_response = call_llm(
            chat_history=chat_history_1, 
            model_parameters=model_parameters,
            functions_available=all_functions_available if not answer_flag else [],
        )
        
        prompt_tokens += model_response.get("input_token_count", 0)
        output_tokens += model_response.get("output_token_count", 0)
        

        responses += 1
        chat_history_1.append({"role": "assistant", "content": model_response["output"], "function_calls": model_response.get("function_calls", [])})
        demo_sequence.append({"role": "assistant", "content": model_response["output"]})

        citation_map = {}
        
        if "function_calls" in model_response and len(model_response["function_calls"]) > 0:
            
            if model_response["function_calls"][-1]["function"] == "search_database" and "question" in model_response["function_calls"][-1]["arguments"]:
                new_search = model_response["function_calls"][-1]["arguments"]["question"]
                if new_search in previous_searches:
                    response = {"role": "user", "content": "You have already requested this search. Refer to the results from that attempt. Continue."}
                    chat_history_1.append(response)
                    demo_sequence.append(response)
                    continue
                
                searches += 1
                excluded_chunks = " ".join([f"-id:\"{result}\"" for result in previous_results])

                search_make = model_response["function_calls"][-1]["arguments"]["question"] + f" {excluded_chunks}"
                on_new_search(f"Searching: \"{new_search}\"")
                
                # print("PREVIOUS RESULTS:", previous_results)
                # print("SEARCH MADE:", search_make)
                if not use_hybrid:
                    searched_sources : List[DocumentChunkDictionary] = call_search_bm25(dict(
                        query=search_make,
                        collection_ids=collection_ids,
                        limit=5,
                    ))
                else:
                    split_size = 5 if use_rerank is None else use_rerank // 2
                    
                    searched_sources : List[DocumentChunkDictionary] = call_search_hybrid(dict(
                        query=search_make,
                        collection_ids=collection_ids,
                        limit_bm25=split_size,
                        limit_similarity=split_size,
                        rerank=(True if not use_rerank is None else False)
                    ))
                    searched_sources = searched_sources[:5]
                    
                
                
                for source in searched_sources:
                    on_new_source(source)
                
                for i in range(len(searched_sources)):
                    citation_map[
                        searched_sources[i]["id"] 
                        if isinstance(searched_sources[i]["id"], str) 
                        else searched_sources[i]["id"][0]
                    ] = len(sources_matched) + i + 1
                
                for source in searched_sources:
                    if isinstance(source["id"], list):
                        sources_matched.extend(source["id"])
                    else:
                        sources_matched.append(source["id"])
                        
                    
                
                sources_represented = [
                    "<CITATION>\n\t{cite:" + str(citation_map[
                        source["id"]
                        if isinstance(source["id"], str) 
                        else source["id"][0]
                    ]) + "}\n</CITATION>\n" + \
                    f"<CONTENT>\n{source['text']}\n</CONTENT>" + \
                    (("\n<METADATA>\n" +
                    json.dumps({
                        k: v for k, v in source.items() 
                        if k not in ["embedding", "collection_id", "collection_type", "creation_timestamp", "text"]    
                    }, indent=4)  + "\n</METADATA>") if use_metadata else "")
                    for source in searched_sources
                ]
                searched_sources_string = "\n\n".join([source_repr for source_repr in sources_represented])
                # print("SEARCHED SOURCES:", searched_sources)
                response = {"role": "user", "content": f"<SEARCH_RESULTS>\n{searched_sources_string}\n</SEARCH_RESULTS>"}
                chat_history_1.append(response)
                demo_sequence.append({**response, "sources": searched_sources})
                previous_searches.add(new_search)
                previous_results.extend([source["id"] for source in searched_sources])
            
            elif model_response["function_calls"][-1]["function"] == "ready_to_answer":

                # print("Got function calls:", model_response["function_calls"])
                # if len(model_response["output"].split("&& > ready_to_answer() &&")[-1]) > 45:
                #     answer = model_response["output"].split("&& > ready_to_answer() &&")[-1]
                #     break
                response = {"role": "user", "content": answer_enabled_prompt}
                chat_history_1.append(response)
                demo_sequence.append(response)
                # print("Getting ready to answer.")
                ready_to_answer_flag = True
                answer_flag = False
                answer_found = True

            elif model_response["function_calls"][-1]["function"] == "cannot_answer":
                # if len(model_response["output"].split("&& > ready_to_answer() &&")[-1]) > 45:
                #     answer = model_response["output"].split("&& > ready_to_answer() &&")[-1]
                #     break
                response = {"role": "user", "content": could_not_answer_prompt}
                chat_history_1.append(response)
                demo_sequence.append(response)
                # print("Getting ready to answer.")
                ready_to_answer_flag = True
                answer_flag = False
        
        if answer_flag:
            answer = model_response["output"]
            break
        
        if chat_history_1[-1]["role"] != "user":
            response = {"role": "user", "content": "No function calls were parsed. Please remember all function calls must be ended with ' &&'. Continue"}
            chat_history_1.append(response)
            demo_sequence.append(response)

        if responses >= max_responses:
            max_cite_index = max([int(cite.split(":")[-1][:-1]) for cite in citation_map.keys()])
            assert max_cite_index == len(sources_matched), f"Max cite index: {max_cite_index}, sources matched: {len(sources_matched)}, citation map: {citation_map}"
            
            return {
                "chat_history": chat_history_1, 
                "output": "Model ran out of responses.", 
                "responses": responses, 
                "time_taken": time.time() - start_time,
                "sources": [], 
                "answer_found": answer_found,
                "searches": searches_return,
                "input_token_count": prompt_tokens,
                "output_token_count": output_tokens
            }


    # print("RETURNING DEMO SEQUENCE WITH LENGTH", len(demo_sequence))
    return {
        "chat_history": chat_history_1, 
        "output": answer,
        "responses": responses,
        "time_taken": time.time() - start_time,
        "sources": all_sources,
        "answer_found": answer_found,
        "searches": searches_return,
        "input_token_count": prompt_tokens,
        "output_token_count": output_tokens
    }

In [None]:
system_prompt_create_correct_answer = """
You are an assistant that answers user questions, but you only do so using information sources provided along with requests.
"""

prompt_create_correct_answer = """
Answer the following questions using the provided sources.
If you cannot answer the question, simply write CANNOT_ANSWER.


Question: {question}


<SOURCES>
{sources}
</SOURCES>
"""

for i, entry in tqdm(enumerate(evaluation_questions)):
	question_hash = hash_string(entry["question"])
    
	# In the case of TriviaQA, the correct answers don't exist,
	# We have to generate it with a model using the correct sources.
	
	if ("answer" not in entry) and (question_hash not in CORRECT_ANSWERS_LOOKUP_TRIVIA):
		sources = "\n".join([f"INFORMATION:\n{e['content']}\n\n" for e in entry["pages"]])
		prompt = prompt_create_correct_answer.format(question=entry["question"], sources=sources)
		
		chat_history = [
			{"content": system_prompt_create_correct_answer, "role": "system"},
			{"content": prompt, "role": "user"},
		]
		

		model_response = call_llm(
			chat_history=chat_history,
			model_parameters={
				"model": DATASET_SPECS.gt_answer_model,
				"max_tokens": 4096,
				"temperature": 0.0,
				"top_p": 1.0
			},
			functions_available=None,
		)

		answer = model_response["output"]

		CORRECT_ANSWERS_LOOKUP_TRIVIA[question_hash] = answer
		
	if DATASET_SPECS.set == "triviaqa":
		evaluation_questions[i]["answer"] = CORRECT_ANSWERS_LOOKUP_TRIVIA[question_hash]
  
evaluation_questions_filtered = [e for e in evaluation_questions if "answer" in e and "CANNOT_ANSWER" not in e["answer"]]


assert len(evaluation_questions_filtered) >= DATASET_SPECS.samples, \
    f"Only {len(evaluation_questions_filtered)} questions available ({DATASET_SPECS.samples} requested)."
    
evaluation_questions = evaluation_questions_filtered[:DATASET_SPECS.samples]  


# Chunk Hit Rate

In [7]:
def measure_chunk_hit_rate(chunks, documents):
	hits = 0
	for chunk in chunks:
		for document in documents:
			if chunk in document:
				hits += 1
				break
	return hits / len(chunks)

# Self Guided Search Function

In [8]:
PRINT_MODEL_RESPONSE = False


In [9]:
def self_guided_search_bounded(
    question : str, 
    max_searches : int,
    use_hybrid : bool = False,
    use_rerank : bool = False,
    use_context_expansion : bool = False,
    rerank_n : int = 100,
):
	response = requests.post("http://localhost:8000/api/self_guided_search", json={
		"auth": {"api_key": DATASET_SPECS.ql_api_key},
		"question": question,
		"collection_ids": TARGET_COLLECTIONS,
		"model": DATASET_SPECS.model, # TODO: get model
		"max_searches": max_searches,
		"use_hybrid": use_hybrid,
		**({"use_rerank": rerank_n} if use_rerank else {}),
	})
	
	response.raise_for_status()

	result = response.json()

	assert not ("success" in result and result["success"] == False), result["error"]
 
	result = result["result"]
	
	assert "input_token_count" in result, "input_token_count not in result"
	assert "output_token_count" in result, "output_token_count not in result"
	
	model = DATASET_SPECS.model
	
	total_input_tokens[model] = total_input_tokens.get(model, 0) + result["input_token_count"]
	total_output_tokens[model] = total_output_tokens.get(model, 0) + result["output_token_count"]
	
	return result

In [None]:
prompt_check_answer_correctness = """
You must grade a student's answer for correctness given the correct answer.

Question: {question}

Correct answer: {correct_answer}

Student Answer: {student_answer}


Was the student answer factually correct?
Respond YES or NO, and nothing else.
"""

import re

cached_amount = 0

samples_completed = 0
samples_per_question = \
    (1 if DATASET_SPECS.unbounded_cases is None else len(DATASET_SPECS.unbounded_cases)) + \
    (0 if DATASET_SPECS.max_search_depth is None else DATASET_SPECS.max_search_depth)

total_samples = DATASET_SPECS.samples * samples_per_question

time_started = time.time()

def notify_sample_completed():
	global samples_completed, total_samples, time_started
	samples_completed += 1
	time_elapsed = time.time() - time_started
 
	time_per_sample = time_elapsed / samples_completed
	time_per_question = time_per_sample * samples_per_question
 
 
	print("Samples completed: %05d / %05d ( %10.2fs %10.2fs/sample %10.2fs/q )        " % (
     	samples_completed, total_samples,
		time_elapsed, time_per_sample, time_per_question
    ), end="\r")


def evaluate_question(entry, entry_i, question_hash):
	model=DATASET_SPECS.model
	run_successful = False
	retries = 0
	
	if "error_occurrence" in entry:
		return
 
		
		
	# In the case of TriviaQA, the correct answers don't exist,
	# We have to generate it with a model using the correct sources.
	if ("answer" not in entry) and (question_hash not in CORRECT_ANSWERS_LOOKUP_TRIVIA):
		sources = "\n".join([f"INFORMATION:\n{e}\n\n" for e in entry["pages"]])
		prompt = prompt_create_correct_answer.format(question=entry["question"], sources=sources)
		
		chat_history = [
			{"content": system_prompt_create_correct_answer, "role": "system"},
			{"content": prompt, "role": "user"},
		]
		
		tokens = count_tokens(" ".join([e["content"] for e in chat_history]))
		
		# print("TOKEN COUNT:", tokens)

		model_response = call_llm(
			chat_history=chat_history,
			model_parameters={
				"model": DATASET_SPECS.gt_answer_model,
				"max_tokens": 4096,
				"temperature": 0.0,
				"top_p": 1.0
			},
			functions_available=None,
		)

		answer = model_response["output"]

		CORRECT_ANSWERS_LOOKUP_TRIVIA[question_hash] = answer
	else:
		# cached_amount += 1
		pass
	
	if "answer" in entry:
		correct_answer = entry["answer"]
	else:
		correct_answer = CORRECT_ANSWERS_LOOKUP_TRIVIA[question_hash]
	
	if "CANNOT_ANSWER" in correct_answer:
		return
	
	
	if model not in EVAL_RESULTS:
		EVAL_RESULTS[model] = {}
	
	
	# Unbounded case
	if not question_hash in EVAL_RESULTS[model]:
		EVAL_RESULTS[model][question_hash] = {}

	unbounded_cases = ["BM25"] if DATASET_SPECS.unbounded_cases is None else DATASET_SPECS.unbounded_cases
	
	for unbounded_case in unbounded_cases:
		specs_raw = unbounded_case.split("_")
		use_hybrid_tmp = "HS" in specs_raw
		assert use_hybrid_tmp == (not "BM25" in specs_raw), "Cannot use both BM25 and HS"
		use_rerank_tmp = "RR" in specs_raw
		use_context_expansion_tmp = "CE" in specs_raw

		unbounded_case_identifier = f"unbounded_{unbounded_case}"

		if unbounded_case_identifier in EVAL_RESULTS[model][question_hash]:
			continue

		
		# Calling through QueryLake
		# sgs_result = self_guided_search_bounded(
		# 	entry["question"], 
		# 	99,
		# 	use_hybrid=use_hybrid_tmp,
		# 	use_rerank=use_rerank_tmp,
		# 	rerank_n=DATASET_SPECS.rerank_n if use_rerank_tmp else None,
		# 	use_context_expansion=use_context_expansion_tmp
		# )

		try:
			sgs_result = self_guided_search_manual(
				model=DATASET_SPECS.model,
				question=entry["question"],
				collection_ids=TARGET_COLLECTIONS,
				max_searches=99,
				use_hybrid=use_hybrid_tmp,
				use_rerank=DATASET_SPECS.rerank_n if use_rerank_tmp else None,
				use_metadata=DATASET_SPECS.include_chunk_metadata,
				# use_context_expansion=use_context_expansion_tmp
			)
		except:
			EVAL_RESULTS[model][question_hash][unbounded_case_identifier] = {
				"error_occurrence": True
			}
			continue
	
		
		chat_sequence, test_answer, duration, response_count = \
			sgs_result["chat_history"], sgs_result["output"], \
			sgs_result["time_taken"], sgs_result["responses"]

		correctness_prompt = prompt_check_answer_correctness.format(question=entry["question"], correct_answer=correct_answer, student_answer=test_answer)
		# If we haven't already evaluated this answer, go ahead and do so.
		model_response_check_correctness = call_llm(
			chat_history=[{"content": correctness_prompt, "role": "user"}],
			model_parameters={
				"model": DATASET_SPECS.evaluator_model,
				"max_tokens": 4096,
				"temperature": 0.0,
				"top_p": 1.0
			},
			functions_available=None
		) if test_answer is not None else {"output": "ANSWER_UNAVAILABLE"}
		
		notify_sample_completed()

		EVAL_RESULTS[model][question_hash][unbounded_case_identifier] = {
			"chat_sequence": chat_sequence,
			"use_hybrid": use_hybrid_tmp,
			"rerank": DATASET_SPECS.rerank_n if use_rerank_tmp else False,
			"context_expansion": use_context_expansion_tmp,
			"question": entry["question"],
			"correct_answer": correct_answer,
			"ground_truth_sources": entry["pages"],
			"retrieved_sources": sgs_result["sources"],
			"test_answer": test_answer,
			"duration": duration,
			"response_count": response_count,
			"correctness_rating": model_response_check_correctness["output"],
		}

	for max_searches in range(1, DATASET_SPECS.max_search_depth+1) if isinstance(DATASET_SPECS.max_search_depth, int) else []:
		
		# If we haven't already evaluated this answer, go ahead and do so.
		if f"max_searches_{max_searches}" in EVAL_RESULTS[model][question_hash]:
			continue

		# Calling through QueryLake
		# sgs_result = self_guided_search_bounded(entry["question"], max_searches)
		try:
			sgs_result = self_guided_search_manual(
				model=DATASET_SPECS.model,
				question=entry["question"],
				collection_ids=TARGET_COLLECTIONS,
				max_searches=max_searches,
				use_metadata=DATASET_SPECS.include_chunk_metadata,
			)
		except:
			EVAL_RESULTS[model][question_hash][f"max_searches_{max_searches}"] = {
				"error_occurrence": True
			}
			continue
		
		chat_sequence, test_answer, duration, response_count = \
			sgs_result["chat_history"], sgs_result["output"], \
			sgs_result["time_taken"], sgs_result["responses"]

		correctness_prompt = prompt_check_answer_correctness.format(question=entry["question"], correct_answer=correct_answer, student_answer=test_answer)

		model_response_check_correctness = call_llm(
			chat_history=[{"content": correctness_prompt, "role": "user"}],
			model_parameters={
				"model": DATASET_SPECS.evaluator_model,
				"max_tokens": 4096,
				"temperature": 0.0,
				"top_p": 1.0
			},
			functions_available=None
		) if test_answer is not None else {"output": "0"}

		# rating = re.search(r"(\d)/5", model_response_check_correctness["output"])
		# if rating is None:
		# 	rating = 0
		# else:
		# 	rating = int(rating.group(1))

		notify_sample_completed()

		EVAL_RESULTS[model][question_hash][f"max_searches_{max_searches}"] = {
			"chat_sequence": chat_sequence,
			"question": entry["question"],
			"correct_answer": correct_answer,
			"ground_truth_sources": entry["pages"],
			"retrieved_sources": sgs_result["sources"],
			"test_answer": test_answer,
			"duration": duration,
			"response_count": response_count,
			"correctness_rating": model_response_check_correctness["output"],
		}
	



	# if entry_i % 10 == 9:
	# 	print("Total cost $%.4f" % (get_money_spent()))

	run_successful = True

arguments = []

for entry_i in tqdm(range(len(evaluation_questions))):
# for entry_i in range(len(evaluation_questions)):
	
	entry = evaluation_questions[entry_i]
	# print("Evaluating entry:", entry)
	question_hash = hash_string(entry["question"])
	
	# evaluate_question(entry, entry_i, question_hash)
	arguments.append((entry, entry_i, question_hash))
	if (entry_i % DATASET_SPECS.batch_size == DATASET_SPECS.batch_size - 1) or entry_i == len(evaluation_questions) - 1:
		# print("CALLING AT", entry_i)
		call_function_batch(evaluate_question, arguments)
		arguments = []


print("Cache rate: %%%.2f" % (100 * cached_amount / len(evaluation_questions)))
print("Total cost $%.4f" % (get_money_spent()))

In [11]:
with open("correct_answers_lookup_hotpot.json", "w") as f:
	json.dump(CORRECT_ANSWERS_LOOKUP_HOTPOT, f, indent=4)
	f.close()

with open("correct_answers_lookup.json", "w") as f:
	json.dump(CORRECT_ANSWERS_LOOKUP_TRIVIA, f, indent=4)
	f.close()


# for entry_i in tqdm(range(len(evaluation_questions[133:]))):
# 	question_hash = hash_string(entry["question"])
	
# 	if question_hash in EVAL_RESULTS[DATASET_SPECS.model]:
# 		del EVAL_RESULTS[DATASET_SPECS.model][question_hash]

with open(OUTPUT_FILE, "w") as f:
	json.dump(EVAL_RESULTS, f, indent=4)
	f.close()

In [None]:
unique_keys = []
RESULTS_BY_TYPE = {}

for question_hash, entry in EVAL_RESULTS[DATASET_SPECS.model].items():
	for key in entry.keys():
		if key == "error_occurrence":
			continue
		if key not in unique_keys:
			unique_keys.append(key)
			RESULTS_BY_TYPE[key] = []
		# RESULTS_BY_TYPE[key].append(entry[key]["correctness_rating"])

unique_keys = [f"max_searches_{i}" for i in range(1, DATASET_SPECS.max_search_depth+1)] + \
    ["unbounded_BM25"]

# Only register questions with all tests evaluated, so that test set is properly controlled.
for question_hash, entry in EVAL_RESULTS[DATASET_SPECS.model].items():
	# if not all([key in entry and (not "error_occurrence" in entry[key]) for key in unique_keys]):
	# 	continue
	
	for key in entry:
		if key == "error_occurrence":
			continue
		if "error_occurrence" in entry[key]:
			continue
		
		assert "correctness_rating" in entry[key], f"correctness_rating not in entry for key {key}"
		RESULTS_BY_TYPE[key].append(entry[key]["correctness_rating"])

RESULTS_PERCENTAGES = {
	"dataset_hash": dataset_hash,
	"configuration": DATASET_SPECS.model_dump()
}

for key, entry in RESULTS_BY_TYPE.items():
	yes_count = sum([1 for e in entry if "YES" in e])
	no_count = sum([1 for e in entry if not "YES" in e])
	
	# no_count = sum([1 for e in entry if "NO" in e])
    
	RESULTS_PERCENTAGES[key] = {
		"percentage": 100 * yes_count / max(1, (yes_count + no_count)),
		"samples": (yes_count + no_count)
	}


print(json.dumps(RESULTS_PERCENTAGES, indent=4))

with open(f"results_percentages_{str(int(time.time()))[:]}.json", "w") as f:
	json.dump(RESULTS_PERCENTAGES, f, indent=4)
	f.close()