diff --git a/ads/opctl/operator/lowcode/forecast/model/base_model.py b/ads/opctl/operator/lowcode/forecast/model/base_model.py index 9e5c51e95..e36efadf7 100644 --- a/ads/opctl/operator/lowcode/forecast/model/base_model.py +++ b/ads/opctl/operator/lowcode/forecast/model/base_model.py @@ -573,9 +573,14 @@ def _save_report( if self.spec.generate_explanations: try: if not self.formatted_global_explanation.empty: + # Round to 4 decimal places before writing + global_expl_rounded = self.formatted_global_explanation.copy() + global_expl_rounded = global_expl_rounded.apply( + lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col + ) if self.spec.generate_explanation_files: write_data( - data=self.formatted_global_explanation, + data=global_expl_rounded, filename=os.path.join( unique_output_dir, self.spec.global_explanation_filename ), @@ -583,16 +588,21 @@ def _save_report( storage_options=storage_options, index=True, ) - results.set_global_explanations(self.formatted_global_explanation) + results.set_global_explanations(global_expl_rounded) else: logger.warning( f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations." ) if not self.formatted_local_explanation.empty: + # Round to 4 decimal places before writing + local_expl_rounded = self.formatted_local_explanation.copy() + local_expl_rounded = local_expl_rounded.apply( + lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col + ) if self.spec.generate_explanation_files: write_data( - data=self.formatted_local_explanation, + data=local_expl_rounded, filename=os.path.join( unique_output_dir, self.spec.local_explanation_filename ), @@ -600,7 +610,7 @@ def _save_report( storage_options=storage_options, index=True, ) - results.set_local_explanations(self.formatted_local_explanation) + results.set_local_explanations(local_expl_rounded) else: logger.warning( f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations." diff --git a/tests/operators/forecast/test_explainers.py b/tests/operators/forecast/test_explainers.py index f158302b0..c634cab67 100644 --- a/tests/operators/forecast/test_explainers.py +++ b/tests/operators/forecast/test_explainers.py @@ -343,8 +343,19 @@ def test_explanations_values(model, num_series, freq): if model == "automlx": pytest.xfail("automlx model does not provide fitted values") + # Check decimal precision for local explanations + local_numeric = local_explanations.select_dtypes(include=["int64", "float64"]) + assert np.allclose(local_numeric, np.round(local_numeric, 4), atol=1e-8), \ + "Local explanations have values with more than 4 decimal places" + + # Check decimal precision for global explanations + global_explanations = results.get_global_explanations() + global_numeric = global_explanations.select_dtypes(include=["int64", "float64"]) + assert np.allclose(global_numeric, np.round(global_numeric, 4), atol=1e-8), \ + "Global explanations have values with more than 4 decimal places" + local_explain_vals = ( - local_explanations.select_dtypes(include=["int64", "float64"]).sum(axis=1) + local_numeric.sum(axis=1) + forecast.fitted_value.mean() ) assert np.allclose(