Skip to content

Commit

Permalink
Update spmd part
Browse files Browse the repository at this point in the history
  • Loading branch information
olegkkruglov committed May 7, 2024
1 parent 6425a03 commit 68ee924
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
6 changes: 3 additions & 3 deletions examples/sklearnex/basic_statistics_spmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate_data(par, size, seed=777):
gtr_std = np.std(weighted_data, axis=0)

bss = BasicStatisticsSpmd(["mean", "standard_deviation"])
res = bss.compute(dpt_data, dpt_weights)
bss.fit(dpt_data, dpt_weights)

print(f"Computed mean on rank {rank}:\n", res["mean"])
print(f"Computed std on rank {rank}:\n", res["standard_deviation"])
print(f"Computed mean on rank {rank}:\n", bss.mean)
print(f"Computed std on rank {rank}:\n", bss.standard_deviation)
11 changes: 11 additions & 0 deletions onedal/spmd/basic_statistics/basic_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# limitations under the License.
# ==============================================================================

import warnings

from onedal.basic_statistics import BasicStatistics as BasicStatistics_Batch

from ..._device_offload import support_usm_ndarray
Expand All @@ -23,4 +25,13 @@
class BasicStatistics(BaseEstimatorSPMD, BasicStatistics_Batch):
@support_usm_ndarray()
def compute(self, data, weights=None, queue=None):
warnings.warn(
"Method `compute` was deprecated in version 2024.3 and will be "
"removed in 2024.5. Use `fit` instead."
)
return super().compute(data, weights=weights, queue=queue)

@support_usm_ndarray()
def fit(self, data, sample_weight=None, queue=None):
super().fit(data, sample_weight=sample_weight, queue=queue)
return self

0 comments on commit 68ee924

Please sign in to comment.