Skip to content

Commit

Permalink
Merge pull request #331 from dyson-ai/feature/select_level
Browse files Browse the repository at this point in the history
Feature/select level
  • Loading branch information
blooop committed Feb 7, 2024
2 parents fd7f31f + 07470bc commit 37a04b6
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 3 deletions.
40 changes: 40 additions & 0 deletions bencher/results/bench_result_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
)

from bencher.results.composable_container.composable_container_panel import ComposableContainerPanel
from bencher.utils import listify

# todo add plugins
# https://gist.github.com/dorneanu/cce1cd6711969d581873a88e0257e312
Expand Down Expand Up @@ -412,6 +413,45 @@ def ds_to_container(
return container(val, styles={"background": "white"}, **kwargs)
return val

@staticmethod
def select_level(
dataset: xr.Dataset,
level: int,
include_types: List[type] = None,
exclude_names: List[str] = None,
) -> xr.Dataset:
"""Given a dataset, return a reduced dataset that only contains data from a specified level. By default all types of variables are filtered at the specified level. If you only want to get a reduced level for some types of data you can pass in a list of types to get filtered, You can also pass a list of variables names to exclude from getting filtered
Args:
dataset (xr.Dataset): dataset to filter
level (int): desired data resolution level
include_types (List[type], optional): Only filter data of these types. Defaults to None.
exclude_names (List[str], optional): Only filter data with these variable names. Defaults to None.
Returns:
xr.Dataset: A reduced dataset at the specified level
Example: a dataset with float_var: [1,2,3,4,5] cat_var: [a,b,c,d,e]
select_level(ds,2) -> [1,5] [a,e]
select_level(ds,2,(float)) -> [1,5] [a,b,c,d,e]
select_level(ds,2,exclude_names=["cat_var]) -> [1,5] [a,b,c,d,e]
see test_bench_result_base.py -> test_select_level()
"""
coords_no_repeat = {}
for c, v in dataset.coords.items():
if c != "repeat":
vals = v.to_numpy()
print(vals.dtype)
include = True
if include_types is not None and vals.dtype not in listify(include_types):
include = False
if exclude_names is not None and c in listify(exclude_names):
include = False
if include:
coords_no_repeat[c] = with_level(v.to_numpy(), level)
return dataset.sel(coords_no_repeat)

# MAPPING TO LOWER LEVEL BENCHCFG functions so they are available at a top level.
def to_sweep_summary(self, **kwargs):
return self.bench_cfg.to_sweep_summary(**kwargs)
Expand Down
8 changes: 7 additions & 1 deletion bencher/results/panel_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from functools import partial
import panel as pn
from param import Parameter
import holoviews as hv
from bencher.results.bench_result_base import BenchResultBase, ReduceType
from bencher.results.video_result import VideoControls
from bencher.variables.results import (
Expand All @@ -20,16 +21,21 @@ def to_video(self, result_var: Parameter = None, **kwargs):
def to_panes(
self,
result_var: Parameter = None,
hv_dataset=None,
target_dimension: int = 0,
container=None,
level: int = None,
**kwargs
) -> Optional[pn.pane.panel]:
if container is None:
container = pn.pane.panel
if hv_dataset is None:
hv_dataset = self.to_hv_dataset(ReduceType.SQUEEZE, level=level)
elif not isinstance(hv_dataset, hv.Dataset):
hv_dataset = hv.Dataset(hv_dataset)
return self.map_plot_panes(
partial(self.ds_to_container, container=container),
hv_dataset=self.to_hv_dataset(ReduceType.SQUEEZE, level=level),
hv_dataset=hv_dataset,
target_dimension=target_dimension,
result_var=result_var,
result_types=PANEL_TYPES,
Expand Down
2 changes: 1 addition & 1 deletion bencher/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def callable_name(any_callable: Callable[..., Any]) -> str:


def listify(obj) -> list:
"""Take an object and turn it into a list if its not already a list"""
"""Take an object and turn it into a list if its not already a list. However if the object is none, don't turn it into a list"""
if obj is None:
return None
if isinstance(obj, list):
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "holobench"
version = "1.8.0"
version = "1.9.0"

authors = [{ name = "Austin Gregg-Smith", email = "blooop@gmail.com" }]
description = "A package for benchmarking the performance of arbitrary functions"
Expand Down
38 changes: 38 additions & 0 deletions test/test_bench_result_base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,20 @@
import unittest
import bencher as bch
import numpy as np

from bencher.example.meta.example_meta import BenchableObject


class TestBench(bch.ParametrizedSweep):
float_var = bch.FloatSweep(default=0, bounds=[0, 4])
cat_var = bch.StringSweep(["a", "b", "c", "d", "e"])
result = bch.ResultVar()

def __call__(self, **kwargs):
self.result = 1
return super().__call__()


class TestBenchResultBase(unittest.TestCase):
def test_to_dataset(self):
bench = BenchableObject().to_bench()
Expand Down Expand Up @@ -39,3 +50,30 @@ def test_to_dataset(self):
)

# bm.__call__(float_vars=1, sample_with_repeats=1)

def test_select_level(self):
bench = TestBench().to_bench()

res = bench.plot_sweep(
input_vars=["float_var", "cat_var"],
run_cfg=bch.BenchRunCfg(level=4),
plot=False,
)

def asserts(ds, expected_float, expected_cat):
np.testing.assert_array_equal(
ds.coords["float_var"].to_numpy(), np.array(expected_float, dtype=float)
)
np.testing.assert_array_equal(ds.coords["cat_var"].to_numpy(), np.array(expected_cat))

ds_raw = res.to_dataset()
asserts(ds_raw, [0, 1, 2, 3, 4], ["a", "b", "c", "d", "e"])

ds_filtered_all = res.select_level(ds_raw, 2)
asserts(ds_filtered_all, [0, 4], ["a", "e"])

ds_filtered_types = res.select_level(ds_raw, 2, float)
asserts(ds_filtered_types, [0, 4], ["a", "b", "c", "d", "e"])

ds_filtered_names = res.select_level(ds_raw, 2, exclude_names="cat_var")
asserts(ds_filtered_names, [0, 4], ["a", "b", "c", "d", "e"])

0 comments on commit 37a04b6

Please sign in to comment.