diff --git a/ads/opctl/operator/lowcode/forecast/const.py b/ads/opctl/operator/lowcode/forecast/const.py index 3d057b3ec..3299f57d7 100644 --- a/ads/opctl/operator/lowcode/forecast/const.py +++ b/ads/opctl/operator/lowcode/forecast/const.py @@ -81,3 +81,4 @@ class ForecastOutputColumns(str, metaclass=ExtendedEnumMeta): DEFAULT_TRIALS = 10 SUMMARY_METRICS_HORIZON_LIMIT = 10 PROPHET_INTERNAL_DATE_COL = "ds" +RENDER_LIMIT = 5000 diff --git a/ads/opctl/operator/lowcode/forecast/model/base_model.py b/ads/opctl/operator/lowcode/forecast/model/base_model.py index e0f16eca1..2cb06665a 100644 --- a/ads/opctl/operator/lowcode/forecast/model/base_model.py +++ b/ads/opctl/operator/lowcode/forecast/model/base_model.py @@ -598,10 +598,7 @@ def explain_model(self, datetime_col_name, explain_predict_fn) -> dict: data_trimmed[datetime_col_name] = data_trimmed[datetime_col_name].apply(lambda x: x.timestamp()) kernel_explnr = PermutationExplainer( model=explain_predict_fn, - masker=data_trimmed, - keep_index=False - if self.spec.model == SupportedModels.AutoMLX - else True, + masker=data_trimmed ) kernel_explnr_vals = kernel_explnr.shap_values(data_trimmed) diff --git a/ads/opctl/operator/lowcode/forecast/utils.py b/ads/opctl/operator/lowcode/forecast/utils.py index 3d4ad6ee6..8810bd986 100644 --- a/ads/opctl/operator/lowcode/forecast/utils.py +++ b/ads/opctl/operator/lowcode/forecast/utils.py @@ -28,7 +28,7 @@ from ads.dataset.label_encoder import DataFrameLabelEncoder from ads.opctl import logger -from .const import SupportedMetrics, SupportedModels +from .const import SupportedMetrics, SupportedModels, RENDER_LIMIT from .errors import ForecastInputDataError, ForecastSchemaYamlError from .operator_config import ForecastOperatorSpec, ForecastOperatorConfig @@ -417,6 +417,23 @@ def get_forecast_plots( def plot_forecast_plotly(idx, col): fig = go.Figure() forecast_i = forecast_output.get_target_category(col) + actual_length = len(forecast_i) + if actual_length > RENDER_LIMIT: + forecast_i = forecast_i.tail(RENDER_LIMIT) + text = f"To improve rendering speed, subsampled the data from {actual_length}" \ + f" rows to {RENDER_LIMIT} rows for this plot." + fig.update_layout( + annotations=[ + go.layout.Annotation( + x=0.01, + y=1.1, + xref="paper", + yref="paper", + text=text, + showarrow=False + ) + ] + ) upper_bound = forecast_output.upper_bound_name lower_bound = forecast_output.lower_bound_name if upper_bound is not None and lower_bound is not None: