From 143f1a70011ff2e09397333357090bb1f0a417a9 Mon Sep 17 00:00:00 2001 From: danielperezz Date: Sun, 9 Nov 2025 10:03:09 +0200 Subject: [PATCH 1/2] histogram data drift module with empty example notebook --- .../assets/feature_stats.csv | 23 ++ .../assets/sample_df_stats.csv | 23 ++ .../histogram_data_drift.ipynb | 31 ++ .../histogram_data_drift.py | 388 ++++++++++++++++++ modules/src/histogram_data_drift/item.yaml | 19 + .../src/histogram_data_drift/requirements.txt | 3 + .../test_histogram_data_drift.py | 279 +++++++++++++ 7 files changed, 766 insertions(+) create mode 100644 modules/src/histogram_data_drift/assets/feature_stats.csv create mode 100644 modules/src/histogram_data_drift/assets/sample_df_stats.csv create mode 100644 modules/src/histogram_data_drift/histogram_data_drift.ipynb create mode 100644 modules/src/histogram_data_drift/histogram_data_drift.py create mode 100644 modules/src/histogram_data_drift/item.yaml create mode 100644 modules/src/histogram_data_drift/requirements.txt create mode 100644 modules/src/histogram_data_drift/test_histogram_data_drift.py diff --git a/modules/src/histogram_data_drift/assets/feature_stats.csv b/modules/src/histogram_data_drift/assets/feature_stats.csv new file mode 100644 index 000000000..de76ff176 --- /dev/null +++ b/modules/src/histogram_data_drift/assets/feature_stats.csv @@ -0,0 +1,23 @@ +,sepal_length_cm,sepal_width_cm,petal_length_cm,petal_width_cm +0,0.0,0.0,0.0,0.0 +1,0.02666666666666667,0.006666666666666667,0.02666666666666667,0.22666666666666666 +2,0.03333333333333333,0.02,0.22,0.04666666666666667 +3,0.04666666666666667,0.02666666666666667,0.07333333333333333,0.04666666666666667 +4,0.10666666666666667,0.02,0.013333333333333334,0.006666666666666667 +5,0.06,0.05333333333333334,0.0,0.006666666666666667 +6,0.03333333333333333,0.09333333333333334,0.0,0.0 +7,0.08666666666666667,0.09333333333333334,0.006666666666666667,0.0 +8,0.09333333333333334,0.06666666666666667,0.013333333333333334,0.04666666666666667 +9,0.06666666666666667,0.17333333333333334,0.02,0.02 +10,0.04,0.07333333333333333,0.03333333333333333,0.03333333333333333 +11,0.06666666666666667,0.12666666666666668,0.08,0.14 +12,0.10666666666666667,0.08,0.09333333333333334,0.08 +13,0.04666666666666667,0.04,0.08,0.02666666666666667 +14,0.07333333333333333,0.02666666666666667,0.11333333333333333,0.013333333333333334 +15,0.02666666666666667,0.06,0.04,0.08 +16,0.013333333333333334,0.013333333333333334,0.08,0.07333333333333333 +17,0.02666666666666667,0.006666666666666667,0.04666666666666667,0.04 +18,0.006666666666666667,0.006666666666666667,0.02666666666666667,0.02 +19,0.03333333333333333,0.006666666666666667,0.013333333333333334,0.05333333333333334 +20,0.006666666666666667,0.006666666666666667,0.02,0.04 +21,0.0,0.0,0.0,0.0 diff --git a/modules/src/histogram_data_drift/assets/sample_df_stats.csv b/modules/src/histogram_data_drift/assets/sample_df_stats.csv new file mode 100644 index 000000000..dc02ef3ba --- /dev/null +++ b/modules/src/histogram_data_drift/assets/sample_df_stats.csv @@ -0,0 +1,23 @@ +,p0,petal_length_cm,petal_width_cm,sepal_length_cm,sepal_width_cm +0,0.0,1.0,1.0,1.0,1.0 +1,0.0,0.0,0.0,0.0,0.0 +2,0.0,0.0,0.0,0.0,0.0 +3,0.0,0.0,0.0,0.0,0.0 +4,0.0,0.0,0.0,0.0,0.0 +5,0.0,0.0,0.0,0.0,0.0 +6,0.0,0.0,0.0,0.0,0.0 +7,0.0,0.0,0.0,0.0,0.0 +8,0.0,0.0,0.0,0.0,0.0 +9,0.0,0.0,0.0,0.0,0.0 +10,0.0,0.0,0.0,0.0,0.0 +11,1.0,0.0,0.0,0.0,0.0 +12,0.0,0.0,0.0,0.0,0.0 +13,0.0,0.0,0.0,0.0,0.0 +14,0.0,0.0,0.0,0.0,0.0 +15,0.0,0.0,0.0,0.0,0.0 +16,0.0,0.0,0.0,0.0,0.0 +17,0.0,0.0,0.0,0.0,0.0 +18,0.0,0.0,0.0,0.0,0.0 +19,0.0,0.0,0.0,0.0,0.0 +20,0.0,0.0,0.0,0.0,0.0 +21,0.0,0.0,0.0,0.0,0.0 diff --git a/modules/src/histogram_data_drift/histogram_data_drift.ipynb b/modules/src/histogram_data_drift/histogram_data_drift.ipynb new file mode 100644 index 000000000..54a15016a --- /dev/null +++ b/modules/src/histogram_data_drift/histogram_data_drift.ipynb @@ -0,0 +1,31 @@ +{ + "cells": [ + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Histogram Data Drift Demo", + "id": "2517d91b275da01d" + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/modules/src/histogram_data_drift/histogram_data_drift.py b/modules/src/histogram_data_drift/histogram_data_drift.py new file mode 100644 index 000000000..b8cdcf299 --- /dev/null +++ b/modules/src/histogram_data_drift/histogram_data_drift.py @@ -0,0 +1,388 @@ +# Copyright 2024 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import Final, Optional, Protocol, Union, cast + +import numpy as np +from pandas import DataFrame, Series + +import mlrun.artifacts +import mlrun.common.model_monitoring.helpers +import mlrun.model_monitoring.applications.context as mm_context +import mlrun.model_monitoring.applications.results as mm_results +import mlrun.model_monitoring.features_drift_table as mm_drift_table +from mlrun.common.schemas.model_monitoring.constants import ( + ResultKindApp, + ResultStatusApp, + StatsKind, +) +from mlrun.model_monitoring.applications import ( + ModelMonitoringApplicationBase, +) +from mlrun.model_monitoring.metrics.histogram_distance import ( + HellingerDistance, + HistogramDistanceMetric, + KullbackLeiblerDivergence, + TotalVarianceDistance, +) + + +class InvalidMetricValueError(ValueError): + pass + + +class InvalidThresholdValueError(ValueError): + pass + + +class ValueClassifier(Protocol): + def value_to_status(self, value: float) -> ResultStatusApp: ... + + +class HistogramDataDriftApplicationConstants: + NAME = "histogram-data-drift" + GENERAL_RESULT_NAME = "general_drift" + + +@dataclass +class DataDriftClassifier: + """ + Classify data drift numeric values into categorical status. + """ + + potential: float = 0.5 + detected: float = 0.7 + + def __post_init__(self) -> None: + """Catch erroneous threshold values""" + if not 0 < self.potential < self.detected < 1: + raise InvalidThresholdValueError( + "The provided thresholds do not comply with the rules" + ) + + def value_to_status(self, value: float) -> ResultStatusApp: + """ + Translate the numeric value into status category. + + :param value: The numeric value of the data drift metric, between 0 and 1. + :returns: `ResultStatusApp` according to the classification. + """ + if value > 1 or value < 0: + raise InvalidMetricValueError( + f"{value = } is invalid, must be in the range [0, 1]." + ) + if value >= self.detected: + return ResultStatusApp.detected + if value >= self.potential: + return ResultStatusApp.potential_detection + return ResultStatusApp.no_detection + + +class HistogramDataDriftApplication(ModelMonitoringApplicationBase): + """ + MLRun's default data drift application for model monitoring. + + The application expects tabular numerical data, and calculates three metrics over the shared features' histograms. + The metrics are calculated on features that have reference data from the training dataset. When there is no + reference data (`feature_stats`), this application send a warning log and does nothing. + The three metrics are: + + * Hellinger distance. + * Total variance distance. + * Kullback-Leibler divergence. + + Each metric is calculated over all the features individually and the mean is taken as the metric value. + The average of Hellinger and total variance distance is taken as the result. + + The application can log two artifacts (disabled by default due to performance issues): + + * JSON with the general drift value per feature. + * Plotly table with the various metrics and histograms per feature. + + If you want to change the application defaults, such as the classifier or which artifacts to produce, you + can either modify the downloaded source code file directly, or inherit from this class (in the same file), then + deploy it as any other model monitoring application. + Please make sure to keep the default application name. This ensures that the full functionality of the application, + including the statistics view in the UI, is available. + """ + + NAME: Final[str] = HistogramDataDriftApplicationConstants.NAME + + _REQUIRED_METRICS = {HellingerDistance, TotalVarianceDistance} + _STATS_TYPES: tuple[StatsKind, StatsKind] = ( + StatsKind.CURRENT_STATS, + StatsKind.DRIFT_MEASURES, + ) + + metrics: list[type[HistogramDistanceMetric]] = [ + HellingerDistance, + KullbackLeiblerDivergence, + TotalVarianceDistance, + ] + + def __init__( + self, + value_classifier: Optional[ValueClassifier] = None, + produce_json_artifact: bool = False, + produce_plotly_artifact: bool = False, + ) -> None: + """ + :param value_classifier: Classifier object that adheres to the :py:class:`~ValueClassifier` protocol. + If not provided, the default :py:class:`~DataDriftClassifier` is used. + :param produce_json_artifact: Whether to produce the JSON artifact or not, ``False`` by default. + :param produce_plotly_artifact: Whether to produce the Plotly artifact or not, ``False`` by default. + """ + self._value_classifier = value_classifier or DataDriftClassifier() + assert self._REQUIRED_METRICS <= set( + self.metrics + ), "TVD and Hellinger distance are required for the general data drift result" + + self._produce_json_artifact = produce_json_artifact + self._produce_plotly_artifact = produce_plotly_artifact + + def _compute_metrics_per_feature( + self, monitoring_context: mm_context.MonitoringApplicationContext + ) -> DataFrame: + """Compute the metrics for the different features and labels""" + metrics_per_feature = DataFrame( + columns=[metric_class.NAME for metric_class in self.metrics] + ) + feature_stats = monitoring_context.dict_to_histogram( + monitoring_context.feature_stats + ) + sample_df_stats = monitoring_context.dict_to_histogram( + monitoring_context.sample_df_stats + ) + for feature_name in feature_stats: + sample_hist = np.asarray(sample_df_stats[feature_name]) + reference_hist = np.asarray(feature_stats[feature_name]) + monitoring_context.logger.info( + "Computing metrics for feature", feature_name=feature_name + ) + metrics_per_feature.loc[feature_name] = { # pyright: ignore[reportCallIssue,reportArgumentType] + metric.NAME: metric( + distrib_t=sample_hist, distrib_u=reference_hist + ).compute() + for metric in self.metrics + } + monitoring_context.logger.info("Finished computing the metrics") + + return metrics_per_feature + + def _get_general_drift_result( + self, metrics: list[mm_results.ModelMonitoringApplicationMetric] + ) -> mm_results.ModelMonitoringApplicationResult: + """Get the general drift result from the metrics list""" + value = cast( + float, + np.mean( + [ + metric.value + for metric in metrics + if metric.name + in [ + f"{HellingerDistance.NAME}_mean", + f"{TotalVarianceDistance.NAME}_mean", + ] + ] + ), + ) + + status = self._value_classifier.value_to_status(value) + + return mm_results.ModelMonitoringApplicationResult( + name=HistogramDataDriftApplicationConstants.GENERAL_RESULT_NAME, + value=value, + kind=ResultKindApp.data_drift, + status=status, + ) + + @staticmethod + def _get_metrics( + metrics_per_feature: DataFrame, + ) -> list[mm_results.ModelMonitoringApplicationMetric]: + """Average the metrics over the features and add the status""" + metrics: list[mm_results.ModelMonitoringApplicationMetric] = [] + + metrics_mean = metrics_per_feature.mean().to_dict() + + for name, value in metrics_mean.items(): + metrics.append( + mm_results.ModelMonitoringApplicationMetric( + name=f"{name}_mean", + value=value, + ) + ) + + return metrics + + @staticmethod + def _get_stats( + metrics: list[mm_results.ModelMonitoringApplicationMetric], + metrics_per_feature: DataFrame, + monitoring_context: mm_context.MonitoringApplicationContext, + ) -> list[mm_results._ModelMonitoringApplicationStats]: + """ + Return a list of the statistics. + + :param metrics: the calculated metrics + :param metrics_per_feature: metric calculated per feature + :param monitoring_context: context object for current monitoring application + :returns: list of mm_results._ModelMonitoringApplicationStats for histogram data drift application + """ + stats = [] + for stats_type in HistogramDataDriftApplication._STATS_TYPES: + stats.append( + mm_results._ModelMonitoringApplicationStats( + name=stats_type, + stats=metrics_per_feature.T.to_dict() + | {metric.name: metric.value for metric in metrics} + if stats_type == StatsKind.DRIFT_MEASURES + else monitoring_context.sample_df_stats, + timestamp=monitoring_context.end_infer_time.isoformat( + sep=" ", timespec="microseconds" + ), + ) + ) + return stats + + @staticmethod + def _get_shared_features_sample_stats( + monitoring_context: mm_context.MonitoringApplicationContext, + ) -> mlrun.common.model_monitoring.helpers.FeatureStats: + """ + Filter out features without reference data in `feature_stats`, e.g. `timestamp`. + """ + return mlrun.common.model_monitoring.helpers.FeatureStats( + { + key: monitoring_context.sample_df_stats[key] + for key in monitoring_context.feature_stats + } + ) + + @staticmethod + def _log_json_artifact( + drift_per_feature_values: Series, + monitoring_context: mm_context.MonitoringApplicationContext, + ) -> None: + """Log the drift values as a JSON artifact""" + monitoring_context.logger.debug("Logging drift value per feature JSON artifact") + monitoring_context.log_artifact( + mlrun.artifacts.Artifact( + body=drift_per_feature_values.to_json(), + format="json", + key="features_drift_results", + ) + ) + monitoring_context.logger.debug("Logged JSON artifact successfully") + + def _log_plotly_table_artifact( + self, + sample_set_statistics: mlrun.common.model_monitoring.helpers.FeatureStats, + inputs_statistics: mlrun.common.model_monitoring.helpers.FeatureStats, + metrics_per_feature: DataFrame, + drift_per_feature_values: Series, + monitoring_context: mm_context.MonitoringApplicationContext, + ) -> None: + """Log the Plotly drift table artifact""" + monitoring_context.logger.debug( + "Feature stats", + sample_set_statistics=sample_set_statistics, + inputs_statistics=inputs_statistics, + ) + + monitoring_context.logger.debug("Computing drift results per feature") + drift_results = { + cast(str, key): (self._value_classifier.value_to_status(value), value) + for key, value in drift_per_feature_values.items() + } + monitoring_context.logger.debug("Producing plotly artifact") + artifact = mm_drift_table.FeaturesDriftTablePlot().produce( + sample_set_statistics=sample_set_statistics, + inputs_statistics=inputs_statistics, + metrics=metrics_per_feature.T.to_dict(), # pyright: ignore[reportArgumentType] + drift_results=drift_results, + ) + monitoring_context.logger.debug("Logging plotly artifact") + monitoring_context.log_artifact(artifact) + monitoring_context.logger.debug("Logged plotly artifact successfully") + + def _log_drift_artifacts( + self, + monitoring_context: mm_context.MonitoringApplicationContext, + metrics_per_feature: DataFrame, + ) -> None: + """Log JSON and Plotly drift data per feature artifacts""" + if not self._produce_json_artifact and not self._produce_plotly_artifact: + return + + drift_per_feature_values = metrics_per_feature[ + [HellingerDistance.NAME, TotalVarianceDistance.NAME] + ].mean(axis=1) + + if self._produce_json_artifact: + self._log_json_artifact(drift_per_feature_values, monitoring_context) + + if self._produce_plotly_artifact: + self._log_plotly_table_artifact( + sample_set_statistics=self._get_shared_features_sample_stats( + monitoring_context + ), + inputs_statistics=monitoring_context.feature_stats, + metrics_per_feature=metrics_per_feature, + drift_per_feature_values=drift_per_feature_values, + monitoring_context=monitoring_context, + ) + + def do_tracking( + self, monitoring_context: mm_context.MonitoringApplicationContext + ) -> list[ + Union[ + mm_results.ModelMonitoringApplicationResult, + mm_results.ModelMonitoringApplicationMetric, + mm_results._ModelMonitoringApplicationStats, + ] + ]: + """ + Calculate and return the data drift metrics, averaged over the features. + """ + monitoring_context.logger.debug("Starting to run the application") + if not monitoring_context.feature_stats: + monitoring_context.logger.warning( + "No feature statistics found, skipping the application. \n" + "In order to run the application, training set must be provided when logging the model." + ) + return [] + metrics_per_feature = self._compute_metrics_per_feature( + monitoring_context=monitoring_context + ) + monitoring_context.logger.debug("Saving artifacts") + self._log_drift_artifacts( + monitoring_context=monitoring_context, + metrics_per_feature=metrics_per_feature, + ) + monitoring_context.logger.debug("Computing average per metric") + metrics = self._get_metrics(metrics_per_feature) + result = self._get_general_drift_result(metrics=metrics) + stats = self._get_stats( + metrics=metrics, + monitoring_context=monitoring_context, + metrics_per_feature=metrics_per_feature, + ) + metrics_result_and_stats = metrics + [result] + stats + monitoring_context.logger.debug( + "Finished running the application", results=metrics_result_and_stats + ) + return metrics_result_and_stats diff --git a/modules/src/histogram_data_drift/item.yaml b/modules/src/histogram_data_drift/item.yaml new file mode 100644 index 000000000..09933d10b --- /dev/null +++ b/modules/src/histogram_data_drift/item.yaml @@ -0,0 +1,19 @@ +apiVersion: v1 +categories: +- model-serving +description: Model-monitoring application for detecting and visualizing data drift. +example: histogram_data_drift.ipynb +generationDate: 2025-11-06 +hidden: false +labels: + author: Iguazio +mlrunVersion: 1.10.0-rc41 +name: histogram_data_drift +spec: + filename: histogram_data_drift.py + image: mlrun/mlrun + kind: monitoring_application + requirements: + - plotly~=5.23 + - pandas +version: 1.0.0 \ No newline at end of file diff --git a/modules/src/histogram_data_drift/requirements.txt b/modules/src/histogram_data_drift/requirements.txt new file mode 100644 index 000000000..4c3614d2b --- /dev/null +++ b/modules/src/histogram_data_drift/requirements.txt @@ -0,0 +1,3 @@ +hypothesis[numpy]~=6.103 +plotly~=5.23 +pandas \ No newline at end of file diff --git a/modules/src/histogram_data_drift/test_histogram_data_drift.py b/modules/src/histogram_data_drift/test_histogram_data_drift.py new file mode 100644 index 000000000..018edaa86 --- /dev/null +++ b/modules/src/histogram_data_drift/test_histogram_data_drift.py @@ -0,0 +1,279 @@ +# Copyright 2024 Iguazio +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path +from unittest.mock import Mock + +import pandas as pd +import pytest +from hypothesis import given +from hypothesis import strategies as st + +import mlrun.common.model_monitoring.helpers +import mlrun.model_monitoring.applications +import mlrun.model_monitoring.applications.context as mm_context +import mlrun.utils +from mlrun.common.schemas.model_monitoring.constants import ( + ResultKindApp, + ResultStatusApp, +) +from histogram_data_drift import ( + DataDriftClassifier, + HistogramDataDriftApplication, + InvalidMetricValueError, + InvalidThresholdValueError, +) + +assets_folder = Path(__file__).parent / "assets" + + +@pytest.fixture +def project(tmp_path: Path) -> mlrun.MlrunProject: + project = mlrun.get_or_create_project("temp", allow_cross_project=True) + project.artifact_path = str(tmp_path) + return project + + +@pytest.fixture +def application() -> HistogramDataDriftApplication: + app = HistogramDataDriftApplication( + produce_json_artifact=True, produce_plotly_artifact=True + ) + return app + + +@pytest.fixture +def logger() -> mlrun.utils.Logger: + return mlrun.utils.Logger(level=logging.DEBUG, name="test_histogram_data_drift_app") + + +class TestDataDriftClassifier: + @staticmethod + @pytest.mark.parametrize( + ("potential", "detected"), [(0.4, 0.2), (0.0, 0.5), (0.7, 1.0), (-1, 2)] + ) + def test_invalid_threshold(potential: float, detected: float) -> None: + with pytest.raises(InvalidThresholdValueError): + DataDriftClassifier(potential=potential, detected=detected) + + @staticmethod + @given( + st.one_of( + st.floats(max_value=0, exclude_max=True), + st.floats(min_value=1, exclude_min=True), + ) + ) + def test_invalid_metric(value: float) -> None: + with pytest.raises(InvalidMetricValueError): + DataDriftClassifier().value_to_status(value) + + @staticmethod + @pytest.fixture + def classifier() -> DataDriftClassifier: + return DataDriftClassifier(potential=0.5, detected=0.7) + + @staticmethod + @pytest.mark.parametrize( + ("value", "expected_status"), + [ + (0, ResultStatusApp.no_detection), + (0.2, ResultStatusApp.no_detection), + (0.5, ResultStatusApp.potential_detection), + (0.6, ResultStatusApp.potential_detection), + (0.71, ResultStatusApp.detected), + (1, ResultStatusApp.detected), + ], + ) + def test_status( + classifier: DataDriftClassifier, value: float, expected_status: ResultStatusApp + ) -> None: + assert ( + classifier.value_to_status(value) == expected_status + ), "The status is different than expected" + + +class TestApplication: + COUNT = 12 # the sample df size + + @classmethod + @pytest.fixture + def sample_df_stats(cls) -> mlrun.common.model_monitoring.helpers.FeatureStats: + return mlrun.common.model_monitoring.helpers.FeatureStats( + { + "timestamp": { + "count": cls.COUNT, + "25%": "2024-03-11 09:31:39.152301+00:00", + "50%": "2024-03-11 09:31:39.152301+00:00", + "75%": "2024-03-11 09:31:39.152301+00:00", + "max": "2024-03-11 09:31:39.152301+00:00", + "mean": "2024-03-11 09:31:39.152301+00:00", + "min": "2024-03-11 09:31:39.152301+00:00", + }, + "ticker": { + "count": cls.COUNT, + "unique": 1, + "top": "AAPL", + "freq": cls.COUNT, + }, + "f1": { + "count": cls.COUNT, + "hist": [[2, 3, 0, 3, 1, 3], [-10, -5, 0, 5, 10, 15, 20]], + }, + "f2": { + "count": cls.COUNT, + "hist": [[0, 6, 0, 2, 1, 3], [66, 67, 68, 69, 70, 71, 72]], + }, + "l": { + "count": cls.COUNT, + "hist": [ + [10, 0, 0, 0, 0, 2], + [0.0, 0.16, 0.33, 0.5, 0.67, 0.83, 1.0], + ], + }, + } + ) + + @staticmethod + @pytest.fixture + def feature_stats() -> mlrun.common.model_monitoring.helpers.FeatureStats: + return mlrun.common.model_monitoring.helpers.FeatureStats( + { + "f1": { + "count": 100, + "hist": [[0, 0, 0, 30, 70, 0], [-10, -5, 0, 5, 10, 15, 20]], + }, + "f2": { + "count": 100, + "hist": [[0, 45, 5, 15, 35, 0], [66, 67, 68, 69, 70, 71, 72]], + }, + "l": { + "count": 100, + "hist": [ + [30, 0, 0, 0, 0, 70], + [0.0, 0.16, 0.33, 0.5, 0.67, 0.83, 1.0], + ], + }, + } + ) + + @staticmethod + @pytest.fixture + def monitoring_context( + sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats, + feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats, + application: HistogramDataDriftApplication, + logger: mlrun.utils.Logger, + project: mlrun.MlrunProject, + ) -> mm_context.MonitoringApplicationContext: + monitoring_context = mm_context.MonitoringApplicationContext( + application_name=application.NAME, + event={}, + artifacts_logger=project, + logger=logger, + project=project, + nuclio_logger=logger, # the wrong type but works here + ) + monitoring_context._sample_df_stats = sample_df_stats + monitoring_context._feature_stats = feature_stats + + return monitoring_context + + @classmethod + def test( + cls, + application: HistogramDataDriftApplication, + monitoring_context: mm_context.MonitoringApplicationContext, + project: mlrun.MlrunProject, + ) -> None: + results = application.do_tracking(monitoring_context) + metrics = [] + assert len(results) == 6, "Expected four results & metrics % stats" + for res in results: + if isinstance( + res, + mlrun.model_monitoring.applications.ModelMonitoringApplicationResult, + ): + assert ( + res.kind == ResultKindApp.data_drift + ), "The kind should be data drift" + assert ( + res.name == "general_drift" + ), "The result name should be general_drift" + assert ( + res.status == ResultStatusApp.potential_detection + ), "Expected potential detection in the general drift" + elif isinstance( + res, + mlrun.model_monitoring.applications.ModelMonitoringApplicationMetric, + ): + metrics.append(res) + assert len(metrics) == 3, "Expected three metrics" + + # Check the artifacts + assert project._artifact_manager.artifact_uris.keys() == { + "features_drift_results", + "drift_table_plot", + }, "The artifacts in the artifact manager are different than expected" + assert {f.name for f in Path(project.artifact_path).glob("*")} == { + "drift_table_plot.html", + "features_drift_results.json", + }, "The artifact files were not found or are different than expected" + + +class TestMetricsPerFeature: + @staticmethod + @pytest.fixture + def monitoring_context( + logger: mlrun.utils.Logger, + ) -> mm_context.MonitoringApplicationContext: + ctx = Mock() + + def dict_to_histogram(df: pd.DataFrame) -> pd.DataFrame: + return df + + ctx.dict_to_histogram = dict_to_histogram + ctx.logger = logger + return ctx + + @staticmethod + @pytest.mark.parametrize( + ("sample_df_stats", "feature_stats"), + [ + pytest.param(pd.DataFrame(), pd.DataFrame(), id="empty-dfs"), + pytest.param( + pd.read_csv(assets_folder / "sample_df_stats.csv", index_col=0), + pd.read_csv(assets_folder / "feature_stats.csv", index_col=0), + id="real-world-csv-dfs", + ), + ], + ) + def test_compute_metrics_per_feature( + application: HistogramDataDriftApplication, + monitoring_context: Mock, + sample_df_stats: pd.DataFrame, + feature_stats: pd.DataFrame, + ) -> None: + monitoring_context.sample_df_stats = sample_df_stats + monitoring_context.feature_stats = feature_stats + + metrics_per_feature = application._compute_metrics_per_feature( + monitoring_context=monitoring_context + ) + assert set(metrics_per_feature.columns) == { + metric.NAME for metric in application.metrics + }, "Different metrics than expected" + assert set(metrics_per_feature.index) == set( + feature_stats.columns + ), "The features are different than expected" From 94de979a0b015167e29df1760603bdeeb7b726f9 Mon Sep 17 00:00:00 2001 From: danielperezz Date: Sun, 9 Nov 2025 11:05:12 +0200 Subject: [PATCH 2/2] post review fixes --- modules/src/histogram_data_drift/item.yaml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/modules/src/histogram_data_drift/item.yaml b/modules/src/histogram_data_drift/item.yaml index 09933d10b..e439e1699 100644 --- a/modules/src/histogram_data_drift/item.yaml +++ b/modules/src/histogram_data_drift/item.yaml @@ -1,7 +1,8 @@ apiVersion: v1 categories: - model-serving -description: Model-monitoring application for detecting and visualizing data drift. +- structured-ML +description: Model-monitoring application for detecting and visualizing data drift example: histogram_data_drift.ipynb generationDate: 2025-11-06 hidden: false