Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed May 7, 2024
1 parent 66d7585 commit 6425a03
Show file tree
Hide file tree
Showing 4 changed files with 101 additions and 9 deletions.
30 changes: 27 additions & 3 deletions onedal/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@
# limitations under the License.
# ==============================================================================

import warnings
from abc import ABCMeta, abstractmethod

import numpy as np

from ..common._base import BaseEstimator
from ..datatypes import _convert_to_supported, from_table, to_table
from ..utils.validation import _check_array


class BaseBasicStatistics(BaseEstimator, metaclass=ABCMeta):
Expand Down Expand Up @@ -72,21 +74,43 @@ def fit(self, data, sample_weight=None, queue=None):
policy = self._get_policy(queue, data, sample_weight)

if not (data is None):
data = np.asarray(data)
data = _check_array(data, ensure_2d=False)
if not (sample_weight is None):
sample_weight = np.asarray(sample_weight)
sample_weight = _check_array(sample_weight, ensure_2d=False)

data, sample_weight = _convert_to_supported(policy, data, sample_weight)
is_single_dim = data.ndim == 1
data_table, weights_table = to_table(data, sample_weight)

dtype = data.dtype
raw_result = self._compute_raw(data_table, weights_table, policy, dtype)
for opt, raw_value in raw_result.items():
value = from_table(raw_value).ravel()
setattr(self, opt, value)
if is_single_dim:
setattr(self, opt, value[0])
else:
setattr(self, opt, value)

return self

def compute(self, data, weights=None, queue=None):
warnings.warn(
"Method `compute` was deprecated in version 2024.3 and will be "
"removed in 2024.5. Use `fit` instead."
)
if not (data is None):
data = _check_array(data, ensure_2d=False)
if not (weights is None):
weights = _check_array(weights, ensure_2d=False)

policy = self._get_policy(queue, data, weights)
data, weights = _convert_to_supported(policy, data, weights)
data_table, weights_table = to_table(data, weights)
dtype = data.dtype
res = self._compute_raw(data_table, weights_table, policy, dtype)

return {k: from_table(v).ravel() for k, v in res.items()}

def _compute_raw(self, data_table, weights_table, policy, dtype=np.float32):
module = self._get_backend("basic_statistics")
params = self._get_onedal_params(dtype)
Expand Down
33 changes: 33 additions & 0 deletions onedal/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,36 @@ def test_all_option_on_random_data(queue, row_count, column_count, weighted, dty
gtr = function(data)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("option", options_and_tests)
@pytest.mark.parametrize("data_size", [100, 1000])
@pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_1d_input_on_random_data(queue, option, data_size, weighted, dtype):
result_option, function, tols = option
fp32tol, fp64tol = tols
seed = 77
gen = np.random.default_rng(seed)
data = gen.uniform(low=-0.3, high=+0.7, size=data_size)
data = data.astype(dtype=dtype)
if weighted:
weights = gen.uniform(low=-0.5, high=+1.0, size=data_size)
weights = weights.astype(dtype=dtype)
else:
weights = None

basicstat = BasicStatistics(result_options=result_option)

result = basicstat.fit(data, sample_weight=weights, queue=queue)

res = getattr(result, result_option)
if weighted:
weighted_data = weights * data
gtr = function(weighted_data)
else:
gtr = function(data)

tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)
12 changes: 6 additions & 6 deletions sklearnex/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils import check_array
from sklearn.utils import check_array, deprecated
from sklearn.utils.validation import _check_sample_weight

from daal4py.sklearn._n_jobs_support import control_n_jobs
Expand Down Expand Up @@ -91,12 +91,9 @@ def _onedal_supported(self, method_name, *data):

def _onedal_fit(self, X, sample_weight=None, queue=None):
if sklearn_check_version("1.0"):
X = self._validate_data(X, dtype=[np.float64, np.float32])
X = self._validate_data(X, dtype=[np.float64, np.float32], ensure_2d=False)
else:
X = check_array(
X,
dtype=[np.float64, np.float32],
)
X = check_array(X, dtype=[np.float64, np.float32], ensure_2d=False)

if sample_weight is not None:
sample_weight = _check_sample_weight(sample_weight, X)
Expand All @@ -110,6 +107,9 @@ def _onedal_fit(self, X, sample_weight=None, queue=None):
self._onedal_estimator.fit(X, sample_weight, queue)
self._save_attributes()

def compute(self, data, weights=None, queue=None):
return self._onedal_estimator.compute(data, weights, queue)

def fit(self, X, y=None, *, sample_weight=None):
"""Compute statistics with X, using minibatches of size batch_size.
Expand Down
35 changes: 35 additions & 0 deletions sklearnex/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,3 +214,38 @@ def test_all_option_on_random_data(
gtr = function(X)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)


@pytest.mark.parametrize("dataframe,queue", get_dataframes_and_queues())
@pytest.mark.parametrize("option", options_and_tests)
@pytest.mark.parametrize("data_size", [100, 1000])
@pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_1d_input_on_random_data(dataframe, queue, option, data_size, weighted, dtype):
result_option, function, tols = option
fp32tol, fp64tol = tols
seed = 77
gen = np.random.default_rng(seed)
X = gen.uniform(low=-0.3, high=+0.7, size=data_size)
X = X.astype(dtype=dtype)
X_df = _convert_to_dataframe(X, sycl_queue=queue, target_df=dataframe)
if weighted:
weights = gen.uniform(low=-0.5, high=1.0, size=data_size)
weights = weights.astype(dtype=dtype)
weights_df = _convert_to_dataframe(weights, sycl_queue=queue, target_df=dataframe)
basicstat = BasicStatistics(result_options=result_option)

if weighted:
result = basicstat.fit(X_df, sample_weight=weights_df)
else:
result = basicstat.fit(X_df)

res = getattr(result, result_option)
if weighted:
weighted_data = weights * X
gtr = function(weighted_data)
else:
gtr = function(X)

tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)

0 comments on commit 6425a03

Please sign in to comment.