Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 45 additions & 13 deletions evals/cli/oaieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,21 @@
import logging
import shlex
import sys
from typing import Any, Mapping, Optional
from typing import Any, Mapping, Optional, Union, cast

import openai

import evals
import evals.api
import evals.base
import evals.record
from evals.eval import Eval
from evals.registry import Registry

logger = logging.getLogger(__name__)


def _purple(str):
def _purple(str: str) -> str:
return f"\033[1;35m{str}\033[0m"


Expand All @@ -41,7 +42,11 @@ def get_parser() -> argparse.ArgumentParser:
"--log_to_file", type=str, default=None, help="Log to a file instead of stdout"
)
parser.add_argument(
"--registry_path", type=str, default=None, action="append", help="Path to the registry"
"--registry_path",
type=str,
default=None,
action="append",
help="Path to the registry",
)
parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--local-run", action=argparse.BooleanOptionalAction, default=True)
Expand All @@ -50,7 +55,25 @@ def get_parser() -> argparse.ArgumentParser:
return parser


def run(args, registry: Optional[Registry] = None):
class OaiEvalArguments(argparse.Namespace):
completion_fn: str
eval: str
extra_eval_params: str
max_samples: Optional[int]
cache: bool
visible: Optional[bool]
seed: int
user: str
record_path: Optional[str]
log_to_file: Optional[str]
registry_path: Optional[str]
debug: bool
local_run: bool
dry_run: bool
dry_run_logging: bool


def run(args: OaiEvalArguments, registry: Optional[Registry] = None) -> str:
if args.debug:
logging.getLogger().setLevel(logging.DEBUG)

Expand All @@ -61,7 +84,7 @@ def run(args, registry: Optional[Registry] = None):

registry = registry or Registry()
if args.registry_path:
registry.add_registry_paths(args.registry_path)
registry.add_registry_paths([args.registry_path])

eval_spec = registry.get_eval(args.eval)
assert (
Expand All @@ -83,6 +106,9 @@ def run(args, registry: Optional[Registry] = None):
}

eval_name = eval_spec.key
if eval_name is None:
raise Exception("you must provide a eval name")

run_spec = evals.base.RunSpec(
completion_fns=completion_fns,
eval_name=eval_name,
Expand All @@ -95,26 +121,30 @@ def run(args, registry: Optional[Registry] = None):
record_path = f"/tmp/evallogs/{run_spec.run_id}_{args.completion_fn}_{args.eval}.jsonl"
else:
record_path = args.record_path

recorder: evals.record.RecorderBase
if args.dry_run:
recorder = evals.record.DummyRecorder(run_spec=run_spec, log=args.dry_run_logging)
elif args.local_run:
recorder = evals.record.LocalRecorder(record_path, run_spec=run_spec)
else:
recorder = evals.record.Recorder(record_path, run_spec=run_spec)

api_extra_options = {}
api_extra_options: dict[str, Any] = {}
if not args.cache:
api_extra_options["cache_level"] = 0

run_url = f"{run_spec.run_id}"
logger.info(_purple(f"Run started: {run_url}"))

def parse_extra_eval_params(param_str: Optional[str]) -> Mapping[str, Any]:
def parse_extra_eval_params(
param_str: Optional[str],
) -> Mapping[str, Union[str, int, float]]:
"""Parse a string of the form "key1=value1,key2=value2" into a dict."""
if not param_str:
return {}

def to_number(x):
def to_number(x: str) -> Union[int, float, str]:
try:
return int(x)
except:
Expand All @@ -131,7 +161,7 @@ def to_number(x):
extra_eval_params = parse_extra_eval_params(args.extra_eval_params)

eval_class = registry.get_class(eval_spec)
eval = eval_class(
eval: Eval = eval_class(
completion_fns=completion_fn_instances,
seed=args.seed,
name=eval_name,
Expand All @@ -150,17 +180,19 @@ def to_number(x):
return run_spec.run_id


def main():
def main() -> None:
parser = get_parser()
args = parser.parse_args(sys.argv[1:])
args = cast(OaiEvalArguments, parser.parse_args(sys.argv[1:]))
logging.basicConfig(
format="[%(asctime)s] [%(filename)s:%(lineno)d] %(message)s",
level=logging.INFO,
filename=args.log_to_file if args.log_to_file else None,
)
logging.getLogger("openai").setLevel(logging.WARN)
if hasattr(openai.error, "set_display_cause"):
openai.error.set_display_cause()

# TODO)) why do we need this?
if hasattr(openai.error, "set_display_cause"): # type: ignore
openai.error.set_display_cause() # type: ignore
run(args)


Expand Down
2 changes: 1 addition & 1 deletion evals/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def t(duration: float) -> str:
return f"{duration//60}min{int(duration%60)}s"


def make_object(object_ref: Any, *args: Any, **kwargs: Any) -> Any:
def make_object(object_ref: str, *args: Any, **kwargs: Any) -> Any:
modname, qualname_separator, qualname = object_ref.partition(":")
obj = importlib.import_module(modname)
if qualname_separator:
Expand Down
4 changes: 4 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ disallow_untyped_defs=True
ignore_errors=False
disallow_untyped_defs=True

[mypy-evals.cli.oaieval]
ignore_errors=False
disallow_untyped_defs=True

[mypy-scripts.*]
ignore_errors=False
disallow_untyped_defs=True
Expand Down