Skip to content

Commit

Permalink
Increase import performance (lru_cache) and add pytest-benchmark (#279)
Browse files Browse the repository at this point in the history
* Add pytest-benchmark test dependency

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Add performance files for before the changes

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Use cached versions of inspect.getdoc/getsource

Fixes #278

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Make numpy import lazy

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Simplify the benchmark code a bit

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove .benchmarks file from git history

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Playing around with GitHub actions for benchmark

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Remove upload workflow, add benchmark

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Tweak benchmark.yml and build.yml

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

* Only run workflow on push to master

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>

---------

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Aug 3, 2023
1 parent cd6fd81 commit b3d12de
Show file tree
Hide file tree
Showing 9 changed files with 190 additions and 72 deletions.
55 changes: 55 additions & 0 deletions .github/workflows/benchmark.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
name: Benchmark

# Do not run this workflow on pull request since this workflow has permission to modify contents.
on:
push:
branches:
- master
workflow_dispatch: {}

permissions:
# deployments permission to deploy GitHub pages website
contents: write
# contents permission to update benchmark contents in gh-pages branch
deployments: write

jobs:
benchmark:
name: Run pytest-benchmark benchmark example
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python 3.11
uses: actions/setup-python@v3
with:
python-version: 3.11
cache: "pip"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .[all]
- name: Run benchmark
run: |
pytest --benchmark-only --benchmark-json=.benchmark_output.json
- name: Store benchmark result
uses: benchmark-action/github-action-benchmark@v1
with:
name: Python Benchmark with pytest-benchmark
tool: 'pytest'
# Where the output from the benchmark tool is stored
output-file-path: .benchmark_output.json
# # Where the previous data file is stored
# external-data-json-path: ./cache/benchmark-data.json
# Use personal access token instead of GITHUB_TOKEN due to https://github.community/t/github-action-not-triggering-gh-pages-upon-push/16096
github-token: ${{ secrets.GITHUB_TOKEN }}
# NOTE: auto-push must be false when external-data-json-path is set since this action
# reads/writes the given JSON file and never pushes to remote
auto-push: true
# Show alert with commit comment on detecting possible performance regression
alert-threshold: '200%'
comment-on-alert: true
# Workflow will fail when an alert happens
fail-on-alert: true
alert-comment-cc-users: '@lebrice'
4 changes: 2 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,6 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .[all]
- name: Test with pytest
- name: Unit tests with Pytest
run: |
pytest
pytest --benchmark-disable
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
[pytest]
addopts = --doctest-modules
addopts = --doctest-modules --benchmark-autosave
testpaths = test simple_parsing
norecursedirs = examples .git .tox .eggs dist build docs *.egg
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
"pytest",
"pytest-xdist",
"pytest-regressions",
"pytest-benchmark",
"numpy",
# "torch",
],
Expand Down
4 changes: 2 additions & 2 deletions simple_parsing/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from typing import Any, Callable, NamedTuple

import docstring_parser as dp

from simple_parsing.docstring import dp_parse, inspect_getdoc
from . import helpers, parsing


Expand Down Expand Up @@ -55,7 +55,7 @@ def _wrapper(*other_args, **other_kwargs) -> Any:
parameters = signature.parameters

# Parse docstring to use as help strings
docstring = dp.parse(inspect.getdoc(function) or "")
docstring = dp_parse(inspect_getdoc(function) or "")
docstring_param_description = {
param.arg_name: param.description for param in docstring.params
}
Expand Down
11 changes: 8 additions & 3 deletions simple_parsing/docstring.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,17 @@

import functools
import inspect

# from inspect import
from dataclasses import dataclass
from logging import getLogger

import docstring_parser as dp
from docstring_parser.common import Docstring

dp_parse = functools.lru_cache(2048)(dp.parse)
inspect_getsource = functools.lru_cache(2048)(inspect.getsource)
inspect_getdoc = functools.lru_cache(2048)(inspect.getdoc)
logger = getLogger(__name__)


Expand Down Expand Up @@ -102,7 +107,7 @@ def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocSt
Doesn't inspect base classes.
"""
try:
source = inspect.getsource(dataclass)
source = inspect_getsource(dataclass)
except (TypeError, OSError) as e:
logger.debug(
UserWarning(
Expand All @@ -114,9 +119,9 @@ def _get_attribute_docstring(dataclass: type, field_name: str) -> AttributeDocSt

# Parse docstring to use as help strings
desc_from_cls_docstring = ""
cls_docstring = inspect.getdoc(dataclass)
cls_docstring = inspect_getdoc(dataclass)
if cls_docstring:
docstring: Docstring = dp.parse(cls_docstring)
docstring: Docstring = dp_parse(cls_docstring)
for param in docstring.params:
if param.arg_name == field_name:
desc_from_cls_docstring = param.description or ""
Expand Down
112 changes: 55 additions & 57 deletions simple_parsing/helpers/hparams/hyperparameters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import dataclasses
import inspect
import math
Expand All @@ -9,6 +10,7 @@
from logging import getLogger
from pathlib import Path
from typing import Any, ClassVar, Dict, List, NamedTuple, Optional, Tuple, Type, TypeVar
import typing

from simple_parsing import utils
from simple_parsing.helpers.serialization.serializable import Serializable
Expand All @@ -22,18 +24,13 @@
from .hparam import ValueOutsidePriorException
from .priors import Prior

if typing.TYPE_CHECKING:
import numpy

logger = getLogger(__name__)
T = TypeVar("T")
HP = TypeVar("HP", bound="HyperParameters")

numpy_installed = False
try:
import numpy as np

numpy_installed = True
except ImportError:
pass


@dataclass
class BoundInfo(Serializable):
Expand All @@ -51,10 +48,8 @@ class HyperParameters(Serializable, decode_into_subclasses=True): # type: ignor

# Class variable holding the random number generator used to create the
# samples.
if numpy_installed:
np_rng: ClassVar[np.random.RandomState] = np.random
else:
rng: ClassVar[random.Random] = random.Random()

rng: ClassVar[random.Random] = random.Random()

def __post_init__(self):
for name, f in field_dict(self).items():
Expand Down Expand Up @@ -141,8 +136,7 @@ def space_id(cls):
def get_bounds(cls) -> List[BoundInfo]:
"""Returns the bounds of the search domain for this type of HParam.
Returns them as a list of `BoundInfo` objects, in the format expected by
GPyOpt.
Returns them as a list of `BoundInfo` objects, in the format expected by GPyOpt.
"""
bounds: List[BoundInfo] = []
for f in fields(cls):
Expand All @@ -162,9 +156,8 @@ def get_bounds(cls) -> List[BoundInfo]:

@classmethod
def get_bounds_dicts(cls) -> List[Dict[str, Any]]:
"""Returns the bounds of the search space for this type of HParam,
in the format expected by the `GPyOpt` package.
"""
"""Returns the bounds of the search space for this type of HParam, in the format expected
by the `GPyOpt` package."""
return [b.to_dict() for b in cls.get_bounds()]

@classmethod
Expand All @@ -186,13 +179,15 @@ def sample(cls):
else:
prior: Optional[Prior] = field.metadata.get("prior")
if prior is not None:
if numpy_installed:
prior.np_rng = cls.np_rng
else:
try:
import numpy as np

prior.np_rng = np.random
except ImportError:
prior.rng = cls.rng
value = prior.sample()
shape = getattr(prior, "shape", None)
if numpy_installed and isinstance(value, np.ndarray) and not shape:
if shape == () and hasattr(value, "item") and callable(value.item):
value = value.item()
kwargs[field.name] = value
return cls(**kwargs)
Expand All @@ -202,43 +197,46 @@ def replace(self, **new_params):
new_hp = type(self).from_dict(new_hp_dict)
return new_hp

# @classmethod
# @contextmanager
# def use_priors(cls, value: bool = True):
# temp = cls.sample_from_priors
# cls.sample_from_priors = value
# yield
# cls.sample_from_priors = temp

if numpy_installed:
# @classmethod
# @contextmanager
# def use_priors(cls, value: bool = True):
# temp = cls.sample_from_priors
# cls.sample_from_priors = value
# yield
# cls.sample_from_priors = temp

def to_array(self, dtype: numpy.dtype | None = None) -> numpy.ndarray:
import numpy as np

dtype = np.float32 if dtype is None else dtype
values: List[float] = []
for k, v in self.to_dict(dict_factory=OrderedDict).items():
try:
v = float(v)
except Exception:
logger.warning(f"Ignoring field {k} because we can't make a float out of it.")
else:
values.append(v)
return np.array(values, dtype=dtype)

def to_array(self, dtype=np.float32) -> np.ndarray:
values: List[float] = []
for k, v in self.to_dict(dict_factory=OrderedDict).items():
try:
v = float(v)
except Exception:
logger.warning(f"Ignoring field {k} because we can't make a float out of it.")
else:
values.append(v)
return np.array(values, dtype=dtype)

@classmethod
def from_array(cls: Type[HP], array: np.ndarray) -> HP:
if len(array.shape) == 2 and array.shape[0] == 1:
array = array[0]

keys = list(field_dict(cls))
# idea: could use to_dict and to_array together to determine how many
# values to get for each field. For now we assume that each field is one
# variable.
# cls.sample().to_dict()
# assert len(keys) == len(array), "assuming that each field is dim 1 for now."
assert len(keys) == len(array), "assuming that each field is dim 1 for now."
d = OrderedDict(zip(keys, array))
logger.debug(f"Creating an instance of {cls} using args {d}")
d = OrderedDict((k, v.item() if isinstance(v, np.ndarray) else v) for k, v in d.items())
return cls.from_dict(d)
@classmethod
def from_array(cls: Type[HP], array: numpy.ndarray) -> HP:
import numpy as np

if len(array.shape) == 2 and array.shape[0] == 1:
array = array[0]

keys = list(field_dict(cls))
# idea: could use to_dict and to_array together to determine how many
# values to get for each field. For now we assume that each field is one
# variable.
# cls.sample().to_dict()
# assert len(keys) == len(array), "assuming that each field is dim 1 for now."
assert len(keys) == len(array), "assuming that each field is dim 1 for now."
d = OrderedDict(zip(keys, array))
logger.debug(f"Creating an instance of {cls} using args {d}")
d = OrderedDict((k, v.item() if isinstance(v, np.ndarray) else v) for k, v in d.items())
return cls.from_dict(d)

def clip_within_bounds(self: HP) -> HP:
d = self.to_dict()
Expand Down
14 changes: 7 additions & 7 deletions simple_parsing/wrappers/dataclass_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from logging import getLogger
from typing import Any, Callable, Generic, TypeVar, cast

import docstring_parser as dp
from typing_extensions import Literal

from .. import docstring, utils
from ..utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type
from .field_wrapper import FieldWrapper
from .wrapper import Wrapper
from simple_parsing import docstring, utils
from simple_parsing.docstring import dp_parse, inspect_getdoc
from simple_parsing.utils import Dataclass, DataclassT, is_dataclass_instance, is_dataclass_type
from simple_parsing.wrappers.field_wrapper import FieldWrapper
from simple_parsing.wrappers.wrapper import Wrapper

logger = getLogger(__name__)

Expand Down Expand Up @@ -327,11 +327,11 @@ def description(self) -> str:

# NOTE: The class docstring may be EXTRELEMY LARGE.

class_docstring = inspect.getdoc(self.dataclass) or ""
class_docstring = inspect_getdoc(self.dataclass) or ""
if not class_docstring:
return ""

doc = dp.parse(class_docstring)
doc = dp_parse(class_docstring)

from simple_parsing.decorators import _description_from_docstring

Expand Down

0 comments on commit b3d12de

Please sign in to comment.