Skip to content

Commit

Permalink
agent: fix wait --std-tools
Browse files Browse the repository at this point in the history
  • Loading branch information
ochafik committed Apr 21, 2024
1 parent 9b9f195 commit 2ba7150
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 46 deletions.
51 changes: 30 additions & 21 deletions examples/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from time import sleep
import typer
from pydantic import BaseModel, Json, TypeAdapter
from pydantic_core import SchemaValidator, core_schema
from typing import Annotated, Any, Callable, Dict, List, Union, Optional, Type
import json, requests

Expand All @@ -13,16 +14,12 @@
from examples.openai.prompting import ToolsPromptStyle
from examples.openai.subprocesses import spawn_subprocess

def _get_params_schema(fn: Callable[[Any], Any], verbose):
if isinstance(fn, OpenAPIMethod):
return fn.parameters_schema

# converter = SchemaConverter(prop_order={}, allow_fetch=False, dotall=False, raw_pattern=False)
schema = TypeAdapter(fn).json_schema()
# Do NOT call converter.resolve_refs(schema) here. Let the server resolve local refs.
if verbose:
sys.stderr.write(f'# PARAMS SCHEMA: {json.dumps(schema, indent=2)}\n')
return schema
def make_call_adapter(ta: TypeAdapter, fn: Callable[..., Any]):
args_validator = SchemaValidator(core_schema.call_schema(
arguments=ta.core_schema['arguments_schema'],
function=fn,
))
return lambda **kwargs: args_validator.validate_python(kwargs)

def completion_with_tool_usage(
*,
Expand Down Expand Up @@ -50,18 +47,28 @@ def completion_with_tool_usage(
schema = type_adapter.json_schema()
response_format=ResponseFormat(type="json_object", schema=schema)

tool_map = {fn.__name__: fn for fn in tools}
tools_schemas = [
Tool(
type="function",
function=ToolFunction(
name=fn.__name__,
description=fn.__doc__ or '',
parameters=_get_params_schema(fn, verbose=verbose)
tool_map = {}
tools_schemas = []
for fn in tools:
if isinstance(fn, OpenAPIMethod):
tool_map[fn.__name__] = fn
parameters_schema = fn.parameters_schema
else:
ta = TypeAdapter(fn)
tool_map[fn.__name__] = make_call_adapter(ta, fn)
parameters_schema = ta.json_schema()
if verbose:
sys.stderr.write(f'# PARAMS SCHEMA ({fn.__name__}): {json.dumps(parameters_schema, indent=2)}\n')
tools_schemas.append(
Tool(
type="function",
function=ToolFunction(
name=fn.__name__,
description=fn.__doc__ or '',
parameters=parameters_schema,
)
)
)
for fn in tools
]

i = 0
while (max_iterations is None or i < max_iterations):
Expand Down Expand Up @@ -106,7 +113,7 @@ def completion_with_tool_usage(
sys.stdout.write(f'⚙️ {pretty_call}')
sys.stdout.flush()
tool_result = tool_map[tool_call.function.name](**tool_call.function.arguments)
sys.stdout.write(f" -> {tool_result}\n")
sys.stdout.write(f" {tool_result}\n")
messages.append(Message(
tool_call_id=tool_call.id,
role="tool",
Expand Down Expand Up @@ -203,6 +210,8 @@ def main(
if std_tools:
tool_functions.extend(collect_functions(StandardTools))

sys.stdout.write(f'🛠️ {", ".join(fn.__name__ for fn in tool_functions)}\n')

response_model: Union[type, Json[Any]] = None #str
if format:
if format in types:
Expand Down
62 changes: 37 additions & 25 deletions examples/agent/tools/std_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,18 @@ class Duration(BaseModel):
years: Optional[int] = None

def __str__(self) -> str:
return f"{self.years} years, {self.months} months, {self.days} days, {self.hours} hours, {self.minutes} minutes, {self.seconds} seconds"
return ', '.join([
x
for x in [
f"{self.years} years" if self.years else None,
f"{self.months} months" if self.months else None,
f"{self.days} days" if self.days else None,
f"{self.hours} hours" if self.hours else None,
f"{self.minutes} minutes" if self.minutes else None,
f"{self.seconds} seconds" if self.seconds else None,
]
if x is not None
])

@property
def get_total_seconds(self) -> int:
Expand All @@ -36,25 +47,6 @@ def __call__(self):
sys.stderr.write(f"Waiting for {self.duration}...\n")
time.sleep(self.duration.get_total_seconds)

class WaitForDate(BaseModel):
until: date

def __call__(self):
# Get the current date
current_date = datetime.date.today()

if self.until < current_date:
raise ValueError("Target date cannot be in the past.")

time_diff = datetime.datetime.combine(self.until, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)

days, seconds = time_diff.days, time_diff.seconds

sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {self.until}...\n")
time.sleep(days * 86400 + seconds)
sys.stderr.write(f"Reached the target date: {self.until}\n")


class StandardTools:

@staticmethod
Expand All @@ -66,12 +58,32 @@ def ask_user(question: str) -> str:
return typer.prompt(question)

@staticmethod
def wait(_for: Union[WaitForDuration, WaitForDate]) -> None:
'''
Wait for a certain amount of time before continuing.
This can be used to wait for a specific duration or until a specific date.
def wait_for_duration(duration: Duration) -> None:
'Wait for a certain amount of time before continuing.'

# sys.stderr.write(f"Waiting for {duration}...\n")
time.sleep(duration.get_total_seconds)

@staticmethod
def wait_for_date(target_date: date) -> None:
f'''
Wait until a specific date is reached before continuing.
Today's date is {datetime.date.today()}
'''
return _for()

# Get the current date
current_date = datetime.date.today()

if target_date < current_date:
raise ValueError("Target date cannot be in the past.")

time_diff = datetime.datetime.combine(target_date, datetime.time.min) - datetime.datetime.combine(current_date, datetime.time.min)

days, seconds = time_diff.days, time_diff.seconds

# sys.stderr.write(f"Waiting for {days} days and {seconds} seconds until {target_date}...\n")
time.sleep(days * 86400 + seconds)
# sys.stderr.write(f"Reached the target date: {target_date}\n")

@staticmethod
def say_out_loud(something: str) -> None:
Expand Down

0 comments on commit 2ba7150

Please sign in to comment.