Skip to content

Commit

Permalink
Merge pull request #326 from dyson-ai/feature/consts
Browse files Browse the repository at this point in the history
Feature/consts
  • Loading branch information
blooop committed Feb 3, 2024
2 parents 6424876 + a8df444 commit 28d75c5
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 33 deletions.
58 changes: 26 additions & 32 deletions bencher/bencher.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,13 @@ def __init__(

self.cache_size = int(100e9) # default to 100gb

# self.bench_cfg = BenchCfg()

# Maybe put this in SweepCfg
self.input_vars = None
self.result_vars = None
self.const_vars = None

def set_worker(self, worker: Callable, worker_input_cfg: ParametrizedSweep = None) -> None:
"""Set the benchmark worker function and optionally the type the worker expects
Expand All @@ -188,34 +195,6 @@ def set_worker(self, worker: Callable, worker_input_cfg: ParametrizedSweep = Non
logging.info(f"setting worker {worker}")
self.worker_input_cfg = worker_input_cfg

def sweep(
self,
input_vars: List[ParametrizedSweep] = None,
result_vars: List[ParametrizedSweep] = None,
const_vars: List[ParametrizedSweep] = None,
time_src: datetime = None,
description: str = None,
post_description: str = None,
pass_repeat: bool = False,
tag: str = "",
run_cfg: BenchRunCfg = None,
plot: bool = False,
) -> BenchResult:
title = "Sweeping " + " vs ".join(params_to_str(input_vars))
return self.plot_sweep(
title,
input_vars=input_vars,
result_vars=result_vars,
const_vars=const_vars,
time_src=time_src,
description=description,
post_description=post_description,
pass_repeat=pass_repeat,
tag=tag,
run_cfg=run_cfg,
plot=plot,
)

def sweep_sequential(
self,
title="",
Expand Down Expand Up @@ -291,18 +270,27 @@ def plot_sweep(
logging.info(
"No input variables passed, using all param variables in bench class as inputs"
)
input_vars = self.worker_class_instance.get_inputs_only()
if self.input_vars is None:
input_vars = self.worker_class_instance.get_inputs_only()
else:
input_vars = self.input_vars
for i in input_vars:
logging.info(f"input var: {i.name}")
if result_vars is None:
logging.info(
"No results variables passed, using all result variables in bench class:"
)
result_vars = self.worker_class_instance.get_results_only()
if self.result_vars is None:
result_vars = self.worker_class_instance.get_results_only()
else:
result_vars = self.result_vars
for r in result_vars:
logging.info(f"result var: {r.name}")
if const_vars is None:
const_vars = self.worker_class_instance.get_input_defaults()
if self.const_vars is None:
const_vars = self.worker_class_instance.get_input_defaults()
else:
const_vars = self.const_vars
else:
if input_vars is None:
input_vars = []
Expand Down Expand Up @@ -392,7 +380,13 @@ def plot_sweep(
title=title,
pass_repeat=pass_repeat,
tag=run_cfg.run_tag + tag,
auto_plot=plot,
)
return self.run_sweep(bench_cfg, run_cfg, time_src)

def run_sweep(
self, bench_cfg: BenchCfg, run_cfg: BenchRunCfg, time_src: datetime
) -> BenchResult:
print("tag", bench_cfg.tag)

bench_cfg.param.update(run_cfg.param.values())
Expand Down Expand Up @@ -447,7 +441,7 @@ def plot_sweep(

bench_res.post_setup()

if plot and bench_res.bench_cfg.auto_plot:
if bench_cfg.auto_plot:
self.report.append_result(bench_res)
self.results.append(bench_res)
return bench_res
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "holobench"
version = "1.5.0"
version = "1.6.0"

authors = [{ name = "Austin Gregg-Smith", email = "blooop@gmail.com" }]
description = "A package for benchmarking the performance of arbitrary functions"
Expand Down

0 comments on commit 28d75c5

Please sign in to comment.