Skip to content

Commit

Permalink
Update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed Apr 30, 2024
1 parent d116248 commit b0c7ad2
Show file tree
Hide file tree
Showing 4 changed files with 329 additions and 141 deletions.
227 changes: 157 additions & 70 deletions onedal/basic_statistics/tests/test_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,94 +14,181 @@
# limitations under the License.
# ==============================================================================

import numpy as np
import pytest
from numpy.testing import assert_allclose

from daal4py.sklearn._utils import daal_check_version
from onedal.basic_statistics import BasicStatistics
from onedal.tests.utils._device_selection import get_queues

if daal_check_version((2023, "P", 100)):
import numpy as np
import pytest
from numpy.testing import assert_allclose

from onedal.basic_statistics import BasicStatistics
from onedal.tests.utils._device_selection import get_queues

options_and_tests = [
("sum", np.sum, (1e-5, 1e-7)),
("min", np.min, (1e-5, 1e-7)),
("max", np.max, (1e-5, 1e-7)),
("mean", np.mean, (1e-5, 1e-7)),
("standard_deviation", np.std, (3e-5, 3e-5)),
]

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("option", options_and_tests)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_option_uniform(queue, option, dtype):
seed = 77
s_count, f_count = 19999, 31

result_option, function, tols = option
fp32tol, fp64tol = tols
def expected_sum(X):
return np.sum(X, axis=0)

gen = np.random.default_rng(seed)
data = gen.uniform(low=-0.3, high=+0.7, size=(s_count, f_count))
data = data.astype(dtype=dtype)

alg = BasicStatistics(result_options=result_option)
res = alg.fit(data, queue=queue)
def expected_max(X):
return np.max(X, axis=0)

res, gtr = getattr(res, result_option), function(data, axis=0)

tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, rtol=tol)
def expected_min(X):
return np.min(X, axis=0)

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_multiple_options_uniform(queue, dtype):
seed = 42
s_count, f_count = 700, 29

gen = np.random.default_rng(seed)
data = gen.uniform(low=-0.5, high=+0.6, size=(s_count, f_count))
data = data.astype(dtype=dtype)
def expected_mean(X):
return np.mean(X, axis=0)

alg = BasicStatistics(result_options=["mean", "max", "sum"])
res = alg.fit(data, queue=queue)

res_mean, res_max, res_sum = res.mean, res.max, res.sum
gtr_mean, gtr_max, gtr_sum = (
np.mean(data, axis=0),
np.max(data, axis=0),
np.sum(data, axis=0),
)
def expected_standard_deviation(X):
return np.std(X, axis=0)

tol = 2e-5 if res_mean.dtype == np.float32 else 1e-7
assert_allclose(gtr_mean, res_mean, rtol=tol)
assert_allclose(gtr_max, res_max, rtol=tol)
assert_allclose(gtr_sum, res_sum, rtol=tol)

@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("option", options_and_tests)
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_option_weighted(queue, option, dtype):
seed = 999
s_count, f_count = 1024, 127
def expected_variance(X):
return np.var(X, axis=0)


def expected_variation(X):
return expected_standard_deviation(X) / expected_mean(X)

result_option, function, tols = option
fp32tol, fp64tol = tols
fp32tol, fp64tol = 30 * fp32tol, 50 * fp64tol

gen = np.random.default_rng(seed)
data = gen.uniform(low=-5.0, high=+9.0, size=(s_count, f_count))
weights = gen.uniform(low=-0.5, high=+1.0, size=s_count)
def expected_sum_squares(X):
return np.sum(np.square(X), axis=0)

data = data.astype(dtype=dtype)

def expected_sum_squares_centered(X):
return np.sum(np.square(X - expected_mean(X)), axis=0)


def expected_standard_deviation(X):
return np.sqrt(expected_variance(X))


def expected_second_order_raw_moment(X):
return np.mean(np.square(X), axis=0)


options_and_tests = [
("sum", expected_sum, (3e-4, 1e-7)),
("min", expected_min, (1e-7, 1e-7)),
("max", expected_max, (1e-7, 1e-7)),
("mean", expected_mean, (5e-7, 1e-7)),
("variance", expected_variance, (2e-3, 2e-3)),
("variation", expected_variation, (5e-2, 5e-2)),
("sum_squares", expected_sum_squares, (2e-4, 1e-7)),
("sum_squares_centered", expected_sum_squares_centered, (2e-4, 1e-7)),
("standard_deviation", expected_standard_deviation, (2e-3, 2e-3)),
("second_order_raw_moment", expected_second_order_raw_moment, (1e-6, 1e-7)),
]


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("option", options_and_tests)
@pytest.mark.parametrize("row_count", [100, 1000])
@pytest.mark.parametrize("column_count", [10, 100])
@pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_single_option_on_random_data(
queue, option, row_count, column_count, weighted, dtype
):
result_option, function, tols = option
fp32tol, fp64tol = tols
seed = 77
gen = np.random.default_rng(seed)
data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
data = data.astype(dtype=dtype)
if weighted:
weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
weights = weights.astype(dtype=dtype)
else:
weights = None

basicstat = BasicStatistics(result_options=result_option)

alg = BasicStatistics(result_options=result_option)
res = alg.fit(data, weights, queue=queue)
result = basicstat.fit(data, sample_weight=weights, queue=queue)

weighted = np.diag(weights) @ data
res, gtr = getattr(res, result_option), function(weighted, axis=0)
res = getattr(result, result_option)
if weighted:
weighted_data = np.diag(weights) @ data
gtr = function(weighted_data)
else:
gtr = function(data)

tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, atol=tol)


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("row_count", [100, 1000])
@pytest.mark.parametrize("column_count", [10, 100])
@pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_multiple_options_on_random_data(queue, row_count, column_count, weighted, dtype):
seed = 42
gen = np.random.default_rng(seed)
data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
data = data.astype(dtype=dtype)

if weighted:
weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
weights = weights.astype(dtype=dtype)
else:
weights = None

basicstat = BasicStatistics(result_options=["mean", "max", "sum"])

result = basicstat.fit(data, sample_weight=weights, queue=queue)

res_mean, res_max, res_sum = result.mean, result.max, result.sum
if weighted:
weighted_data = np.diag(weights) @ data
gtr_mean, gtr_max, gtr_sum = (
expected_mean(weighted_data),
expected_max(weighted_data),
expected_sum(weighted_data),
)
else:
gtr_mean, gtr_max, gtr_sum = (
expected_mean(data),
expected_max(data),
expected_sum(data),
)

tol = 3e-4 if res_mean.dtype == np.float32 else 1e-7
assert_allclose(gtr_mean, res_mean, atol=tol)
assert_allclose(gtr_max, res_max, atol=tol)
assert_allclose(gtr_sum, res_sum, atol=tol)


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("row_count", [100, 1000])
@pytest.mark.parametrize("column_count", [10, 100])
@pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_all_option_on_random_data(queue, row_count, column_count, weighted, dtype):
seed = 77
gen = np.random.default_rng(seed)
data = gen.uniform(low=-0.3, high=+0.7, size=(row_count, column_count))
data = data.astype(dtype=dtype)
if weighted:
weights = gen.uniform(low=-0.5, high=+1.0, size=row_count)
weights = weights.astype(dtype=dtype)
else:
weights = None

basicstat = BasicStatistics(result_options="all")

result = basicstat.fit(data, sample_weight=weights, queue=queue)

if weighted:
weighted_data = np.diag(weights) @ data

for option in options_and_tests:
result_option, function, tols = option
fp32tol, fp64tol = tols
res = getattr(result, result_option)
if weighted:
gtr = function(weighted_data)
else:
gtr = function(data)
tol = fp32tol if res.dtype == np.float32 else fp64tol
assert_allclose(gtr, res, rtol=tol)
assert_allclose(gtr, res, atol=tol)
64 changes: 6 additions & 58 deletions onedal/basic_statistics/tests/test_incremental_basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,67 +19,15 @@
from numpy.testing import assert_allclose

from onedal.basic_statistics import IncrementalBasicStatistics
from onedal.basic_statistics.tests.test_basic_statistics import (
expected_max,
expected_mean,
expected_sum,
options_and_tests,
)
from onedal.tests.utils._device_selection import get_queues


def expected_sum(X):
return np.sum(X, axis=0)


def expected_max(X):
return np.max(X, axis=0)


def expected_min(X):
return np.min(X, axis=0)


def expected_mean(X):
return np.mean(X, axis=0)


def expected_standard_deviation(X):
return np.std(X, axis=0)


def expected_variance(X):
return np.var(X, axis=0)


def expected_variation(X):
return expected_standard_deviation(X) / expected_mean(X)


def expected_sum_squares(X):
return np.sum(np.square(X), axis=0)


def expected_sum_squares_centered(X):
return np.sum(np.square(X - expected_mean(X)), axis=0)


def expected_standard_deviation(X):
return np.sqrt(expected_variance(X))


def expected_second_order_raw_moment(X):
return np.mean(np.square(X), axis=0)


options_and_tests = [
("sum", expected_sum, (3e-4, 1e-7)),
("min", expected_min, (1e-7, 1e-7)),
("max", expected_max, (1e-7, 1e-7)),
("mean", expected_mean, (3e-7, 1e-7)),
("variance", expected_variance, (2e-3, 2e-3)),
("variation", expected_variation, (5e-2, 5e-2)),
("sum_squares", expected_sum_squares, (2e-4, 1e-7)),
("sum_squares_centered", expected_sum_squares_centered, (2e-4, 1e-7)),
("standard_deviation", expected_standard_deviation, (2e-3, 2e-3)),
("second_order_raw_moment", expected_second_order_raw_moment, (1e-6, 1e-7)),
]


@pytest.mark.parametrize("queue", get_queues())
@pytest.mark.parametrize("weighted", [True, False])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
Expand Down
Loading

0 comments on commit b0c7ad2

Please sign in to comment.