diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml new file mode 100644 index 00000000..a4e30f9e --- /dev/null +++ b/.github/workflows/benchmark.yml @@ -0,0 +1,55 @@ +name: Benchmark + +# Do not run this workflow on pull request since this workflow has permission to modify contents. +on: + push: + branches: + - master + workflow_dispatch: {} + +permissions: + # deployments permission to deploy GitHub pages website + contents: write + # contents permission to update benchmark contents in gh-pages branch + deployments: write + +jobs: + benchmark: + name: Run pytest-benchmark benchmark example + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v3 + with: + python-version: 3.11 + cache: "pip" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[all] + + - name: Run benchmark + run: | + pytest --benchmark-only --benchmark-json=.benchmark_output.json + + - name: Store benchmark result + uses: benchmark-action/github-action-benchmark@v1 + with: + name: Python Benchmark with pytest-benchmark + tool: 'pytest' + # Where the output from the benchmark tool is stored + output-file-path: .benchmark_output.json + # # Where the previous data file is stored + # external-data-json-path: ./cache/benchmark-data.json + # Use personal access token instead of GITHUB_TOKEN due to https://github.community/t/github-action-not-triggering-gh-pages-upon-push/16096 + github-token: ${{ secrets.GITHUB_TOKEN }} + # NOTE: auto-push must be false when external-data-json-path is set since this action + # reads/writes the given JSON file and never pushes to remote + auto-push: true + # Show alert with commit comment on detecting possible performance regression + alert-threshold: '200%' + comment-on-alert: true + # Workflow will fail when an alert happens + fail-on-alert: true + alert-comment-cc-users: '@lebrice' diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 19b96d10..d8610dbd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -21,6 +21,6 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[all] - - name: Test with pytest + - name: Unit tests with Pytest run: | - pytest + pytest --benchmark-disable diff --git a/pytest.ini b/pytest.ini index 4bc99bbe..b5900ab6 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,4 +1,4 @@ [pytest] -addopts = --doctest-modules +addopts = --doctest-modules --benchmark-autosave testpaths = test simple_parsing norecursedirs = examples .git .tox .eggs dist build docs *.egg diff --git a/setup.py b/setup.py index c50b4707..87cd64d8 100644 --- a/setup.py +++ b/setup.py @@ -20,6 +20,7 @@ "pytest", "pytest-xdist", "pytest-regressions", + "pytest-benchmark", "numpy", # "torch", ], diff --git a/simple_parsing/decorators.py b/simple_parsing/decorators.py index c6235eef..68580eab 100644 --- a/simple_parsing/decorators.py +++ b/simple_parsing/decorators.py @@ -8,7 +8,7 @@ from typing import Any, Callable, NamedTuple import docstring_parser as dp - +from simple_parsing.docstring import dp_parse, inspect_getdoc from . import helpers, parsing @@ -55,7 +55,7 @@ def _wrapper(*other_args, **other_kwargs) -> Any: parameters = signature.parameters # Parse docstring to use as help strings - docstring = dp.parse(inspect.getdoc(function) or "") + docstring = dp_parse(inspect_getdoc(function) or "") docstring_param_description = { param.arg_name: param.description for param in docstring.params } diff --git a/simple_parsing/docstring.py b/simple_parsing/docstring.py index c7bc16e5..17d52d0a 100644 --- a/simple_parsing/docstring.py +++ b/simple_parsing/docstring.py @@ -5,12 +5,17 @@ import functools import inspect + +# from inspect import from dataclasses import dataclass from logging import getLogger import docstring_parser as dp from docstring_parser.common import Docstring +dp_parse = functools.lru_cache(2048)(dp.parse) +inspect_getsource = functools.lru_cache(2048)(inspect.getsource) +inspect_getdoc = functools.lru_cache(2048)(inspect.getdoc) logger = getLogger(__name__) @@ -102,7 +107,7 @@ def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocSt Doesn't inspect base classes. """ try: - source = inspect.getsource(dataclass) + source = inspect_getsource(dataclass) except (TypeError, OSError) as e: logger.debug( UserWarning( @@ -114,9 +119,9 @@ def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocSt # Parse docstring to use as help strings desc_from_cls_docstring = "" - cls_docstring = inspect.getdoc(dataclass) + cls_docstring = inspect_getdoc(dataclass) if cls_docstring: - docstring: Docstring = dp.parse(cls_docstring) + docstring: Docstring = dp_parse(cls_docstring) for param in docstring.params: if param.arg_name == field_name: desc_from_cls_docstring = param.description or "" diff --git a/simple_parsing/helpers/hparams/hyperparameters.py b/simple_parsing/helpers/hparams/hyperparameters.py index d81f84f2..47a3aa9e 100644 --- a/simple_parsing/helpers/hparams/hyperparameters.py +++ b/simple_parsing/helpers/hparams/hyperparameters.py @@ -1,3 +1,4 @@ +from __future__ import annotations import dataclasses import inspect import math @@ -9,6 +10,7 @@ from logging import getLogger from pathlib import Path from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Type, TypeVar +import typing from simple_parsing import utils from simple_parsing.helpers.serialization.serializable import Serializable @@ -22,18 +24,13 @@ from .hparam import ValueOutsidePriorException from .priors import Prior +if typing.TYPE_CHECKING: + import numpy + logger = getLogger(__name__) T = TypeVar("T") HP = TypeVar("HP", bound="HyperParameters") -numpy_installed = False -try: - import numpy as np - - numpy_installed = True -except ImportError: - pass - @dataclass class BoundInfo(Serializable): @@ -51,10 +48,8 @@ class HyperParameters(Serializable, decode_into_subclasses=True): # type: ignor # Class variable holding the random number generator used to create the # samples. - if numpy_installed: - np_rng: ClassVar[np.random.RandomState] = np.random - else: - rng: ClassVar[random.Random] = random.Random() + + rng: ClassVar[random.Random] = random.Random() def __post_init__(self): for name, f in field_dict(self).items(): @@ -141,8 +136,7 @@ def space_id(cls): def get_bounds(cls) -> List[BoundInfo]: """Returns the bounds of the search domain for this type of HParam. - Returns them as a list of `BoundInfo` objects, in the format expected by - GPyOpt. + Returns them as a list of `BoundInfo` objects, in the format expected by GPyOpt. """ bounds: List[BoundInfo] = [] for f in fields(cls): @@ -162,9 +156,8 @@ def get_bounds(cls) -> List[BoundInfo]: @classmethod def get_bounds_dicts(cls) -> List[Dict[str, Any]]: - """Returns the bounds of the search space for this type of HParam, - in the format expected by the `GPyOpt` package. - """ + """Returns the bounds of the search space for this type of HParam, in the format expected + by the `GPyOpt` package.""" return [b.to_dict() for b in cls.get_bounds()] @classmethod @@ -186,13 +179,15 @@ def sample(cls): else: prior: Optional[Prior] = field.metadata.get("prior") if prior is not None: - if numpy_installed: - prior.np_rng = cls.np_rng - else: + try: + import numpy as np + + prior.np_rng = np.random + except ImportError: prior.rng = cls.rng value = prior.sample() shape = getattr(prior, "shape", None) - if numpy_installed and isinstance(value, np.ndarray) and not shape: + if shape == () and hasattr(value, "item") and callable(value.item): value = value.item() kwargs[field.name] = value return cls(**kwargs) @@ -202,43 +197,46 @@ def replace(self, **new_params): new_hp = type(self).from_dict(new_hp_dict) return new_hp - # @classmethod - # @contextmanager - # def use_priors(cls, value: bool = True): - # temp = cls.sample_from_priors - # cls.sample_from_priors = value - # yield - # cls.sample_from_priors = temp - - if numpy_installed: + # @classmethod + # @contextmanager + # def use_priors(cls, value: bool = True): + # temp = cls.sample_from_priors + # cls.sample_from_priors = value + # yield + # cls.sample_from_priors = temp + + def to_array(self, dtype: numpy.dtype | None = None) -> numpy.ndarray: + import numpy as np + + dtype = np.float32 if dtype is None else dtype + values: List[float] = [] + for k, v in self.to_dict(dict_factory=OrderedDict).items(): + try: + v = float(v) + except Exception: + logger.warning(f"Ignoring field {k} because we can't make a float out of it.") + else: + values.append(v) + return np.array(values, dtype=dtype) - def to_array(self, dtype=np.float32) -> np.ndarray: - values: List[float] = [] - for k, v in self.to_dict(dict_factory=OrderedDict).items(): - try: - v = float(v) - except Exception: - logger.warning(f"Ignoring field {k} because we can't make a float out of it.") - else: - values.append(v) - return np.array(values, dtype=dtype) - - @classmethod - def from_array(cls: Type[HP], array: np.ndarray) -> HP: - if len(array.shape) == 2 and array.shape[0] == 1: - array = array[0] - - keys = list(field_dict(cls)) - # idea: could use to_dict and to_array together to determine how many - # values to get for each field. For now we assume that each field is one - # variable. - # cls.sample().to_dict() - # assert len(keys) == len(array), "assuming that each field is dim 1 for now." - assert len(keys) == len(array), "assuming that each field is dim 1 for now." - d = OrderedDict(zip(keys, array)) - logger.debug(f"Creating an instance of {cls} using args {d}") - d = OrderedDict((k, v.item() if isinstance(v, np.ndarray) else v) for k, v in d.items()) - return cls.from_dict(d) + @classmethod + def from_array(cls: Type[HP], array: numpy.ndarray) -> HP: + import numpy as np + + if len(array.shape) == 2 and array.shape[0] == 1: + array = array[0] + + keys = list(field_dict(cls)) + # idea: could use to_dict and to_array together to determine how many + # values to get for each field. For now we assume that each field is one + # variable. + # cls.sample().to_dict() + # assert len(keys) == len(array), "assuming that each field is dim 1 for now." + assert len(keys) == len(array), "assuming that each field is dim 1 for now." + d = OrderedDict(zip(keys, array)) + logger.debug(f"Creating an instance of {cls} using args {d}") + d = OrderedDict((k, v.item() if isinstance(v, np.ndarray) else v) for k, v in d.items()) + return cls.from_dict(d) def clip_within_bounds(self: HP) -> HP: d = self.to_dict() diff --git a/simple_parsing/wrappers/dataclass_wrapper.py b/simple_parsing/wrappers/dataclass_wrapper.py index 2c00c98d..6a8c2bf4 100644 --- a/simple_parsing/wrappers/dataclass_wrapper.py +++ b/simple_parsing/wrappers/dataclass_wrapper.py @@ -10,13 +10,13 @@ from logging import getLogger from typing import Any, Callable, Generic, TypeVar, cast -import docstring_parser as dp from typing_extensions import Literal -from .. import docstring, utils -from ..utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type -from .field_wrapper import FieldWrapper -from .wrapper import Wrapper +from simple_parsing import docstring, utils +from simple_parsing.docstring import dp_parse, inspect_getdoc +from simple_parsing.utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type +from simple_parsing.wrappers.field_wrapper import FieldWrapper +from simple_parsing.wrappers.wrapper import Wrapper logger = getLogger(__name__) @@ -327,11 +327,11 @@ def description(self) -> str: # NOTE: The class docstring may be EXTRELEMY LARGE. - class_docstring = inspect.getdoc(self.dataclass) or "" + class_docstring = inspect_getdoc(self.dataclass) or "" if not class_docstring: return "" - doc = dp.parse(class_docstring) + doc = dp_parse(class_docstring) from simple_parsing.decorators import _description_from_docstring diff --git a/test/test_performance.py b/test/test_performance.py new file mode 100644 index 00000000..566a8a5d --- /dev/null +++ b/test/test_performance.py @@ -0,0 +1,59 @@ +import functools +import importlib +import sys +from typing import Callable, TypeVar +import pytest +from pytest_benchmark.fixture import BenchmarkFixture + +C = TypeVar("C", bound=Callable) + + +def import_sp(): + assert "simple_parsing" not in sys.modules + __import__("simple_parsing") + + +def unimport_sp(): + if "simple_parsing" in sys.modules: + import simple_parsing # noqa + + del simple_parsing + importlib.invalidate_caches() + sys.modules.pop("simple_parsing") + assert "simple_parsing" not in sys.modules + + +def clear_lru_caches(): + from simple_parsing.docstring import dp_parse, inspect_getdoc, inspect_getsource + + dp_parse.cache_clear() + inspect_getdoc.cache_clear() + inspect_getsource.cache_clear() + + +def call_before(before: Callable[[], None], fn: C) -> C: + @functools.wraps(fn) + def wrapped(*args, **kwargs): + before() + return fn(*args, **kwargs) + + return wrapped # type: ignore + + +@pytest.mark.benchmark( + group="import", +) +def test_import_performance(benchmark: BenchmarkFixture): + # NOTE: Issue is that the `conftest.py` actually already imports simple-parsing! + benchmark(call_before(unimport_sp, import_sp)) + + +def test_parse_performance(benchmark: BenchmarkFixture): + import simple_parsing as sp + from test.nesting.example_use_cases import HyperParameters + + benchmark( + call_before(clear_lru_caches, sp.parse), + HyperParameters, + args="--age_group.num_layers 5 --age_group.num_units 65 ", + )