Skip to content

Commit

Permalink
Merge pull request #398 from dyson-ai/feature/with_sample_values
Browse files Browse the repository at this point in the history
Feature/with sample values
  • Loading branch information
blooop committed Jun 6, 2024
2 parents 7c5817e + 42c1f1c commit 446b014
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 22 deletions.
21 changes: 19 additions & 2 deletions bencher/bencher.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,21 @@ def plot_sweep(
else:
const_vars = deepcopy(const_vars)

for i in range(len(input_vars)):
input_vars[i] = self.convert_vars_to_params(input_vars[i], "input")
if isinstance(input_vars, dict):
input_lists = []
for k, v in input_vars.items():
param_var = self.convert_vars_to_params(k, "input")
if isinstance(v, list):
assert len(v) > 0
param_var = param_var.with_sample_values(v)
else:
raise RuntimeError("Unsupported type")
input_lists.append(param_var)

input_vars = input_lists
else:
for i in range(len(input_vars)):
input_vars[i] = self.convert_vars_to_params(input_vars[i], "input")
for i in range(len(result_vars)):
result_vars[i] = self.convert_vars_to_params(result_vars[i], "result")

Expand Down Expand Up @@ -484,6 +497,10 @@ def convert_vars_to_params(self, variable: param.Parameter, var_type: str):
"""
if isinstance(variable, str):
variable = self.worker_class_instance.param.objects(instance=False)[variable]
if isinstance(variable, tuple):
variable = self.worker_class_instance.param.objects(instance=False)[
variable[0]
].with_sample_values(variable[1])
if not isinstance(variable, param.Parameter):
raise TypeError(
f"You need to use {var_type}_vars =[{self.worker_input_cfg}.param.your_variable], instead of {var_type}_vars =[{self.worker_input_cfg}.your_variable]"
Expand Down
40 changes: 40 additions & 0 deletions bencher/example/example_custom_sweep2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import bencher as bch


class Square(bch.ParametrizedSweep):
"""An example of a datatype with an integer and float parameter"""

x = bch.FloatSweep(default=0, bounds=[0, 6])

result = bch.ResultVar("ul", doc="Square of x")

def __call__(self, **kwargs) -> dict:
self.update_params_from_kwargs(**kwargs)
self.result = self.x * self.x
return self.get_results_values_as_dict()


def example_custom_sweep2(
run_cfg: bch.BenchRunCfg = None, report: bch.BenchReport = None
) -> bch.Bench:
"""This example shows how to define a custom set of value to sample from intead of a uniform sweep
Args:
run_cfg (BenchRunCfg): configuration of how to perform the param sweep
Returns:
Bench: results of the parameter sweep
"""

bench = Square().to_bench(run_cfg=run_cfg, report=report)

# These are all equivalent
bench.plot_sweep(input_vars=[Square.param.x.with_sample_values([0, 1, 2])])
bench.plot_sweep(input_vars=dict(x=[2, 3, 4]))
bench.plot_sweep(input_vars=[("x", [3, 4, 5])])

return bench


if __name__ == "__main__":
example_custom_sweep2().report.show()
38 changes: 19 additions & 19 deletions pixi.lock
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-h4ab18f5_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h59595ed_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.0-h4ab18f5_3.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.1-h4ab18f5_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.12.3-hab00c5b_0_cpython.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
Expand Down Expand Up @@ -112,7 +112,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/a2/73/a68704750a7679d0b6d3ad7aa8d4da8e14e151ae82e6fee774e6e0d05ec8/urllib3-2.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/67/91/1f55f0e026fba8eba15afb7d097bb873bd6a9e466be45a45e7cac40a930b/xarray-2024.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b7/2c/08768a39947864fcebc19f059b758d8169a2ac183a61361359f56c144f7c/xyzservices-2024.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/5f/51/c106f095c33de0b833d3823fbab3383248476b3a9fd4dcd59ba01d950361/xyzservices-2024.6.0-py3-none-any.whl
- pypi: .
py310:
channels:
Expand All @@ -135,7 +135,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-h4ab18f5_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h59595ed_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.0-h4ab18f5_3.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.1-h4ab18f5_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.10.14-hd12c33a_0_cpython.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
Expand Down Expand Up @@ -227,7 +227,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/a2/73/a68704750a7679d0b6d3ad7aa8d4da8e14e151ae82e6fee774e6e0d05ec8/urllib3-2.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/67/91/1f55f0e026fba8eba15afb7d097bb873bd6a9e466be45a45e7cac40a930b/xarray-2024.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b7/2c/08768a39947864fcebc19f059b758d8169a2ac183a61361359f56c144f7c/xyzservices-2024.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/5f/51/c106f095c33de0b833d3823fbab3383248476b3a9fd4dcd59ba01d950361/xyzservices-2024.6.0-py3-none-any.whl
- pypi: .
py311:
channels:
Expand All @@ -251,7 +251,7 @@ environments:
- conda: https://conda.anaconda.org/conda-forge/linux-64/libxcrypt-4.4.36-hd590300_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/libzlib-1.3.1-h4ab18f5_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/ncurses-6.5-h59595ed_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.0-h4ab18f5_3.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.1-h4ab18f5_0.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/python-3.11.9-hb806964_0_cpython.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/readline-8.2-h8228510_1.conda
- conda: https://conda.anaconda.org/conda-forge/linux-64/tk-8.6.13-noxft_h4845f30_101.conda
Expand Down Expand Up @@ -341,7 +341,7 @@ environments:
- pypi: https://files.pythonhosted.org/packages/a2/73/a68704750a7679d0b6d3ad7aa8d4da8e14e151ae82e6fee774e6e0d05ec8/urllib3-2.2.1-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/f4/24/2a3e3df732393fed8b3ebf2ec078f05546de641fe1b667ee316ec1dcf3b7/webencodings-0.5.1-py2.py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/67/91/1f55f0e026fba8eba15afb7d097bb873bd6a9e466be45a45e7cac40a930b/xarray-2024.5.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/b7/2c/08768a39947864fcebc19f059b758d8169a2ac183a61361359f56c144f7c/xyzservices-2024.4.0-py3-none-any.whl
- pypi: https://files.pythonhosted.org/packages/5f/51/c106f095c33de0b833d3823fbab3383248476b3a9fd4dcd59ba01d950361/xyzservices-2024.6.0-py3-none-any.whl
- pypi: .
packages:
- kind: conda
Expand Down Expand Up @@ -896,9 +896,9 @@ packages:
requires_python: '>=3.7'
- kind: pypi
name: holobench
version: 1.23.0
version: 1.24.0
path: .
sha256: 76ce5d13f28e35b37db41f02f45ee8dce425e1ab53b7bc2bc4a97322d908309a
sha256: 5d6bfc4999f8aae6d6207b11e7bcb5649ea91191c05fb7f739f2bfc03d926d15
requires_dist:
- holoviews>=1.15,<=1.18.3
- numpy>=1.0,<=1.26.4
Expand Down Expand Up @@ -1424,6 +1424,7 @@ packages:
constrains:
- binutils_impl_linux-64 2.40
license: GPL-3.0-only
license_family: GPL
purls: []
size: 708179
timestamp: 1717523002366
Expand Down Expand Up @@ -1835,13 +1836,12 @@ packages:
requires_python: '>=3.9'
- kind: conda
name: openssl
version: 3.3.0
build: h4ab18f5_3
build_number: 3
version: 3.3.1
build: h4ab18f5_0
subdir: linux-64
url: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.0-h4ab18f5_3.conda
sha256: 33dcea0ed3a61b2de6b66661cdd55278640eb99d676cd129fbff3e53641fa125
md5: 12ea6d0d4ed54530eaed18e4835c1f7c
url: https://conda.anaconda.org/conda-forge/linux-64/openssl-3.3.1-h4ab18f5_0.conda
sha256: 9691f8bd6394c5bb0b8d2f47cd1467b91bd5b1df923b69e6b517f54496ee4b50
md5: a41fa0e391cc9e0d6b78ac69ca047a6c
depends:
- ca-certificates
- libgcc-ng >=12
Expand All @@ -1850,8 +1850,8 @@ packages:
license: Apache-2.0
license_family: Apache
purls: []
size: 2891147
timestamp: 1716468354865
size: 2896170
timestamp: 1717546157673
- kind: pypi
name: optuna
version: 3.6.1
Expand Down Expand Up @@ -3548,9 +3548,9 @@ packages:
requires_python: '>=3.9'
- kind: pypi
name: xyzservices
version: 2024.4.0
url: https://files.pythonhosted.org/packages/b7/2c/08768a39947864fcebc19f059b758d8169a2ac183a61361359f56c144f7c/xyzservices-2024.4.0-py3-none-any.whl
sha256: b83e48c5b776c9969fffcfff57b03d02b1b1cd6607a9d9c4e7f568b01ef47f4c
version: 2024.6.0
url: https://files.pythonhosted.org/packages/5f/51/c106f095c33de0b833d3823fbab3383248476b3a9fd4dcd59ba01d950361/xyzservices-2024.6.0-py3-none-any.whl
sha256: fecb2508f0f2b71c819aecf5df2c03cef001c56a4b49302e640f3b34710d25e4
requires_python: '>=3.8'
- kind: conda
name: xz
Expand Down
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "holobench"
version = "1.23.0"
version = "1.24.0"

authors = [{ name = "Austin Gregg-Smith", email = "blooop@gmail.com" }]
description = "A package for benchmarking the performance of arbitrary functions"
Expand Down Expand Up @@ -80,8 +80,11 @@ test = "pytest"
coverage = "coverage run -m pytest && coverage xml -o coverage.xml"
coverage-report = "coverage report -m"
update-lock = "pixi update && git commit -a -m'update pixi.lock'"
fix = { depends_on = ["update-lock", "format", "ruff-lint"] }
push = "git push"
update-lock-push = { depends_on = ["update-lock", "push"] }
fix-commit-push = { depends_on = ["fix", "commit-format", "update-lock-push"] }

ci-no-cover = { depends_on = ["style", "test"] }
ci = { depends_on = [
"format",
Expand Down
4 changes: 4 additions & 0 deletions test/test_bench_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from bencher.example.example_float3D import example_floats3D

from bencher.example.example_custom_sweep import example_custom_sweep
from bencher.example.example_custom_sweep2 import example_custom_sweep2
from bencher.example.example_workflow import example_floats2D_workflow, example_floats3D_workflow
from bencher.example.example_holosweep import example_holosweep
from bencher.example.example_holosweep_tap import example_holosweep_tap
Expand Down Expand Up @@ -85,6 +86,9 @@ def test_example_float3D(self) -> None:
def test_example_custom_sweep(self) -> None:
self.examples_asserts(example_custom_sweep(self.create_run_cfg()))

def test_example_custom2(self) -> None:
self.examples_asserts(example_custom_sweep2(self.create_run_cfg()))

def test_example_floats2D_workflow(self) -> None:
self.examples_asserts(example_floats2D_workflow(self.create_run_cfg()))

Expand Down

0 comments on commit 446b014

Please sign in to comment.