Skip to content

Commit

Permalink
Merge pull request #399 from dyson-ai/feature/with_level
Browse files Browse the repository at this point in the history
Feature/with level
  • Loading branch information
blooop committed Jun 6, 2024
2 parents 446b014 + d215966 commit 2d3af3a
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 30 deletions.
2 changes: 1 addition & 1 deletion bencher/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .variables.inputs import IntSweep, FloatSweep, StringSweep, EnumSweep, BoolSweep, SweepBase
from .variables.time import TimeSnapshot

from .variables.inputs import box
from .variables.inputs import box, p
from .variables.results import (
ResultVar,
ResultVec,
Expand Down
56 changes: 33 additions & 23 deletions bencher/bencher.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from datetime import datetime
from itertools import product, combinations

from typing import Callable, List
from typing import Callable, List, Optional
from copy import deepcopy
import numpy as np
import param
Expand Down Expand Up @@ -315,23 +315,37 @@ def plot_sweep(
else:
const_vars = deepcopy(const_vars)

if run_cfg is None:
if self.run_cfg is None:
run_cfg = BenchRunCfg()
logging.info("Generate default run cfg")
else:
run_cfg = deepcopy(self.run_cfg)
logging.info("Copy run cfg from bench class")

if run_cfg.only_plot:
run_cfg.use_cache = True

self.last_run_cfg = run_cfg

if isinstance(input_vars, dict):
input_lists = []
for k, v in input_vars.items():
param_var = self.convert_vars_to_params(k, "input")
param_var = self.convert_vars_to_params(k, "input", run_cfg)
if isinstance(v, list):
assert len(v) > 0
param_var = param_var.with_sample_values(v)

else:
raise RuntimeError("Unsupported type")
input_lists.append(param_var)

input_vars = input_lists
else:
for i in range(len(input_vars)):
input_vars[i] = self.convert_vars_to_params(input_vars[i], "input")
input_vars[i] = self.convert_vars_to_params(input_vars[i], "input", run_cfg)
for i in range(len(result_vars)):
result_vars[i] = self.convert_vars_to_params(result_vars[i], "result")
result_vars[i] = self.convert_vars_to_params(result_vars[i], "result", run_cfg)

for r in result_vars:
logging.info(f"result var: {r.name}")
Expand All @@ -342,22 +356,9 @@ def plot_sweep(
for i in range(len(const_vars)):
# consts come as tuple pairs
cv_list = list(const_vars[i])
cv_list[0] = self.convert_vars_to_params(cv_list[0], "const")
cv_list[0] = self.convert_vars_to_params(cv_list[0], "const", run_cfg)
const_vars[i] = cv_list

if run_cfg is None:
if self.run_cfg is None:
run_cfg = BenchRunCfg()
logging.info("Generate default run cfg")
else:
run_cfg = deepcopy(self.run_cfg)
logging.info("Copy run cfg from bench class")

if run_cfg.only_plot:
run_cfg.use_cache = True

self.last_run_cfg = run_cfg

if title is None:
if len(input_vars) > 0:
title = "Sweeping " + " vs ".join([i.name for i in input_vars])
Expand Down Expand Up @@ -485,7 +486,12 @@ def run_sweep(
self.results.append(bench_res)
return bench_res

def convert_vars_to_params(self, variable: param.Parameter, var_type: str):
def convert_vars_to_params(
self,
variable: param.Parameter | str | dict | tuple,
var_type: str,
run_cfg: Optional[BenchRunCfg],
) -> param.Parameter:
"""check that a variable is a subclass of param
Args:
Expand All @@ -497,10 +503,14 @@ def convert_vars_to_params(self, variable: param.Parameter, var_type: str):
"""
if isinstance(variable, str):
variable = self.worker_class_instance.param.objects(instance=False)[variable]
if isinstance(variable, tuple):
variable = self.worker_class_instance.param.objects(instance=False)[
variable[0]
].with_sample_values(variable[1])
if isinstance(variable, dict):
param_var = self.worker_class_instance.param.objects(instance=False)[variable["name"]]
if variable["values"] is not None:
param_var = param_var.with_sample_values(variable["values"])
if variable["max_level"] is not None:
if run_cfg is not None:
param_var = param_var.with_level(run_cfg.level, variable["max_level"])
variable = param_var
if not isinstance(variable, param.Parameter):
raise TypeError(
f"You need to use {var_type}_vars =[{self.worker_input_cfg}.param.your_variable], instead of {var_type}_vars =[{self.worker_input_cfg}.your_variable]"
Expand Down
3 changes: 1 addition & 2 deletions bencher/example/example_custom_sweep2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ def example_custom_sweep2(

# These are all equivalent
bench.plot_sweep(input_vars=[Square.param.x.with_sample_values([0, 1, 2])])
bench.plot_sweep(input_vars=dict(x=[2, 3, 4]))
bench.plot_sweep(input_vars=[("x", [3, 4, 5])])
bench.plot_sweep(input_vars=[bch.p("x", [4, 5, 6])])

return bench

Expand Down
37 changes: 37 additions & 0 deletions bencher/example/example_levels2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import bencher as bch


class Square(bch.ParametrizedSweep):
"""An example of a datatype with an integer and float parameter"""

x = bch.FloatSweep(default=0, bounds=[0, 6])

result = bch.ResultVar("ul", doc="Square of x")

def __call__(self, **kwargs) -> dict:
self.update_params_from_kwargs(**kwargs)
self.result = self.x * self.x
return self.get_results_values_as_dict()


def example_levels2(run_cfg: bch.BenchRunCfg = None, report: bch.BenchReport = None) -> bch.Bench:
"""This example shows how to define a custom set of value to sample from intead of a uniform sweep
Args:
run_cfg (BenchRunCfg): configuration of how to perform the param sweep
Returns:
Bench: results of the parameter sweep
"""

bench = Square().to_bench(run_cfg=run_cfg, report=report)

# These are all equivalent
bench.plot_sweep(input_vars=[Square.param.x.with_level(run_cfg.level, 3)])
bench.plot_sweep(input_vars=[bch.p("x", max_level=3)])

return bench


if __name__ == "__main__":
example_levels2(bch.BenchRunCfg(level=4)).report.show()
8 changes: 7 additions & 1 deletion bencher/variables/inputs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from enum import Enum
from typing import List, Any
from typing import List, Any, Dict

import numpy as np
from param import Integer, Number, Selector
Expand Down Expand Up @@ -174,6 +174,12 @@ def box(name, center, width):
return var


def p(name: str, values: List[Any] = None, max_level: int = None) -> Dict[str, Any]:
if max_level is not None:
assert max_level > 0
return {"name": name, "values": values, "max_level": max_level}


def with_level(arr: list, level) -> list:
return IntSweep(sample_values=arr).with_level(level).values()
# return tmp.with_sample_values(arr).with_level(level).values()
4 changes: 2 additions & 2 deletions pixi.lock
Original file line number Diff line number Diff line change
Expand Up @@ -896,9 +896,9 @@ packages:
requires_python: '>=3.7'
- kind: pypi
name: holobench
version: 1.24.0
version: 1.25.0
path: .
sha256: 5d6bfc4999f8aae6d6207b11e7bcb5649ea91191c05fb7f739f2bfc03d926d15
sha256: 0cd186ecf71e375f8522dc223a714b35cdf40bf76c572593de379c38680acc1d
requires_dist:
- holoviews>=1.15,<=1.18.3
- numpy>=1.0,<=1.26.4
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "holobench"
version = "1.24.0"
version = "1.25.0"

authors = [{ name = "Austin Gregg-Smith", email = "blooop@gmail.com" }]
description = "A package for benchmarking the performance of arbitrary functions"
Expand Down
4 changes: 4 additions & 0 deletions test/test_bench_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from bencher.example.example_custom_sweep import example_custom_sweep
from bencher.example.example_custom_sweep2 import example_custom_sweep2
from bencher.example.example_levels2 import example_levels2
from bencher.example.example_workflow import example_floats2D_workflow, example_floats3D_workflow
from bencher.example.example_holosweep import example_holosweep
from bencher.example.example_holosweep_tap import example_holosweep_tap
Expand Down Expand Up @@ -89,6 +90,9 @@ def test_example_custom_sweep(self) -> None:
def test_example_custom2(self) -> None:
self.examples_asserts(example_custom_sweep2(self.create_run_cfg()))

def test_example_level2(self) -> None:
self.examples_asserts(example_levels2(self.create_run_cfg()))

def test_example_floats2D_workflow(self) -> None:
self.examples_asserts(example_floats2D_workflow(self.create_run_cfg()))

Expand Down

0 comments on commit 2d3af3a

Please sign in to comment.