From 0f21baf48e9d25911ec9e4da3eed5b841de4143e Mon Sep 17 00:00:00 2001 From: Jonathan Daniel <36337649+jond01@users.noreply.github.com> Date: Tue, 19 Mar 2024 15:42:57 +0200 Subject: [PATCH] [Model Monitoring] Log artifacts in the histogram data drift app (#5280) --- mlrun/artifacts/plots.py | 14 +- mlrun/model_monitoring/api.py | 18 +- mlrun/model_monitoring/application.py | 42 ++--- .../applications/histogram_data_drift.py | 170 +++++++++++++----- mlrun/model_monitoring/batch.py | 43 +---- mlrun/model_monitoring/controller.py | 2 +- .../model_monitoring/features_drift_table.py | 56 +++--- mlrun/model_monitoring/helpers.py | 49 ++++- .../assets/feature_stats.csv | 23 +++ .../assets/sample_df_stats.csv | 23 +++ .../test_histogram_data_drift.py | 108 ++++++++--- .../test_features_drift_table.py | 59 +++--- .../model_monitoring/assets/application.py | 5 +- .../assets/custom_evidently_app.py | 5 +- 14 files changed, 414 insertions(+), 203 deletions(-) create mode 100644 tests/model_monitoring/test_applications/assets/feature_stats.csv create mode 100644 tests/model_monitoring/test_applications/assets/sample_df_stats.csv diff --git a/mlrun/artifacts/plots.py b/mlrun/artifacts/plots.py index 841fd65046a..bfc5995986e 100644 --- a/mlrun/artifacts/plots.py +++ b/mlrun/artifacts/plots.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import base64 +import typing from io import BytesIO from deprecated import deprecated @@ -21,6 +22,9 @@ from ..utils import dict_to_json from .base import Artifact, LegacyArtifact +if typing.TYPE_CHECKING: + from plotly.graph_objs import Figure + class PlotArtifact(Artifact): kind = "plot" @@ -207,10 +211,10 @@ class PlotlyArtifact(Artifact): def __init__( self, - figure=None, - key: str = None, - target_path: str = None, - ): + figure: typing.Optional["Figure"] = None, + key: typing.Optional[str] = None, + target_path: typing.Optional[str] = None, + ) -> None: """ Initialize a Plotly artifact with the given figure. @@ -247,7 +251,7 @@ def __init__( self._figure = figure self.spec.format = "html" - def get_body(self): + def get_body(self) -> str: """ Get the artifact's body - the Plotly figure's html code. diff --git a/mlrun/model_monitoring/api.py b/mlrun/model_monitoring/api.py index 74be2125e8b..7eb3e81539d 100644 --- a/mlrun/model_monitoring/api.py +++ b/mlrun/model_monitoring/api.py @@ -704,8 +704,8 @@ def perform_drift_analysis( drift_detected_threshold=drift_threshold, ) - # Drift table plot - html_plot = FeaturesDriftTablePlot().produce( + # Drift table artifact + plotly_artifact = FeaturesDriftTablePlot().produce( sample_set_statistics=sample_set_statistics, inputs_statistics=inputs_statistics, metrics=metrics, @@ -732,7 +732,7 @@ def perform_drift_analysis( # Log the different artifacts _log_drift_artifacts( context=context, - html_plot=html_plot, + plotly_artifact=plotly_artifact, metrics_per_feature=metrics_per_feature, drift_status=drift_status, drift_metric=drift_metric, @@ -742,7 +742,7 @@ def perform_drift_analysis( def _log_drift_artifacts( context: mlrun.MLClientCtx, - html_plot: str, + plotly_artifact: mlrun.artifacts.Artifact, metrics_per_feature: dict[str, float], drift_status: bool, drift_metric: float, @@ -755,20 +755,14 @@ def _log_drift_artifacts( 3 - Results of the total drift analysis :param context: MLRun context. Will log the artifacts. - :param html_plot: Body of the html file of the plot. + :param plotly_artifact: The plotly artifact. :param metrics_per_feature: Dictionary in which the key is a feature name and the value is the drift numerical result. :param drift_status: Boolean value that represents the final drift analysis result. :param drift_metric: The final drift numerical result. :param artifacts_tag: Tag to use for all the artifacts resulted from the function. - """ - context.log_artifact( - mlrun.artifacts.Artifact( - body=html_plot.encode("utf-8"), format="html", key="drift_table_plot" - ), - tag=artifacts_tag, - ) + context.log_artifact(plotly_artifact, tag=artifacts_tag) context.log_artifact( mlrun.artifacts.Artifact( body=json.dumps(metrics_per_feature), diff --git a/mlrun/model_monitoring/application.py b/mlrun/model_monitoring/application.py index 1971ad013b4..30c29609fef 100644 --- a/mlrun/model_monitoring/application.py +++ b/mlrun/model_monitoring/application.py @@ -16,13 +16,13 @@ import json import re from abc import ABC, abstractmethod -from typing import Any, Optional, Union +from typing import Any, Optional, Union, cast import numpy as np import pandas as pd import mlrun.common.helpers -import mlrun.common.schemas.model_monitoring +import mlrun.common.model_monitoring.helpers import mlrun.common.schemas.model_monitoring.constants as mm_constant import mlrun.utils.v3io_clients from mlrun.datastore import get_stream_pusher @@ -84,8 +84,8 @@ class ModelMonitoringApplicationBase(StepToDict, ABC): class MyApp(ApplicationBase): def do_tracking( self, - sample_df_stats: pd.DataFrame, - feature_stats: pd.DataFrame, + sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats, + feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats, start_infer_time: pd.Timestamp, end_infer_time: pd.Timestamp, schedule_time: pd.Timestamp, @@ -93,7 +93,7 @@ def do_tracking( endpoint_id: str, output_stream_uri: str, ) -> ModelMonitoringApplicationResult: - self.context.log_artifact(TableArtifact("sample_df_stats", df=sample_df_stats)) + self.context.log_artifact(TableArtifact("sample_df_stats", df=self.dict_to_histogram(sample_df_stats))) return ModelMonitoringApplicationResult( name="data_drift_test", value=0.5, @@ -126,14 +126,16 @@ def do( return results, event def _lazy_init(self, app_name: str): - self.context = self._create_context_for_logging(app_name=app_name) + self.context = cast( + mlrun.MLClientCtx, self._create_context_for_logging(app_name=app_name) + ) @abstractmethod def do_tracking( self, application_name: str, - sample_df_stats: pd.DataFrame, - feature_stats: pd.DataFrame, + sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats, + feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats, sample_df: pd.DataFrame, start_infer_time: pd.Timestamp, end_infer_time: pd.Timestamp, @@ -147,8 +149,8 @@ def do_tracking( Implement this method with your custom monitoring logic. :param application_name: (str) the app name - :param sample_df_stats: (pd.DataFrame) The new sample distribution DataFrame. - :param feature_stats: (pd.DataFrame) The train sample distribution DataFrame. + :param sample_df_stats: (FeatureStats) The new sample distribution dictionary. + :param feature_stats: (FeatureStats) The train sample distribution dictionary. :param sample_df: (pd.DataFrame) The new sample DataFrame. :param start_infer_time: (pd.Timestamp) Start time of the monitoring schedule. :param end_infer_time: (pd.Timestamp) End time of the monitoring schedule. @@ -167,8 +169,8 @@ def _resolve_event( event: dict[str, Any], ) -> tuple[ str, - pd.DataFrame, - pd.DataFrame, + mlrun.common.model_monitoring.helpers.FeatureStats, + mlrun.common.model_monitoring.helpers.FeatureStats, pd.DataFrame, pd.Timestamp, pd.Timestamp, @@ -184,8 +186,8 @@ def _resolve_event( :return: A tuple of: [0] = (str) application name - [1] = (pd.DataFrame) current input statistics - [2] = (pd.DataFrame) train statistics + [1] = (dict) current input statistics + [2] = (dict) train statistics [3] = (pd.DataFrame) current input data [4] = (pd.Timestamp) start time of the monitoring schedule [5] = (pd.Timestamp) end time of the monitoring schedule @@ -197,12 +199,8 @@ def _resolve_event( end_time = pd.Timestamp(event[mm_constant.ApplicationEvent.END_INFER_TIME]) return ( event[mm_constant.ApplicationEvent.APPLICATION_NAME], - cls._dict_to_histogram( - json.loads(event[mm_constant.ApplicationEvent.CURRENT_STATS]) - ), - cls._dict_to_histogram( - json.loads(event[mm_constant.ApplicationEvent.FEATURE_STATS]) - ), + json.loads(event[mm_constant.ApplicationEvent.CURRENT_STATS]), + json.loads(event[mm_constant.ApplicationEvent.FEATURE_STATS]), ParquetTarget( path=event[mm_constant.ApplicationEvent.SAMPLE_PARQUET_PATH] ).as_df(start_time=start_time, end_time=end_time, time_column="timestamp"), @@ -223,7 +221,9 @@ def _create_context_for_logging(app_name: str): return context @staticmethod - def _dict_to_histogram(histogram_dict: dict[str, dict[str, Any]]) -> pd.DataFrame: + def dict_to_histogram( + histogram_dict: mlrun.common.model_monitoring.helpers.FeatureStats, + ) -> pd.DataFrame: """ Convert histogram dictionary to pandas DataFrame with feature histograms as columns diff --git a/mlrun/model_monitoring/applications/histogram_data_drift.py b/mlrun/model_monitoring/applications/histogram_data_drift.py index 5053e8f5fdc..c0e72963cf9 100644 --- a/mlrun/model_monitoring/applications/histogram_data_drift.py +++ b/mlrun/model_monitoring/applications/histogram_data_drift.py @@ -13,13 +13,17 @@ # limitations under the License. from dataclasses import dataclass -from typing import Final, Optional, Protocol +from typing import Final, Optional, Protocol, cast import numpy as np -from pandas import DataFrame, Timestamp +from pandas import DataFrame, Series, Timestamp +import mlrun.artifacts +import mlrun.common.model_monitoring.helpers +import mlrun.model_monitoring.features_drift_table as mm_drift_table from mlrun.common.schemas.model_monitoring.constants import ( MLRUN_HISTOGRAM_DATA_DRIFT_APP_NAME, + EventFieldType, ResultKindApp, ResultStatusApp, ) @@ -27,7 +31,7 @@ ModelMonitoringApplicationBase, ModelMonitoringApplicationResult, ) -from mlrun.model_monitoring.batch import ( +from mlrun.model_monitoring.metrics.histogram_distance import ( HellingerDistance, HistogramDistanceMetric, KullbackLeiblerDivergence, @@ -115,31 +119,24 @@ def __init__(self, value_classifier: Optional[ValueClassifier] = None) -> None: def _compute_metrics_per_feature( self, sample_df_stats: DataFrame, feature_stats: DataFrame - ) -> dict[type[HistogramDistanceMetric], list[float]]: + ) -> DataFrame: """Compute the metrics for the different features and labels""" - metrics_per_feature: dict[type[HistogramDistanceMetric], list[float]] = { - metric_class: [] for metric_class in self.metrics - } + metrics_per_feature = DataFrame( + columns=[metric_class.NAME for metric_class in self.metrics] + ) - for (sample_feat, sample_hist), (reference_feat, reference_hist) in zip( - sample_df_stats.items(), feature_stats.items() - ): - assert sample_feat == reference_feat, "The features do not match" + for feature_name in feature_stats: + sample_hist = np.asarray(sample_df_stats[feature_name]) + reference_hist = np.asarray(feature_stats[feature_name]) self.context.logger.info( - "Computing metrics for feature", feature_name=sample_feat + "Computing metrics for feature", feature_name=feature_name ) - sample_arr = np.asarray(sample_hist) - reference_arr = np.asarray(reference_hist) - for metric in self.metrics: - metric_name = metric.NAME - self.context.logger.debug( - "Computing data drift metric", - metric_name=metric_name, - feature_name=sample_feat, - ) - metrics_per_feature[metric].append( - metric(distrib_t=sample_arr, distrib_u=reference_arr).compute() - ) + metrics_per_feature.loc[feature_name] = { # pyright: ignore[reportCallIssue,reportArgumentType] + metric.NAME: metric( + distrib_t=sample_hist, distrib_u=reference_hist + ).compute() + for metric in self.metrics + } self.context.logger.info("Finished computing the metrics") return metrics_per_feature @@ -147,37 +144,37 @@ def _compute_metrics_per_feature( def _add_general_drift_result( self, results: list[ModelMonitoringApplicationResult], value: float ) -> None: + """Add the general drift result to the results list and log it""" + status = self._value_classifier.value_to_status(value) results.append( ModelMonitoringApplicationResult( name="general_drift", value=value, kind=self.METRIC_KIND, - status=self._value_classifier.value_to_status(value), + status=status, ) ) def _get_results( - self, metrics_per_feature: dict[type[HistogramDistanceMetric], list[float]] + self, metrics_per_feature: DataFrame ) -> list[ModelMonitoringApplicationResult]: """Average the metrics over the features and add the status""" results: list[ModelMonitoringApplicationResult] = [] - hellinger_tvd_values: list[float] = [] - for metric_class, metric_values in metrics_per_feature.items(): - self.context.logger.debug( - "Averaging metric over the features", metric_name=metric_class.NAME - ) - value = np.mean(metric_values) - if metric_class == KullbackLeiblerDivergence: + + self.context.logger.debug("Averaging metrics over the features") + metrics_mean = metrics_per_feature.mean().to_dict() + + self.context.logger.debug("Creating the results") + for name, value in metrics_mean.items(): + if name == KullbackLeiblerDivergence.NAME: # This metric is not bounded from above [0, inf). # No status is currently reported for KL divergence status = ResultStatusApp.irrelevant else: status = self._value_classifier.value_to_status(value) - if metric_class in self._REQUIRED_METRICS: - hellinger_tvd_values.append(value) results.append( ModelMonitoringApplicationResult( - name=f"{metric_class.NAME}_mean", + name=f"{name}_mean", value=value, kind=self.METRIC_KIND, status=status, @@ -185,16 +182,102 @@ def _get_results( ) self._add_general_drift_result( - results=results, value=np.mean(hellinger_tvd_values) + results=results, + value=np.mean( + [ + metrics_mean[HellingerDistance.NAME], + metrics_mean[TotalVarianceDistance.NAME], + ] + ), ) + self.context.logger.info("Finished with the results") return results + @staticmethod + def _remove_timestamp_feature( + sample_set_statistics: mlrun.common.model_monitoring.helpers.FeatureStats, + ) -> mlrun.common.model_monitoring.helpers.FeatureStats: + """ + Drop the 'timestamp' feature if it exists, as it is irrelevant + in the plotly artifact + """ + sample_set_statistics = mlrun.common.model_monitoring.helpers.FeatureStats( + sample_set_statistics.copy() + ) + if EventFieldType.TIMESTAMP in sample_set_statistics: + del sample_set_statistics[EventFieldType.TIMESTAMP] + return sample_set_statistics + + def _log_json_artifact(self, drift_per_feature_values: Series) -> None: + """Log the drift values as a JSON artifact""" + self.context.logger.debug("Logging drift value per feature JSON artifact") + self.context.log_artifact( + mlrun.artifacts.Artifact( + body=drift_per_feature_values.to_json(), + format="json", + key="features_drift_results", + ) + ) + self.context.logger.debug("Logged JSON artifact successfully") + + def _log_plotly_table_artifact( + self, + sample_set_statistics: mlrun.common.model_monitoring.helpers.FeatureStats, + inputs_statistics: mlrun.common.model_monitoring.helpers.FeatureStats, + metrics_per_feature: DataFrame, + drift_per_feature_values: Series, + ) -> None: + """Log the Plotly drift table artifact""" + self.context.logger.debug( + "Feature stats", + sample_set_statistics=sample_set_statistics, + inputs_statistics=inputs_statistics, + ) + + self.context.logger.debug("Computing drift results per feature") + drift_results = { + cast(str, key): (self._value_classifier.value_to_status(value), value) + for key, value in drift_per_feature_values.items() + } + self.context.logger.debug("Logging plotly artifact") + self.context.log_artifact( + mm_drift_table.FeaturesDriftTablePlot().produce( + sample_set_statistics=sample_set_statistics, + inputs_statistics=inputs_statistics, + metrics=metrics_per_feature.T.to_dict(), + drift_results=drift_results, + ) + ) + self.context.logger.debug("Logged plotly artifact successfully") + + def _log_drift_artifacts( + self, + sample_set_statistics: mlrun.common.model_monitoring.helpers.FeatureStats, + inputs_statistics: mlrun.common.model_monitoring.helpers.FeatureStats, + metrics_per_feature: DataFrame, + log_json_artifact: bool = True, + ) -> None: + """Log JSON and Plotly drift data per feature artifacts""" + drift_per_feature_values = metrics_per_feature[ + [HellingerDistance.NAME, TotalVarianceDistance.NAME] + ].mean(axis=1) + + if log_json_artifact: + self._log_json_artifact(drift_per_feature_values) + + self._log_plotly_table_artifact( + sample_set_statistics=self._remove_timestamp_feature(sample_set_statistics), + inputs_statistics=inputs_statistics, + metrics_per_feature=metrics_per_feature, + drift_per_feature_values=drift_per_feature_values, + ) + def do_tracking( self, application_name: str, - sample_df_stats: DataFrame, - feature_stats: DataFrame, + sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats, + feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats, sample_df: DataFrame, start_infer_time: Timestamp, end_infer_time: Timestamp, @@ -210,7 +293,14 @@ def do_tracking( """ self.context.logger.debug("Starting to run the application") metrics_per_feature = self._compute_metrics_per_feature( - sample_df_stats=sample_df_stats, feature_stats=feature_stats + sample_df_stats=self.dict_to_histogram(sample_df_stats), + feature_stats=self.dict_to_histogram(feature_stats), + ) + self.context.logger.debug("Saving artifacts") + self._log_drift_artifacts( + inputs_statistics=feature_stats, + sample_set_statistics=sample_df_stats, + metrics_per_feature=metrics_per_feature, ) self.context.logger.debug("Computing average per metric") results = self._get_results(metrics_per_feature) diff --git a/mlrun/model_monitoring/batch.py b/mlrun/model_monitoring/batch.py index acab78709ff..83ee254987d 100644 --- a/mlrun/model_monitoring/batch.py +++ b/mlrun/model_monitoring/batch.py @@ -33,6 +33,7 @@ import mlrun.data_types.infer import mlrun.feature_store as fstore import mlrun.utils.v3io_clients +from mlrun.model_monitoring.helpers import calculate_inputs_statistics from mlrun.model_monitoring.metrics.histogram_distance import ( HellingerDistance, HistogramDistanceMetric, @@ -353,48 +354,6 @@ def _get_drift_status( return drift_status -def calculate_inputs_statistics( - sample_set_statistics: dict, inputs: pd.DataFrame -) -> dict: - """ - Calculate the inputs data statistics for drift monitoring purpose. - - :param sample_set_statistics: The sample set (stored end point's dataset to reference) statistics. The bins of the - histograms of each feature will be used to recalculate the histograms of the inputs. - :param inputs: The inputs to calculate their statistics and later on - the drift with respect to the - sample set. - - :returns: The calculated statistics of the inputs data. - """ - - # Use `DFDataInfer` to calculate the statistics over the inputs: - inputs_statistics = mlrun.data_types.infer.DFDataInfer.get_stats( - df=inputs, - options=mlrun.data_types.infer.InferOptions.Histogram, - ) - - # Recalculate the histograms over the bins that are set in the sample-set of the end point: - for feature in inputs_statistics.keys(): - if feature in sample_set_statistics: - counts, bins = np.histogram( - inputs[feature].to_numpy(), - bins=sample_set_statistics[feature]["hist"][1], - ) - inputs_statistics[feature]["hist"] = [ - counts.tolist(), - bins.tolist(), - ] - elif "hist" in inputs_statistics[feature]: - # Comply with the other common features' histogram length - mlrun.common.model_monitoring.helpers.pad_hist( - mlrun.common.model_monitoring.helpers.Histogram( - inputs_statistics[feature]["hist"] - ) - ) - - return inputs_statistics - - class BatchProcessor: """ The main object to handle the batch processing job. This object is used to get the required configurations and diff --git a/mlrun/model_monitoring/controller.py b/mlrun/model_monitoring/controller.py index 62885ea7db1..ddf11aa7448 100644 --- a/mlrun/model_monitoring/controller.py +++ b/mlrun/model_monitoring/controller.py @@ -31,10 +31,10 @@ from mlrun.datastore import get_stream_pusher from mlrun.datastore.targets import ParquetTarget from mlrun.errors import err_to_str -from mlrun.model_monitoring.batch import calculate_inputs_statistics from mlrun.model_monitoring.helpers import ( _BatchDict, batch_dict2timedelta, + calculate_inputs_statistics, get_monitoring_parquet_path, get_stream_path, ) diff --git a/mlrun/model_monitoring/features_drift_table.py b/mlrun/model_monitoring/features_drift_table.py index b70a8894bd4..624dd7e6fa6 100644 --- a/mlrun/model_monitoring/features_drift_table.py +++ b/mlrun/model_monitoring/features_drift_table.py @@ -21,9 +21,34 @@ from plotly.subplots import make_subplots import mlrun.common.schemas.model_monitoring +from mlrun.artifacts import PlotlyArtifact # A type for representing a drift result, a tuple of the status and the drift mean: -DriftResultType = tuple[mlrun.common.schemas.model_monitoring.DriftStatus, float] +DriftResultType = tuple[ + mlrun.common.schemas.model_monitoring.constants.ResultStatusApp, float +] + + +class _PlotlyTableArtifact(PlotlyArtifact): + """A custom class for plotly table artifacts""" + + @staticmethod + def _disable_table_dragging(figure_html: str) -> str: + """ + Disable the table columns dragging by adding the following + JavaScript code + """ + start, end = figure_html.rsplit(";", 1) + middle = ( + ';for (const element of document.getElementsByClassName("table")) ' + '{element.style.pointerEvents = "none";}' + ) + figure_html = start + middle + end + return figure_html + + def get_body(self) -> str: + """Get the adjusted HTML representation of the figure""" + return self._disable_table_dragging(super().get_body()) class FeaturesDriftTablePlot: @@ -62,9 +87,9 @@ class FeaturesDriftTablePlot: # Status configurations: _STATUS_COLORS = { - mlrun.common.schemas.model_monitoring.DriftStatus.NO_DRIFT: "rgb(0,176,80)", # Green - mlrun.common.schemas.model_monitoring.DriftStatus.POSSIBLE_DRIFT: "rgb(255,192,0)", # Orange - mlrun.common.schemas.model_monitoring.DriftStatus.DRIFT_DETECTED: "rgb(208,0,106)", # Magenta + mlrun.common.schemas.model_monitoring.constants.ResultStatusApp.no_detection: "rgb(0,176,80)", # Green + mlrun.common.schemas.model_monitoring.constants.ResultStatusApp.potential_detection: "rgb(255,192,0)", # Orange + mlrun.common.schemas.model_monitoring.constants.ResultStatusApp.detected: "rgb(208,0,106)", # Magenta } # Font configurations: @@ -97,7 +122,7 @@ def produce( inputs_statistics: dict, metrics: dict[str, Union[dict, float]], drift_results: dict[str, DriftResultType], - ) -> str: + ) -> _PlotlyTableArtifact: """ Produce the html code of the table plot with the given information and the stored configurations in the class. @@ -106,9 +131,8 @@ def produce( :param metrics: The drift detection metrics calculated on the sample set and inputs. :param drift_results: The drift results per feature according to the rules of the monitor. - :return: The full path to the html file of the plot. + :return: The drift table as a plotly artifact. """ - # Plot the drift table: figure = self._plot( features=list(inputs_statistics.keys()), sample_set_statistics=sample_set_statistics, @@ -116,19 +140,7 @@ def produce( metrics=metrics, drift_results=drift_results, ) - - # Get its HTML representation: - figure_html = figure.to_html() - - # Turn off the table columns dragging by injecting the following JavaScript code: - start, end = figure_html.rsplit(";", 1) - middle = ( - ';for (const element of document.getElementsByClassName("table")) ' - '{element.style.pointerEvents = "none";}' - ) - figure_html = start + middle + end - - return figure_html + return _PlotlyTableArtifact(figure=figure, key="drift_table_plot") def _read_columns_names(self, statistics_dictionary: dict, drift_metrics: dict): """ @@ -366,10 +378,10 @@ def _plot_histogram_bars( bins = np.array(bins) if bins[0] == -sys.float_info.max: bins[0] = bins[1] - (bins[2] - bins[1]) - hovertext[0] = f"(-∞, {bins[1]})" + hovertext[0] = f"(-inf, {bins[1]})" if bins[-1] == sys.float_info.max: bins[-1] = bins[-2] + (bins[-2] - bins[-3]) - hovertext[-1] = f"({bins[-2]}, ∞)" + hovertext[-1] = f"({bins[-2]}, inf)" # Center the bins (leave the first one): bins = 0.5 * (bins[:-1] + bins[1:]) # Plot the histogram as a line with filled background below it: diff --git a/mlrun/model_monitoring/helpers.py b/mlrun/model_monitoring/helpers.py index e0e1e988147..b844c59d072 100644 --- a/mlrun/model_monitoring/helpers.py +++ b/mlrun/model_monitoring/helpers.py @@ -15,6 +15,9 @@ import datetime import typing +import numpy as np +import pandas as pd + import mlrun import mlrun.common.model_monitoring.helpers import mlrun.common.schemas @@ -36,10 +39,6 @@ class _BatchDict(typing.TypedDict): days: int -class _MLRunNoRunsFoundError(Exception): - pass - - def get_stream_path( project: str = None, function_name: str = mm_constants.MonitoringFunctionNames.STREAM, @@ -212,3 +211,45 @@ def update_model_endpoint_last_request( endpoint_id=model_endpoint.metadata.uid, attributes={EventFieldType.LAST_REQUEST: bumped_last_request}, ) + + +def calculate_inputs_statistics( + sample_set_statistics: dict, inputs: pd.DataFrame +) -> dict: + """ + Calculate the inputs data statistics for drift monitoring purpose. + + :param sample_set_statistics: The sample set (stored end point's dataset to reference) statistics. The bins of the + histograms of each feature will be used to recalculate the histograms of the inputs. + :param inputs: The inputs to calculate their statistics and later on - the drift with respect to the + sample set. + + :returns: The calculated statistics of the inputs data. + """ + + # Use `DFDataInfer` to calculate the statistics over the inputs: + inputs_statistics = mlrun.data_types.infer.DFDataInfer.get_stats( + df=inputs, + options=mlrun.data_types.infer.InferOptions.Histogram, + ) + + # Recalculate the histograms over the bins that are set in the sample-set of the end point: + for feature in inputs_statistics.keys(): + if feature in sample_set_statistics: + counts, bins = np.histogram( + inputs[feature].to_numpy(), + bins=sample_set_statistics[feature]["hist"][1], + ) + inputs_statistics[feature]["hist"] = [ + counts.tolist(), + bins.tolist(), + ] + elif "hist" in inputs_statistics[feature]: + # Comply with the other common features' histogram length + mlrun.common.model_monitoring.helpers.pad_hist( + mlrun.common.model_monitoring.helpers.Histogram( + inputs_statistics[feature]["hist"] + ) + ) + + return inputs_statistics diff --git a/tests/model_monitoring/test_applications/assets/feature_stats.csv b/tests/model_monitoring/test_applications/assets/feature_stats.csv new file mode 100644 index 00000000000..de76ff176fe --- /dev/null +++ b/tests/model_monitoring/test_applications/assets/feature_stats.csv @@ -0,0 +1,23 @@ +,sepal_length_cm,sepal_width_cm,petal_length_cm,petal_width_cm +0,0.0,0.0,0.0,0.0 +1,0.02666666666666667,0.006666666666666667,0.02666666666666667,0.22666666666666666 +2,0.03333333333333333,0.02,0.22,0.04666666666666667 +3,0.04666666666666667,0.02666666666666667,0.07333333333333333,0.04666666666666667 +4,0.10666666666666667,0.02,0.013333333333333334,0.006666666666666667 +5,0.06,0.05333333333333334,0.0,0.006666666666666667 +6,0.03333333333333333,0.09333333333333334,0.0,0.0 +7,0.08666666666666667,0.09333333333333334,0.006666666666666667,0.0 +8,0.09333333333333334,0.06666666666666667,0.013333333333333334,0.04666666666666667 +9,0.06666666666666667,0.17333333333333334,0.02,0.02 +10,0.04,0.07333333333333333,0.03333333333333333,0.03333333333333333 +11,0.06666666666666667,0.12666666666666668,0.08,0.14 +12,0.10666666666666667,0.08,0.09333333333333334,0.08 +13,0.04666666666666667,0.04,0.08,0.02666666666666667 +14,0.07333333333333333,0.02666666666666667,0.11333333333333333,0.013333333333333334 +15,0.02666666666666667,0.06,0.04,0.08 +16,0.013333333333333334,0.013333333333333334,0.08,0.07333333333333333 +17,0.02666666666666667,0.006666666666666667,0.04666666666666667,0.04 +18,0.006666666666666667,0.006666666666666667,0.02666666666666667,0.02 +19,0.03333333333333333,0.006666666666666667,0.013333333333333334,0.05333333333333334 +20,0.006666666666666667,0.006666666666666667,0.02,0.04 +21,0.0,0.0,0.0,0.0 diff --git a/tests/model_monitoring/test_applications/assets/sample_df_stats.csv b/tests/model_monitoring/test_applications/assets/sample_df_stats.csv new file mode 100644 index 00000000000..dc02ef3ba37 --- /dev/null +++ b/tests/model_monitoring/test_applications/assets/sample_df_stats.csv @@ -0,0 +1,23 @@ +,p0,petal_length_cm,petal_width_cm,sepal_length_cm,sepal_width_cm +0,0.0,1.0,1.0,1.0,1.0 +1,0.0,0.0,0.0,0.0,0.0 +2,0.0,0.0,0.0,0.0,0.0 +3,0.0,0.0,0.0,0.0,0.0 +4,0.0,0.0,0.0,0.0,0.0 +5,0.0,0.0,0.0,0.0,0.0 +6,0.0,0.0,0.0,0.0,0.0 +7,0.0,0.0,0.0,0.0,0.0 +8,0.0,0.0,0.0,0.0,0.0 +9,0.0,0.0,0.0,0.0,0.0 +10,0.0,0.0,0.0,0.0,0.0 +11,1.0,0.0,0.0,0.0,0.0 +12,0.0,0.0,0.0,0.0,0.0 +13,0.0,0.0,0.0,0.0,0.0 +14,0.0,0.0,0.0,0.0,0.0 +15,0.0,0.0,0.0,0.0,0.0 +16,0.0,0.0,0.0,0.0,0.0 +17,0.0,0.0,0.0,0.0,0.0 +18,0.0,0.0,0.0,0.0,0.0 +19,0.0,0.0,0.0,0.0,0.0 +20,0.0,0.0,0.0,0.0,0.0 +21,0.0,0.0,0.0,0.0,0.0 diff --git a/tests/model_monitoring/test_applications/test_histogram_data_drift.py b/tests/model_monitoring/test_applications/test_histogram_data_drift.py index 73d82687653..677e4b259d5 100644 --- a/tests/model_monitoring/test_applications/test_histogram_data_drift.py +++ b/tests/model_monitoring/test_applications/test_histogram_data_drift.py @@ -14,6 +14,7 @@ import inspect import logging +from pathlib import Path from typing import Any from unittest.mock import Mock @@ -22,6 +23,8 @@ from hypothesis import given from hypothesis import strategies as st +import mlrun.artifacts.manager +import mlrun.common.model_monitoring.helpers from mlrun import MLClientCtx from mlrun.common.schemas.model_monitoring.constants import ( ResultKindApp, @@ -35,6 +38,18 @@ ) from mlrun.utils import Logger +assets_folder = Path(__file__).parent / "assets" + + +@pytest.fixture +def application() -> HistogramDataDriftApplication: + app = HistogramDataDriftApplication() + app.context = MLClientCtx( + log_stream=Logger(name="test_data_drift_app", level=logging.DEBUG) + ) + app.context._artifacts_manager = Mock(spec=mlrun.artifacts.manager.ArtifactManager) + return app + class TestDataDriftClassifier: @staticmethod @@ -84,40 +99,64 @@ def test_status( class TestApplication: @staticmethod @pytest.fixture - def sample_df_stats() -> pd.DataFrame: - return pd.DataFrame.from_dict( + def sample_df_stats() -> mlrun.common.model_monitoring.helpers.FeatureStats: + return mlrun.common.model_monitoring.helpers.FeatureStats( { - "f1": [0.1, 0.3, 0, 0.3, 0.05, 0.25], - "f2": [0, 0.5, 0, 0.2, 0.05, 0.25], - "l": [0.9, 0, 0, 0, 0, 0.1], + "timestamp": { + "count": 1, + "25%": "2024-03-11 09:31:39.152301+00:00", + "50%": "2024-03-11 09:31:39.152301+00:00", + "75%": "2024-03-11 09:31:39.152301+00:00", + "max": "2024-03-11 09:31:39.152301+00:00", + "mean": "2024-03-11 09:31:39.152301+00:00", + "min": "2024-03-11 09:31:39.152301+00:00", + }, + "f1": { + "count": 100, + "hist": [[10, 30, 0, 30, 5, 25], [-10, -5, 0, 5, 10, 15, 20]], + }, + "f2": { + "count": 100, + "hist": [[0, 50, 0, 20, 5, 25], [66, 67, 68, 69, 70, 71, 72]], + }, + "l": { + "count": 100, + "hist": [ + [90, 0, 0, 0, 0, 10], + [0.0, 0.16, 0.33, 0.5, 0.67, 0.83, 1.0], + ], + }, } ) @staticmethod @pytest.fixture - def feature_stats() -> pd.DataFrame: - return pd.DataFrame.from_dict( + def feature_stats() -> mlrun.common.model_monitoring.helpers.FeatureStats: + return mlrun.common.model_monitoring.helpers.FeatureStats( { - "f1": [0, 0, 0, 0.3, 0.7, 0], - "f2": [0, 0.45, 0.05, 0.15, 0.35, 0], - "l": [0.3, 0, 0, 0, 0, 0.7], + "f1": { + "count": 100, + "hist": [[0, 0, 0, 30, 70, 0], [-10, -5, 0, 5, 10, 15, 20]], + }, + "f2": { + "count": 100, + "hist": [[0, 45, 5, 15, 35, 0], [66, 67, 68, 69, 70, 71, 72]], + }, + "l": { + "count": 100, + "hist": [ + [30, 0, 0, 0, 0, 70], + [0.0, 0.16, 0.33, 0.5, 0.67, 0.83, 1.0], + ], + }, } ) - @staticmethod - @pytest.fixture - def application() -> HistogramDataDriftApplication: - app = HistogramDataDriftApplication() - app.context = MLClientCtx( - log_stream=Logger(name="test_data_drift_app", level=logging.DEBUG) - ) - return app - @staticmethod @pytest.fixture def application_kwargs( - sample_df_stats: pd.DataFrame, - feature_stats: pd.DataFrame, + sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats, + feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats, application: HistogramDataDriftApplication, ) -> dict[str, Any]: kwargs = {} @@ -156,3 +195,30 @@ def test( result_by_name["general_drift"]["result_status"] == ResultStatusApp.potential_detection ), "Expected potential detection in the general drift" + + +@pytest.mark.parametrize( + ("sample_df_stats", "feature_stats"), + [ + pytest.param(pd.DataFrame(), pd.DataFrame(), id="empty-dfs"), + pytest.param( + pd.read_csv(assets_folder / "sample_df_stats.csv", index_col=0), + pd.read_csv(assets_folder / "feature_stats.csv", index_col=0), + id="real-world-csv-dfs", + ), + ], +) +def test_compute_metrics_per_feature( + application: HistogramDataDriftApplication, + sample_df_stats: pd.DataFrame, + feature_stats: pd.DataFrame, +) -> None: + metrics_per_feature = application._compute_metrics_per_feature( + sample_df_stats=sample_df_stats, feature_stats=feature_stats + ) + assert set(metrics_per_feature.columns) == { + metric.NAME for metric in application.metrics + }, "Different metrics than expected" + assert set(metrics_per_feature.index) == set( + feature_stats.columns + ), "The features are different than expected" diff --git a/tests/model_monitoring/test_features_drift_table.py b/tests/model_monitoring/test_features_drift_table.py index a5d53084602..e2bccc27a05 100644 --- a/tests/model_monitoring/test_features_drift_table.py +++ b/tests/model_monitoring/test_features_drift_table.py @@ -19,11 +19,11 @@ import pytest import mlrun -from mlrun.artifacts import Artifact +import mlrun.model_monitoring.applications.histogram_data_drift as histogram_data_drift +import mlrun.utils from mlrun.common.model_monitoring.helpers import FeatureStats, pad_features_hist from mlrun.data_types.infer import DFDataInfer, default_num_bins -from mlrun.model_monitoring.batch import VirtualDrift, calculate_inputs_statistics -from mlrun.model_monitoring.features_drift_table import FeaturesDriftTablePlot +from mlrun.model_monitoring.helpers import calculate_inputs_statistics def generate_data( @@ -62,53 +62,50 @@ def plot_produce(context: mlrun.MLClientCtx): ) # Calculate statistics: - sample_data_statistics = DFDataInfer.get_stats( - df=sample_data, - options=mlrun.data_types.infer.InferOptions.Histogram, + sample_data_statistics = FeatureStats( + DFDataInfer.get_stats( + df=sample_data, + options=mlrun.data_types.infer.InferOptions.Histogram, + ) ) - pad_features_hist(FeatureStats(sample_data_statistics)) - inputs_statistics = calculate_inputs_statistics( - sample_set_statistics=sample_data_statistics, - inputs=inputs, + pad_features_hist(sample_data_statistics) + inputs_statistics = FeatureStats( + calculate_inputs_statistics( + sample_set_statistics=sample_data_statistics, + inputs=inputs, + ) ) - # Calculate drift: - virtual_drift = VirtualDrift(inf_capping=10) - metrics = virtual_drift.compute_drift_from_histograms( - feature_stats=sample_data_statistics, - current_stats=inputs_statistics, - ) - drift_results = virtual_drift.check_for_drift_per_feature( - metrics_results_dictionary=metrics - ) + # Initialize the app + application = histogram_data_drift.HistogramDataDriftApplication() + application.context = context - # Plot: - html_plot = FeaturesDriftTablePlot().produce( + # Calculate drift + metrics_per_feature = application._compute_metrics_per_feature( + sample_df_stats=application.dict_to_histogram(sample_data_statistics), + feature_stats=application.dict_to_histogram(inputs_statistics), + ) + application._log_drift_artifacts( sample_set_statistics=sample_data_statistics, inputs_statistics=inputs_statistics, - metrics=metrics, - drift_results=drift_results, - ) - - # Log: - context.log_artifact( - Artifact(body=html_plot, format="html", key="drift_table_plot") + metrics_per_feature=metrics_per_feature, + log_json_artifact=False, ) def test_plot_produce(tmp_path: Path) -> None: # Run the plot production and logging: - train_run = mlrun.new_function().run( + app_plot_run = mlrun.new_function().run( artifact_path=str(tmp_path), handler=plot_produce, ) # Validate the artifact was logged: - assert len(train_run.status.artifacts) == 1 + assert len(app_plot_run.status.artifacts) == 1 # Check the plot was saved properly (only the drift table plot should appear): artifact_directory_content = list( - Path(train_run.status.artifacts[0]["spec"]["target_path"]).parent.glob("*") + Path(app_plot_run.status.artifacts[0]["spec"]["target_path"]).parent.glob("*") ) assert len(artifact_directory_content) == 1 assert artifact_directory_content[0].name == "drift_table_plot.html" diff --git a/tests/system/model_monitoring/assets/application.py b/tests/system/model_monitoring/assets/application.py index 4e6e56af741..0f4ddd89230 100644 --- a/tests/system/model_monitoring/assets/application.py +++ b/tests/system/model_monitoring/assets/application.py @@ -15,6 +15,7 @@ import pandas as pd import mlrun +import mlrun.common.model_monitoring.helpers from mlrun.common.schemas.model_monitoring.constants import ( ResultKindApp, ResultStatusApp, @@ -41,8 +42,8 @@ def __init_subclass__(cls, check_num_events: bool) -> None: def do_tracking( self, application_name: str, - sample_df_stats: pd.DataFrame, - feature_stats: pd.DataFrame, + sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats, + feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats, sample_df: pd.DataFrame, start_infer_time: pd.Timestamp, end_infer_time: pd.Timestamp, diff --git a/tests/system/model_monitoring/assets/custom_evidently_app.py b/tests/system/model_monitoring/assets/custom_evidently_app.py index 0abca0bba4e..56bc01acfc1 100644 --- a/tests/system/model_monitoring/assets/custom_evidently_app.py +++ b/tests/system/model_monitoring/assets/custom_evidently_app.py @@ -19,6 +19,7 @@ import pandas as pd from sklearn.datasets import load_iris +import mlrun.common.model_monitoring.helpers from mlrun.common.schemas.model_monitoring.constants import ( ResultKindApp, ResultStatusApp, @@ -163,8 +164,8 @@ def _init_evidently_project(self) -> None: def do_tracking( self, application_name: str, - sample_df_stats: pd.DataFrame, - feature_stats: pd.DataFrame, + sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats, + feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats, sample_df: pd.DataFrame, start_infer_time: pd.Timestamp, end_infer_time: pd.Timestamp,