-
Notifications
You must be signed in to change notification settings - Fork 4.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
959 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
from typing import Dict, Type | ||
|
||
from llama_index.agent import OpenAIAgent, ReActAgent | ||
from llama_index.agent.types import BaseAgent | ||
from llama_index.llms import Anthropic, OpenAI | ||
from llama_index.llms.base import LLM | ||
|
||
OPENAI_MODELS = [ | ||
"text-davinci-003", | ||
"gpt-3.5-turbo-0613", | ||
"gpt-4-0613", | ||
] | ||
ANTHROPIC_MODELS = ["claude-instant-1", "claude-2"] | ||
ALL_MODELS = OPENAI_MODELS + ANTHROPIC_MODELS | ||
|
||
AGENTS: Dict[str, Type[BaseAgent]] = { | ||
"react": ReActAgent, | ||
"openai": OpenAIAgent, | ||
} | ||
|
||
|
||
def get_model(model: str) -> LLM: | ||
llm: LLM | ||
if model in OPENAI_MODELS: | ||
llm = OpenAI(model=model) | ||
elif model in ANTHROPIC_MODELS: | ||
llm = Anthropic(model=model) | ||
else: | ||
raise ValueError(f"Unknown model {model}") | ||
return llm | ||
|
||
|
||
def is_valid_combination(agent: str, model: str) -> bool: | ||
if agent == "openai" and model not in ["gpt-3.5-turbo-0613", "gpt-4-0613"]: | ||
print(f"{agent} does not work with {model}") | ||
return False | ||
return True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
from typing import Callable, Dict | ||
|
||
from task import Task | ||
|
||
from llama_index.tools.function_tool import FunctionTool | ||
|
||
|
||
class Phone: | ||
def __init__(self) -> None: | ||
self.number = "" | ||
self.entered = False | ||
|
||
def dial_digit(self, number: str) -> None: | ||
"""Dial a digit on the phone.""" | ||
assert len(number) == 1 and number.isdigit() | ||
self.number += number | ||
|
||
def enter(self) -> None: | ||
"""Press the enter key on the phone.""" | ||
if self.entered: | ||
raise Exception("Already entered") | ||
|
||
self.entered = True | ||
|
||
def evaluate(self, response: str, expected_response: str) -> bool: | ||
return self.number == expected_response and self.entered | ||
|
||
|
||
def search_number(first_name: str, last_name: str) -> str: | ||
"""Search for a person by first and last name.""" | ||
if first_name == "John" and last_name == "Smith": | ||
return "2135" | ||
else: | ||
return "No results found. Please capitalize both first and last name." | ||
|
||
|
||
search_number_tool = FunctionTool.from_defaults(fn=search_number) | ||
|
||
|
||
def get_dial_then_enter() -> Task: | ||
phone = Phone() | ||
|
||
dial_digit_tool = FunctionTool.from_defaults(fn=phone.dial_digit) | ||
enter_tool = FunctionTool.from_defaults(fn=phone.enter) | ||
|
||
task = Task( | ||
message="Dial the number 4151 then hit enter.", | ||
expected_response="4151", | ||
tools=[dial_digit_tool, enter_tool], | ||
eval_fn=phone.evaluate, | ||
) | ||
return task | ||
|
||
|
||
def get_search_then_dial() -> Task: | ||
phone = Phone() | ||
|
||
dial_digit_tool = FunctionTool.from_defaults(fn=phone.dial_digit) | ||
enter_tool = FunctionTool.from_defaults(fn=phone.enter) | ||
|
||
task = Task( | ||
message="Dial the number for john smith, then hit enter.", | ||
expected_response="2135", | ||
tools=[dial_digit_tool, enter_tool, search_number_tool], | ||
eval_fn=phone.evaluate, | ||
) | ||
return task | ||
|
||
|
||
TASKS: Dict[str, Callable[..., Task]] = { | ||
"dial_then_enter": get_dial_then_enter, | ||
"search_then_dial": get_search_then_dial, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
def contains_expected_response(response: str, expected_response: str) -> bool: | ||
"""Check if the response contains the expected response.""" | ||
return expected_response in response |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
from typing import List, cast | ||
|
||
import pandas as pd | ||
from agent_utils import AGENTS, ALL_MODELS, get_model, is_valid_combination | ||
from button_tasks import TASKS as BUTTON_TASKS | ||
from fire import Fire | ||
from math_tasks import TASKS as MATH_TASKS | ||
|
||
from llama_index.agent.types import BaseAgent | ||
|
||
ALL_TASKS = list(MATH_TASKS.keys()) + list(BUTTON_TASKS.keys()) | ||
|
||
|
||
def evaluate(agent: str, model: str, task_name: str, verbose: bool = False) -> bool: | ||
if task_name in MATH_TASKS: | ||
task = MATH_TASKS[task_name]() | ||
elif task_name in BUTTON_TASKS: | ||
task = BUTTON_TASKS[task_name]() | ||
else: | ||
raise ValueError(f"Unknown task {task_name}") | ||
|
||
print("=========================================") | ||
print(f"Evaluating | {agent} | {model} | {task.message} |") | ||
|
||
llm = get_model(model) | ||
agent_cls = AGENTS[agent] | ||
if agent == "react": | ||
additional_kwargs = {"max_iterations": 10} | ||
elif agent == "openai": | ||
additional_kwargs = {"max_function_calls": 10} | ||
else: | ||
raise ValueError(f"Unknown agent {agent}") | ||
|
||
agent_ = agent_cls.from_tools( # type: ignore | ||
tools=task.tools, | ||
llm=llm, | ||
verbose=verbose, | ||
**additional_kwargs, | ||
) # type: ignore | ||
agent_ = cast(BaseAgent, agent_) | ||
try: | ||
actual_response = agent_.chat(task.message).response | ||
outcome = task.eval_fn(actual_response, task.expected_response) | ||
except Exception as e: | ||
if verbose: | ||
print("Failed due to: ", e) | ||
|
||
actual_response = None | ||
outcome = False | ||
|
||
if verbose: | ||
print(f"Expected response: {task.expected_response}") | ||
print(f"Actual response: {actual_response}") | ||
print(f"Outcome: {outcome}") | ||
return outcome | ||
|
||
|
||
def benchmark( | ||
agents: List[str] = list(AGENTS.keys()), | ||
models: List[str] = ALL_MODELS, | ||
tasks: List[str] = ALL_TASKS, | ||
verbose: bool = False, | ||
output: str = "results.csv", | ||
save: bool = True, | ||
) -> pd.DataFrame: | ||
data = [] | ||
for agent in agents: | ||
for model in models: | ||
for task in tasks: | ||
if not is_valid_combination(agent, model): | ||
continue | ||
outcome = evaluate(agent, model, task, verbose) | ||
data.append( | ||
{ | ||
"agent": agent, | ||
"model": model, | ||
"task": task, | ||
"outcome": outcome, | ||
} | ||
) | ||
df = pd.DataFrame(data) | ||
if save: | ||
df.to_csv(output) | ||
return df | ||
|
||
|
||
if __name__ == "__main__": | ||
Fire(benchmark) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
from typing import Callable, Dict | ||
|
||
from eval import contains_expected_response | ||
from task import Task | ||
|
||
from llama_index.tools.function_tool import FunctionTool | ||
|
||
|
||
def add(a: int, b: int) -> int: | ||
"""Add two integers and returns the result integer""" | ||
return a + b | ||
|
||
|
||
def multiply(a: int, b: int) -> int: | ||
"""Multiple two integers and returns the result integer""" | ||
return a * b | ||
|
||
|
||
add_tool = FunctionTool.from_defaults(fn=add) | ||
multiply_tool = FunctionTool.from_defaults(fn=multiply) | ||
|
||
|
||
POWER_TASK = Task( | ||
message="What is 3 to the power of 4?", | ||
expected_response="81", | ||
tools=[add_tool, multiply_tool], | ||
eval_fn=contains_expected_response, | ||
) | ||
|
||
ADD_THEN_MULTIPLY_TASK = Task( | ||
message="What is 123 + 321 * 2?", | ||
expected_response="765", | ||
tools=[add_tool, multiply_tool], | ||
eval_fn=contains_expected_response, | ||
) | ||
|
||
|
||
TASKS: Dict[str, Callable[..., Task]] = { | ||
"add_then_multiply": lambda: ADD_THEN_MULTIPLY_TASK, | ||
"power": lambda: POWER_TASK, | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
,agent,model,task,outcome | ||
0,react,text-davinci-003,add_then_multiply,True | ||
1,react,text-davinci-003,power,False | ||
2,react,text-davinci-003,dial_then_enter,False | ||
3,react,text-davinci-003,search_then_dial,False | ||
4,react,gpt-3.5-turbo-0613,add_then_multiply,False | ||
5,react,gpt-3.5-turbo-0613,power,True | ||
6,react,gpt-3.5-turbo-0613,dial_then_enter,False | ||
7,react,gpt-3.5-turbo-0613,search_then_dial,False | ||
8,react,gpt-4-0613,add_then_multiply,True | ||
9,react,gpt-4-0613,power,True | ||
10,react,gpt-4-0613,dial_then_enter,True | ||
11,react,gpt-4-0613,search_then_dial,True | ||
12,react,claude-instant-1,add_then_multiply,False | ||
13,react,claude-instant-1,power,True | ||
14,react,claude-instant-1,dial_then_enter,False | ||
15,react,claude-instant-1,search_then_dial,False | ||
16,react,claude-2,add_then_multiply,True | ||
17,react,claude-2,power,False | ||
18,react,claude-2,dial_then_enter,True | ||
19,react,claude-2,search_then_dial,True | ||
20,openai,gpt-3.5-turbo-0613,add_then_multiply,False | ||
21,openai,gpt-3.5-turbo-0613,power,True | ||
22,openai,gpt-3.5-turbo-0613,dial_then_enter,True | ||
23,openai,gpt-3.5-turbo-0613,search_then_dial,False | ||
24,openai,gpt-4-0613,add_then_multiply,True | ||
25,openai,gpt-4-0613,power,True | ||
26,openai,gpt-4-0613,dial_then_enter,True | ||
27,openai,gpt-4-0613,search_then_dial,False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from typing import Callable, List | ||
|
||
from pydantic import BaseModel | ||
from llama_index.tools.types import BaseTool | ||
|
||
|
||
class Task(BaseModel): | ||
message: str | ||
expected_response: str | ||
tools: List[BaseTool] | ||
eval_fn: Callable[[str, str], bool] | ||
|
||
class Config: | ||
arbitrary_types_allowed = True |