Skip to content

Commit

Permalink
Refactored and changed interface for BasicStatistics
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Jan 21, 2024
1 parent b993e34 commit bcdf148
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 143 deletions.
79 changes: 33 additions & 46 deletions onedal/basic_statistics/basic_statistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,21 @@ struct params2desc {
};

template <typename Policy, typename Task>
struct init_compute_ops_dispatcher {};

template <typename Policy>
struct init_compute_ops_dispatcher<Policy, dal::basic_statistics::task::compute> {
void operator()(py::module_& m) {
using Task = dal::basic_statistics::task::compute;
m.def("train",
[](const Policy& policy,
const py::dict& params,
const table& data,
const table& weights) {
using namespace dal::basic_statistics;
using input_t = compute_input<Task>;

compute_ops ops(policy, input_t{ data, weights }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
});
}
};
void init_compute_ops(py::module& m) {
m.def("compute", [](
const Policy& policy,
const py::dict& params,
const table& data,
const table& weights) {
using namespace dal::basic_statistics;
using input_t = compute_input<Task>;

compute_ops ops(policy, input_t{ data, weights }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
}
);
}


template <typename Policy, typename Task>
void init_finalize_compute_ops(pybind11::module_& m) {
Expand Down Expand Up @@ -148,28 +144,23 @@ void init_partial_compute_ops(py::module& m) {
);
}

template <typename Policy, typename Task>
void init_compute_ops(py::module& m) {
init_compute_ops_dispatcher<Policy, Task>{}(m);
}

template <typename Task>
void init_compute_result(py::module_& m) {
using namespace dal::basic_statistics;
using result_t = compute_result<Task>;

auto cls = py::class_<result_t>(m, "compute_result")
.def(py::init())
.DEF_ONEDAL_PY_PROPERTY(min, result_t)
.DEF_ONEDAL_PY_PROPERTY(max, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum, result_t)
.DEF_ONEDAL_PY_PROPERTY(mean, result_t)
.DEF_ONEDAL_PY_PROPERTY(variance, result_t)
.DEF_ONEDAL_PY_PROPERTY(variation, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum_squares, result_t)
.DEF_ONEDAL_PY_PROPERTY(standard_deviation, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum_squares_centered, result_t)
.DEF_ONEDAL_PY_PROPERTY(second_order_raw_moment, result_t);
py::class_<result_t>(m, "compute_result")
.def(py::init())
.DEF_ONEDAL_PY_PROPERTY(min, result_t)
.DEF_ONEDAL_PY_PROPERTY(max, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum, result_t)
.DEF_ONEDAL_PY_PROPERTY(mean, result_t)
.DEF_ONEDAL_PY_PROPERTY(variance, result_t)
.DEF_ONEDAL_PY_PROPERTY(variation, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum_squares, result_t)
.DEF_ONEDAL_PY_PROPERTY(standard_deviation, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum_squares_centered, result_t)
.DEF_ONEDAL_PY_PROPERTY(second_order_raw_moment, result_t);
}

template <typename Task>
Expand Down Expand Up @@ -201,22 +192,18 @@ ONEDAL_PY_INIT_MODULE(basic_statistics) {
using namespace dal::basic_statistics;

auto sub = m.def_submodule("basic_statistics");
using task_list = types<dal::basic_statistics::task::compute>;

#ifdef ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list_spmd, task_list);
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list_spmd, task::compute);
#else // ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute);
#endif // ONEDAL_DATA_PARALLEL_SPMD

ONEDAL_PY_INSTANTIATE(init_partial_compute_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_finalize_compute_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_partial_compute_result, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_partial_compute_ops, sub, policy_list, task::compute);
ONEDAL_PY_INSTANTIATE(init_finalize_compute_ops, sub, policy_list, task::compute);
ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task::compute);
ONEDAL_PY_INSTANTIATE(init_partial_compute_result, sub, task::compute);
}

ONEDAL_PY_TYPE2STR(dal::basic_statistics::task::compute, "compute");

#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100

} // namespace oneapi::dal::python
46 changes: 18 additions & 28 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@
# ==============================================================================

from abc import ABCMeta, abstractmethod
from numbers import Number

import numpy as np
from sklearn.base import BaseEstimator

from onedal import _backend

Expand All @@ -31,6 +29,7 @@ class BaseBasicStatistics(metaclass=ABCMeta):
def __init__(self, result_options, algorithm):
self.options = result_options
self.algorithm = algorithm
self._module = _backend.basic_statistics

@staticmethod
def get_all_result_options():
Expand Down Expand Up @@ -66,17 +65,16 @@ def _get_onedal_params(self, dtype=np.float32):
"result_option": options,
}

def _compute_raw(self, data_table, weights_table, module, policy, dtype=np.float32):
params = self._get_onedal_params(dtype)

result = module.train(policy, params, data_table, weights_table)

options = self._get_result_options(self.options)
options = options.split("|")
class BasicStatistics(BaseBasicStatistics):
"""
Basic Statistics oneDAL implementation.
"""

return {opt: getattr(result, opt) for opt in options}
def __init__(self, result_options="all", algorithm="by_default"):
super().__init__(result_options, algorithm)

def _compute(self, data, weights, module, queue):
def fit(self, data, weights=None, queue=None):
policy = self._get_policy(queue, data, weights)

if not (data is None):
Expand All @@ -85,27 +83,19 @@ def _compute(self, data, weights, module, queue):
weights = np.asarray(weights)

data, weights = _convert_to_supported(policy, data, weights)

data_table, weights_table = to_table(data, weights)

dtype = data.dtype
res = self._compute_raw(data_table, weights_table, module, policy, dtype)
raw_result = self._compute_raw(data_table, weights_table, policy, dtype)
for opt, raw_value in raw_result.items():
value = from_table(raw_value).ravel()
setattr(self, opt, value)

return {k: from_table(v).ravel() for k, v in res.items()}
return self

def _compute_raw(self, data_table, weights_table, policy, dtype=np.float32):
params = self._get_onedal_params(dtype)
result = self._module.compute(policy, params, data_table, weights_table)
options = self._get_result_options(self.options).split("|")

class BasicStatistics(BaseBasicStatistics):
"""
Basic Statistics oneDAL implementation.
"""

def __init__(self, result_options="all", *, algorithm="by_default", **kwargs):
super().__init__(result_options, algorithm)

def compute(self, data, weights=None, queue=None):
return super()._compute(data, weights, _backend.basic_statistics.compute, queue)

def compute_raw(self, data_table, weights_table, policy, dtype=np.float32):
return super()._compute_raw(
data_table, weights_table, _backend.basic_statistics.compute, policy, dtype
)
return {opt: getattr(result, opt) for opt in options}
48 changes: 2 additions & 46 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ==============================================================================
# Copyright 2024 Intel Corporation
# Copyright 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,57 +14,13 @@
# limitations under the License.
# ==============================================================================

from abc import ABCMeta, abstractmethod

import numpy as np

from daal4py.sklearn._utils import get_dtype
from onedal import _backend

from ..common._policy import _get_policy
from ..datatypes import _convert_to_supported, from_table, to_table


class BaseBasicStatistics(metaclass=ABCMeta):
@abstractmethod
def __init__(self, result_options, algorithm):
self.options = result_options
self.algorithm = algorithm
self._module = _backend.basic_statistics.compute

@staticmethod
def get_all_result_options():
return [
"min",
"max",
"sum",
"mean",
"variance",
"variation",
"sum_squares",
"standard_deviation",
"sum_squares_centered",
"second_order_raw_moment",
]

def _get_policy(self, queue, *data):
return _get_policy(queue, *data)

def _get_result_options(self, options):
if options == "all":
options = self.get_all_result_options()
if isinstance(options, list):
options = "|".join(options)
assert isinstance(options, str)
return options

def _get_onedal_params(self, dtype=np.float32):
options = self._get_result_options(self.options)
return {
"fptype": "float" if dtype == np.float32 else "double",
"method": self.algorithm,
"result_option": options,
}
from .basic_statistics import BaseBasicStatistics


class IncrementalBasicStatistics(BaseBasicStatistics):
Expand Down
51 changes: 29 additions & 22 deletions onedal/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,24 +32,6 @@
("standard_deviation", np.std, (3e-5, 3e-5)),
]

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_basic_uniform(queue, dtype):
seed = 42
s_count, f_count = 70000, 29

gen = np.random.default_rng(seed)
data = gen.uniform(low=-0.5, high=+0.6, size=(s_count, f_count))
data = data.astype(dtype=dtype)

alg = BasicStatistics(result_options="mean")
res = alg.compute(data, queue=queue)

res_mean = res["mean"]
gtr_mean = np.mean(data, axis=0)
tol = 2e-5 if res_mean.dtype == np.float32 else 1e-7
assert_allclose(gtr_mean, res_mean, rtol=tol)

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("option", options_and_tests)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
Expand All @@ -65,13 +47,38 @@ def test_option_uniform(queue, option, dtype):
data = data.astype(dtype=dtype)

alg = BasicStatistics(result_options=result_option)
res = alg.compute(data, queue=queue)
res = alg.fit(data, queue=queue)

res, gtr = res[result_option], function(data, axis=0)
res, gtr = getattr(res, result_option), function(data, axis=0)

tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, rtol=tol)

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_multiple_options_uniform(queue, dtype):
seed = 42
s_count, f_count = 700, 29

gen = np.random.default_rng(seed)
data = gen.uniform(low=-0.5, high=+0.6, size=(s_count, f_count))
data = data.astype(dtype=dtype)

alg = BasicStatistics(result_options=["mean", "max", "sum"])
res = alg.fit(data, queue=queue)

res_mean, res_max, res_sum = res.mean, res.max, res.sum
gtr_mean, gtr_max, gtr_sum = (
np.mean(data, axis=0),
np.max(data, axis=0),
np.sum(data, axis=0),
)

tol = 2e-5 if res_mean.dtype == np.float32 else 1e-7
assert_allclose(gtr_mean, res_mean, rtol=tol)
assert_allclose(gtr_max, res_max, rtol=tol)
assert_allclose(gtr_sum, res_sum, rtol=tol)

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("option", options_and_tests)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
Expand All @@ -91,10 +98,10 @@ def test_option_weighted(queue, option, dtype):
weights = weights.astype(dtype=dtype)

alg = BasicStatistics(result_options=result_option)
res = alg.compute(data, weights, queue=queue)
res = alg.fit(data, weights, queue=queue)

weighted = np.diag(weights) @ data
res, gtr = res[result_option], function(weighted, axis=0)
res, gtr = getattr(res, result_option), function(weighted, axis=0)

tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, rtol=tol)
Loading

0 comments on commit bcdf148

Please sign in to comment.