diff --git a/evals/cli/oaieval.py b/evals/cli/oaieval.py index a72b194ae0..cfb5ea4248 100644 --- a/evals/cli/oaieval.py +++ b/evals/cli/oaieval.py @@ -5,7 +5,7 @@ import logging import shlex import sys -from typing import Any, Mapping, Optional +from typing import Any, Mapping, Optional, Union, cast import openai @@ -13,12 +13,13 @@ import evals.api import evals.base import evals.record +from evals.eval import Eval from evals.registry import Registry logger = logging.getLogger(__name__) -def _purple(str): +def _purple(str: str) -> str: return f"\033[1;35m{str}\033[0m" @@ -41,7 +42,11 @@ def get_parser() -> argparse.ArgumentParser: "--log_to_file", type=str, default=None, help="Log to a file instead of stdout" ) parser.add_argument( - "--registry_path", type=str, default=None, action="append", help="Path to the registry" + "--registry_path", + type=str, + default=None, + action="append", + help="Path to the registry", ) parser.add_argument("--debug", action=argparse.BooleanOptionalAction, default=False) parser.add_argument("--local-run", action=argparse.BooleanOptionalAction, default=True) @@ -50,7 +55,25 @@ def get_parser() -> argparse.ArgumentParser: return parser -def run(args, registry: Optional[Registry] = None): +class OaiEvalArguments(argparse.Namespace): + completion_fn: str + eval: str + extra_eval_params: str + max_samples: Optional[int] + cache: bool + visible: Optional[bool] + seed: int + user: str + record_path: Optional[str] + log_to_file: Optional[str] + registry_path: Optional[str] + debug: bool + local_run: bool + dry_run: bool + dry_run_logging: bool + + +def run(args: OaiEvalArguments, registry: Optional[Registry] = None) -> str: if args.debug: logging.getLogger().setLevel(logging.DEBUG) @@ -61,7 +84,7 @@ def run(args, registry: Optional[Registry] = None): registry = registry or Registry() if args.registry_path: - registry.add_registry_paths(args.registry_path) + registry.add_registry_paths([args.registry_path]) eval_spec = registry.get_eval(args.eval) assert ( @@ -83,6 +106,9 @@ def run(args, registry: Optional[Registry] = None): } eval_name = eval_spec.key + if eval_name is None: + raise Exception("you must provide a eval name") + run_spec = evals.base.RunSpec( completion_fns=completion_fns, eval_name=eval_name, @@ -95,6 +121,8 @@ def run(args, registry: Optional[Registry] = None): record_path = f"/tmp/evallogs/{run_spec.run_id}_{args.completion_fn}_{args.eval}.jsonl" else: record_path = args.record_path + + recorder: evals.record.RecorderBase if args.dry_run: recorder = evals.record.DummyRecorder(run_spec=run_spec, log=args.dry_run_logging) elif args.local_run: @@ -102,19 +130,21 @@ def run(args, registry: Optional[Registry] = None): else: recorder = evals.record.Recorder(record_path, run_spec=run_spec) - api_extra_options = {} + api_extra_options: dict[str, Any] = {} if not args.cache: api_extra_options["cache_level"] = 0 run_url = f"{run_spec.run_id}" logger.info(_purple(f"Run started: {run_url}")) - def parse_extra_eval_params(param_str: Optional[str]) -> Mapping[str, Any]: + def parse_extra_eval_params( + param_str: Optional[str], + ) -> Mapping[str, Union[str, int, float]]: """Parse a string of the form "key1=value1,key2=value2" into a dict.""" if not param_str: return {} - def to_number(x): + def to_number(x: str) -> Union[int, float, str]: try: return int(x) except: @@ -131,7 +161,7 @@ def to_number(x): extra_eval_params = parse_extra_eval_params(args.extra_eval_params) eval_class = registry.get_class(eval_spec) - eval = eval_class( + eval: Eval = eval_class( completion_fns=completion_fn_instances, seed=args.seed, name=eval_name, @@ -150,17 +180,19 @@ def to_number(x): return run_spec.run_id -def main(): +def main() -> None: parser = get_parser() - args = parser.parse_args(sys.argv[1:]) + args = cast(OaiEvalArguments, parser.parse_args(sys.argv[1:])) logging.basicConfig( format="[%(asctime)s] [%(filename)s:%(lineno)d] %(message)s", level=logging.INFO, filename=args.log_to_file if args.log_to_file else None, ) logging.getLogger("openai").setLevel(logging.WARN) - if hasattr(openai.error, "set_display_cause"): - openai.error.set_display_cause() + + # TODO)) why do we need this? + if hasattr(openai.error, "set_display_cause"): # type: ignore + openai.error.set_display_cause() # type: ignore run(args) diff --git a/evals/utils/misc.py b/evals/utils/misc.py index 2e19570e16..da8685500e 100644 --- a/evals/utils/misc.py +++ b/evals/utils/misc.py @@ -17,7 +17,7 @@ def t(duration: float) -> str: return f"{duration//60}min{int(duration%60)}s" -def make_object(object_ref: Any, *args: Any, **kwargs: Any) -> Any: +def make_object(object_ref: str, *args: Any, **kwargs: Any) -> Any: modname, qualname_separator, qualname = object_ref.partition(":") obj = importlib.import_module(modname) if qualname_separator: diff --git a/mypy.ini b/mypy.ini index 62225d735f..e7656d5345 100644 --- a/mypy.ini +++ b/mypy.ini @@ -30,6 +30,10 @@ disallow_untyped_defs=True ignore_errors=False disallow_untyped_defs=True +[mypy-evals.cli.oaieval] +ignore_errors=False +disallow_untyped_defs=True + [mypy-scripts.*] ignore_errors=False disallow_untyped_defs=True