Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: BasicStatistics API change #1644

Merged
merged 21 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/sklearnex/basic_statistics_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate_data(par, size, seed=777):
gtr_std = np.std(weighted_data, axis=0)

bss = BasicStatisticsSpmd(["mean", "standard_deviation"])
res = bss.compute(dpt_data, dpt_weights)
bss.fit(dpt_data, dpt_weights)

print(f"Computed mean on rank {rank}:\n", res["mean"])
print(f"Computed std on rank {rank}:\n", res["standard_deviation"])
print(f"Computed mean on rank {rank}:\n", bss.mean)
print(f"Computed std on rank {rank}:\n", bss.standard_deviation)
79 changes: 33 additions & 46 deletions onedal/basic_statistics/basic_statistics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,26 +129,21 @@ struct params2desc_incremental {
};

template <typename Policy, typename Task>
struct init_compute_ops_dispatcher {};

template <typename Policy>
struct init_compute_ops_dispatcher<Policy, dal::basic_statistics::task::compute> {
void operator()(py::module_& m) {
using Task = dal::basic_statistics::task::compute;

m.def("train",
[](const Policy& policy,
const py::dict& params,
const table& data,
const table& weights) {
using namespace dal::basic_statistics;
using input_t = compute_input<Task>;

compute_ops ops(policy, input_t{ data, weights }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
});
}
};
void init_compute_ops(py::module& m) {
m.def("compute", [](
const Policy& policy,
const py::dict& params,
const table& data,
const table& weights) {
using namespace dal::basic_statistics;
using input_t = compute_input<Task>;

compute_ops ops(policy, input_t{ data, weights }, params2desc{});
return fptype2t{ method2t{ Task{}, ops } }(params);
}
);
}


template <typename Policy, typename Task>
void init_partial_compute_ops(py::module& m) {
Expand Down Expand Up @@ -177,28 +172,23 @@ void init_finalize_compute_ops(pybind11::module_& m) {
});
}

template <typename Policy, typename Task>
void init_compute_ops(py::module& m) {
init_compute_ops_dispatcher<Policy, Task>{}(m);
}

template <typename Task>
void init_compute_result(py::module_& m) {
using namespace dal::basic_statistics;
using result_t = compute_result<Task>;

auto cls = py::class_<result_t>(m, "compute_result")
.def(py::init())
.DEF_ONEDAL_PY_PROPERTY(min, result_t)
.DEF_ONEDAL_PY_PROPERTY(max, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum, result_t)
.DEF_ONEDAL_PY_PROPERTY(mean, result_t)
.DEF_ONEDAL_PY_PROPERTY(variance, result_t)
.DEF_ONEDAL_PY_PROPERTY(variation, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum_squares, result_t)
.DEF_ONEDAL_PY_PROPERTY(standard_deviation, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum_squares_centered, result_t)
.DEF_ONEDAL_PY_PROPERTY(second_order_raw_moment, result_t);
py::class_<result_t>(m, "compute_result")
.def(py::init())
.DEF_ONEDAL_PY_PROPERTY(min, result_t)
.DEF_ONEDAL_PY_PROPERTY(max, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum, result_t)
.DEF_ONEDAL_PY_PROPERTY(mean, result_t)
.DEF_ONEDAL_PY_PROPERTY(variance, result_t)
.DEF_ONEDAL_PY_PROPERTY(variation, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum_squares, result_t)
.DEF_ONEDAL_PY_PROPERTY(standard_deviation, result_t)
.DEF_ONEDAL_PY_PROPERTY(sum_squares_centered, result_t)
.DEF_ONEDAL_PY_PROPERTY(second_order_raw_moment, result_t);
}

template <typename Task>
Expand Down Expand Up @@ -230,21 +220,18 @@ ONEDAL_PY_INIT_MODULE(basic_statistics) {
using namespace dal::basic_statistics;

auto sub = m.def_submodule("basic_statistics");
using task_list = types<dal::basic_statistics::task::compute>;

#ifdef ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_spmd, task_list);
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_spmd, task::compute);
#else // ONEDAL_DATA_PARALLEL_SPMD
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_partial_compute_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_finalize_compute_ops, sub, policy_list, task_list);
ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_partial_compute_result, sub, task_list);
ONEDAL_PY_INSTANTIATE(init_compute_ops, sub, policy_list, task::compute);
ONEDAL_PY_INSTANTIATE(init_partial_compute_ops, sub, policy_list, task::compute);
ONEDAL_PY_INSTANTIATE(init_finalize_compute_ops, sub, policy_list, task::compute);
ONEDAL_PY_INSTANTIATE(init_compute_result, sub, task::compute);
ONEDAL_PY_INSTANTIATE(init_partial_compute_result, sub, task::compute);
#endif // ONEDAL_DATA_PARALLEL_SPMD
}

ONEDAL_PY_TYPE2STR(dal::basic_statistics::task::compute, "compute");

#endif // defined(ONEDAL_VERSION) && ONEDAL_VERSION >= 20230100

} // namespace oneapi::dal::python
93 changes: 50 additions & 43 deletions onedal/basic_statistics/basic_statistics.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment for the _get_result_options:
Inputs should be validated. For example invalid option name or empty string input.
This cases should be covered by tests.

I think make sense updated in this PR or in other smaller one either.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've never seen it was done anywhere in our repo

Copy link
Contributor

@samir-nasibli samir-nasibli Jun 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't take this comment as a nitpick. Let's try to follow secure code development.

In scikit-learn-intelex mostly parameters are validated by sklearn primitives. BS has specific things to check. Worth to add some validation this string input param.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the problem still exists not only here but for all algos without sklearn analogues and then goes beyond this PR scope I guess. I'll create ticket for that

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even if there are problems, I would like us to start this. This PR is a good starting point for project changes. Please add a minimum check for results_option, since API of BS is updated here.
Use with pytest.raise(ValueError) to check error raise on empty string input for results_option

Minimal validation for the input should be covered.

Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@
# limitations under the License.
# ==============================================================================

import warnings
from abc import ABCMeta, abstractmethod
from numbers import Number

import numpy as np

from onedal import _backend

from ..common._base import BaseEstimator
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils import _is_csr
from ..utils.validation import _check_array


class BaseBasicStatistics(metaclass=ABCMeta):
class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta):
@abstractmethod
def __init__(self, result_options, algorithm):
self.options = result_options
Expand Down Expand Up @@ -63,59 +62,67 @@ def _get_onedal_params(self, is_csr, dtype=np.float32):
"result_option": options,
}

def _compute_raw(
self, data_table, weights_table, module, policy, dtype=np.float32, is_csr=False
):
params = self._get_onedal_params(is_csr, dtype)

result = module.train(policy, params, data_table, weights_table)

options = self._get_result_options(self.options)
options = options.split("|")
class BasicStatistics(BaseBasicStatistics):
"""
Basic Statistics oneDAL implementation.
"""

return {opt: getattr(result, opt) for opt in options}
def __init__(self, result_options="all", algorithm="by_default"):
super().__init__(result_options, algorithm)

def _compute(self, data, weights, module, queue):
policy = self._get_policy(queue, data, weights)
def fit(self, data, sample_weight=None, queue=None):
policy = self._get_policy(queue, data, sample_weight)

is_csr = _is_csr(data)
if not (data is None) and not is_csr:
data = np.asarray(data)

if not (weights is None):
weights = np.asarray(weights)
if data is not None and not is_csr:
data = _check_array(data, ensure_2d=False)
if sample_weight is not None:
sample_weight = _check_array(sample_weight, ensure_2d=False)

data, weights = _convert_to_supported(policy, data, weights)

data_table, weights_table = to_table(data, weights)
data, sample_weight = _convert_to_supported(policy, data, sample_weight)
is_single_dim = data.ndim == 1
data_table, weights_table = to_table(data, sample_weight)

dtype = data.dtype
res = self._compute_raw(data_table, weights_table, module, policy, dtype, is_csr)
raw_result = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)
for opt, raw_value in raw_result.items():
value = from_table(raw_value).ravel()
if is_single_dim:
setattr(self, opt, value[0])
else:
setattr(self, opt, value)

return {k: from_table(v).ravel() for k, v in res.items()}
return self

def compute(self, data, weights=None, queue=None):
olegkkruglov marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
"Method `compute` was deprecated in version 2024.7 and will be "
"removed in 2025.0. Use `fit` instead."
)

is_csr = _is_csr(data)

class BasicStatistics(BaseEstimator, BaseBasicStatistics):
"""
Basic Statistics oneDAL implementation.
"""
if data is not None:
data = _check_array(data, ensure_2d=False)
if weights is not None:
weights = _check_array(weights, ensure_2d=False)

def __init__(self, result_options="all", *, algorithm="by_default", **kwargs):
super().__init__(result_options, algorithm)
policy = self._get_policy(queue, data, weights)
data, weights = _convert_to_supported(policy, data, weights)
data_table, weights_table = to_table(data, weights)
dtype = data.dtype
res = self._compute_raw(data_table, weights_table, policy, dtype, is_csr)

def compute(self, data, weights=None, queue=None):
return super()._compute(
data, weights, self._get_backend("basic_statistics", "compute", None), queue
)
return {k: from_table(v).ravel() for k, v in res.items()}

def compute_raw(
def _compute_raw(
self, data_table, weights_table, policy, dtype=np.float32, is_csr=False
):
return super()._compute_raw(
data_table,
weights_table,
self._get_backend("basic_statistics", "compute", None),
policy,
dtype,
is_csr,
)
module = self._get_backend("basic_statistics")
params = self._get_onedal_params(is_csr, dtype)
result = module.compute(policy, params, data_table, weights_table)
options = self._get_result_options(self.options).split("|")

return {opt: getattr(result, opt) for opt in options}
56 changes: 7 additions & 49 deletions onedal/basic_statistics/incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,56 +14,12 @@
# limitations under the License.
# ==============================================================================

from abc import ABCMeta, abstractmethod

import numpy as np

from daal4py.sklearn._utils import get_dtype
from onedal import _backend

from ..common._policy import _get_policy
from ..datatypes import _convert_to_supported, from_table, to_table


class BaseBasicStatistics(metaclass=ABCMeta):
@abstractmethod
def __init__(self, result_options, algorithm):
self.options = result_options
self.algorithm = algorithm

@staticmethod
def get_all_result_options():
return [
"min",
"max",
"sum",
"mean",
"variance",
"variation",
"sum_squares",
"standard_deviation",
"sum_squares_centered",
"second_order_raw_moment",
]

def _get_policy(self, queue, *data):
return _get_policy(queue, *data)

def _get_result_options(self, options):
if options == "all":
options = self.get_all_result_options()
if isinstance(options, list):
options = "|".join(options)
assert isinstance(options, str)
return options

def _get_onedal_params(self, dtype=np.float32):
options = self._get_result_options(self.options)
return {
"fptype": "float" if dtype == np.float32 else "double",
"method": self.algorithm,
"result_option": options,
}
from .basic_statistics import BaseBasicStatistics


class IncrementalBasicStatistics(BaseBasicStatistics):
Expand Down Expand Up @@ -110,11 +66,11 @@ class IncrementalBasicStatistics(BaseBasicStatistics):

def __init__(self, result_options="all"):
super().__init__(result_options, algorithm="by_default")
module = _backend.basic_statistics.compute
module = self._get_backend("basic_statistics")
samir-nasibli marked this conversation as resolved.
Show resolved Hide resolved
self._partial_result = module.partial_compute_result()

def _reset(self):
module = _backend.basic_statistics.compute
module = self._get_backend("basic_statistics")
self._partial_result = module.partial_train_result()

def partial_fit(self, X, weights=None, queue=None):
Expand Down Expand Up @@ -146,7 +102,8 @@ def partial_fit(self, X, weights=None, queue=None):
self._onedal_params = self._get_onedal_params(dtype)

X_table, weights_table = to_table(X, weights)
self._partial_result = _backend.basic_statistics.compute.partial_compute(
module = self._get_backend("basic_statistics")
self._partial_result = module.partial_compute(
self._policy,
self._onedal_params,
self._partial_result,
Expand All @@ -169,7 +126,8 @@ def finalize_fit(self, queue=None):
self : object
Returns the instance itself.
"""
result = _backend.basic_statistics.compute.finalize_compute(
module = self._get_backend("basic_statistics")
result = module.finalize_compute(
self._policy, self._onedal_params, self._partial_result
)
options = self._get_result_options(self.options).split("|")
Expand Down
Loading
Loading