Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions ads/opctl/operator/lowcode/forecast/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,34 +573,44 @@ def _save_report(
if self.spec.generate_explanations:
try:
if not self.formatted_global_explanation.empty:
# Round to 4 decimal places before writing
global_expl_rounded = self.formatted_global_explanation.copy()
global_expl_rounded = global_expl_rounded.apply(
lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col
)
if self.spec.generate_explanation_files:
write_data(
data=self.formatted_global_explanation,
data=global_expl_rounded,
filename=os.path.join(
unique_output_dir, self.spec.global_explanation_filename
),
format="csv",
storage_options=storage_options,
index=True,
)
results.set_global_explanations(self.formatted_global_explanation)
results.set_global_explanations(global_expl_rounded)
else:
logger.warning(
f"Attempted to generate global explanations for the {self.spec.global_explanation_filename} file, but an issue occured in formatting the explanations."
)

if not self.formatted_local_explanation.empty:
# Round to 4 decimal places before writing
local_expl_rounded = self.formatted_local_explanation.copy()
local_expl_rounded = local_expl_rounded.apply(
lambda col: np.round(col, 4) if np.issubdtype(col.dtype, np.number) else col
)
if self.spec.generate_explanation_files:
write_data(
data=self.formatted_local_explanation,
data=local_expl_rounded,
filename=os.path.join(
unique_output_dir, self.spec.local_explanation_filename
),
format="csv",
storage_options=storage_options,
index=True,
)
results.set_local_explanations(self.formatted_local_explanation)
results.set_local_explanations(local_expl_rounded)
else:
logger.warning(
f"Attempted to generate local explanations for the {self.spec.local_explanation_filename} file, but an issue occured in formatting the explanations."
Expand Down
13 changes: 12 additions & 1 deletion tests/operators/forecast/test_explainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,8 +343,19 @@ def test_explanations_values(model, num_series, freq):
if model == "automlx":
pytest.xfail("automlx model does not provide fitted values")

# Check decimal precision for local explanations
local_numeric = local_explanations.select_dtypes(include=["int64", "float64"])
assert np.allclose(local_numeric, np.round(local_numeric, 4), atol=1e-8), \
"Local explanations have values with more than 4 decimal places"

# Check decimal precision for global explanations
global_explanations = results.get_global_explanations()
global_numeric = global_explanations.select_dtypes(include=["int64", "float64"])
assert np.allclose(global_numeric, np.round(global_numeric, 4), atol=1e-8), \
"Global explanations have values with more than 4 decimal places"

local_explain_vals = (
local_explanations.select_dtypes(include=["int64", "float64"]).sum(axis=1)
local_numeric.sum(axis=1)
+ forecast.fitted_value.mean()
)
assert np.allclose(
Expand Down