From 68ee92498a7fb254bb895eb0c468110c2cda25bd Mon Sep 17 00:00:00 2001 From: "Kruglov, Oleg" Date: Tue, 7 May 2024 06:44:00 -0700 Subject: [PATCH] Update spmd part --- examples/sklearnex/basic_statistics_spmd.py | 6 +++--- onedal/spmd/basic_statistics/basic_statistics.py | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/examples/sklearnex/basic_statistics_spmd.py b/examples/sklearnex/basic_statistics_spmd.py index 29864aea62..909c842cb9 100644 --- a/examples/sklearnex/basic_statistics_spmd.py +++ b/examples/sklearnex/basic_statistics_spmd.py @@ -58,7 +58,7 @@ def generate_data(par, size, seed=777): gtr_std = np.std(weighted_data, axis=0) bss = BasicStatisticsSpmd(["mean", "standard_deviation"]) -res = bss.compute(dpt_data, dpt_weights) +bss.fit(dpt_data, dpt_weights) -print(f"Computed mean on rank {rank}:\n", res["mean"]) -print(f"Computed std on rank {rank}:\n", res["standard_deviation"]) +print(f"Computed mean on rank {rank}:\n", bss.mean) +print(f"Computed std on rank {rank}:\n", bss.standard_deviation) diff --git a/onedal/spmd/basic_statistics/basic_statistics.py b/onedal/spmd/basic_statistics/basic_statistics.py index 27e37b1abc..c0ae2193a4 100644 --- a/onedal/spmd/basic_statistics/basic_statistics.py +++ b/onedal/spmd/basic_statistics/basic_statistics.py @@ -14,6 +14,8 @@ # limitations under the License. # ============================================================================== +import warnings + from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch from ..._device_offload import support_usm_ndarray @@ -23,4 +25,13 @@ class BasicStatistics(BaseEstimatorSPMD, BasicStatistics_Batch): @support_usm_ndarray() def compute(self, data, weights=None, queue=None): + warnings.warn( + "Method `compute` was deprecated in version 2024.3 and will be " + "removed in 2024.5. Use `fit` instead." + ) return super().compute(data, weights=weights, queue=queue) + + @support_usm_ndarray() + def fit(self, data, sample_weight=None, queue=None): + super().fit(data, sample_weight=sample_weight, queue=queue) + return self