Skip to content

Commit

Permalink
[Model Monitoring] Log artifacts in the histogram data drift app (#5280)
Browse files Browse the repository at this point in the history
  • Loading branch information
jond01 committed Mar 19, 2024
1 parent cf03af8 commit 0f21baf
Show file tree
Hide file tree
Showing 14 changed files with 414 additions and 203 deletions.
14 changes: 9 additions & 5 deletions mlrun/artifacts/plots.py
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import base64
import typing
from io import BytesIO

from deprecated import deprecated
Expand All @@ -21,6 +22,9 @@
from ..utils import dict_to_json
from .base import Artifact, LegacyArtifact

if typing.TYPE_CHECKING:
from plotly.graph_objs import Figure


class PlotArtifact(Artifact):
kind = "plot"
Expand Down Expand Up @@ -207,10 +211,10 @@ class PlotlyArtifact(Artifact):

def __init__(
self,
figure=None,
key: str = None,
target_path: str = None,
):
figure: typing.Optional["Figure"] = None,
key: typing.Optional[str] = None,
target_path: typing.Optional[str] = None,
) -> None:
"""
Initialize a Plotly artifact with the given figure.
Expand Down Expand Up @@ -247,7 +251,7 @@ def __init__(
self._figure = figure
self.spec.format = "html"

def get_body(self):
def get_body(self) -> str:
"""
Get the artifact's body - the Plotly figure's html code.
Expand Down
18 changes: 6 additions & 12 deletions mlrun/model_monitoring/api.py
Expand Up @@ -704,8 +704,8 @@ def perform_drift_analysis(
drift_detected_threshold=drift_threshold,
)

# Drift table plot
html_plot = FeaturesDriftTablePlot().produce(
# Drift table artifact
plotly_artifact = FeaturesDriftTablePlot().produce(
sample_set_statistics=sample_set_statistics,
inputs_statistics=inputs_statistics,
metrics=metrics,
Expand All @@ -732,7 +732,7 @@ def perform_drift_analysis(
# Log the different artifacts
_log_drift_artifacts(
context=context,
html_plot=html_plot,
plotly_artifact=plotly_artifact,
metrics_per_feature=metrics_per_feature,
drift_status=drift_status,
drift_metric=drift_metric,
Expand All @@ -742,7 +742,7 @@ def perform_drift_analysis(

def _log_drift_artifacts(
context: mlrun.MLClientCtx,
html_plot: str,
plotly_artifact: mlrun.artifacts.Artifact,
metrics_per_feature: dict[str, float],
drift_status: bool,
drift_metric: float,
Expand All @@ -755,20 +755,14 @@ def _log_drift_artifacts(
3 - Results of the total drift analysis
:param context: MLRun context. Will log the artifacts.
:param html_plot: Body of the html file of the plot.
:param plotly_artifact: The plotly artifact.
:param metrics_per_feature: Dictionary in which the key is a feature name and the value is the drift numerical
result.
:param drift_status: Boolean value that represents the final drift analysis result.
:param drift_metric: The final drift numerical result.
:param artifacts_tag: Tag to use for all the artifacts resulted from the function.
"""
context.log_artifact(
mlrun.artifacts.Artifact(
body=html_plot.encode("utf-8"), format="html", key="drift_table_plot"
),
tag=artifacts_tag,
)
context.log_artifact(plotly_artifact, tag=artifacts_tag)
context.log_artifact(
mlrun.artifacts.Artifact(
body=json.dumps(metrics_per_feature),
Expand Down
42 changes: 21 additions & 21 deletions mlrun/model_monitoring/application.py
Expand Up @@ -16,13 +16,13 @@
import json
import re
from abc import ABC, abstractmethod
from typing import Any, Optional, Union
from typing import Any, Optional, Union, cast

import numpy as np
import pandas as pd

import mlrun.common.helpers
import mlrun.common.schemas.model_monitoring
import mlrun.common.model_monitoring.helpers
import mlrun.common.schemas.model_monitoring.constants as mm_constant
import mlrun.utils.v3io_clients
from mlrun.datastore import get_stream_pusher
Expand Down Expand Up @@ -84,16 +84,16 @@ class ModelMonitoringApplicationBase(StepToDict, ABC):
class MyApp(ApplicationBase):
def do_tracking(
self,
sample_df_stats: pd.DataFrame,
feature_stats: pd.DataFrame,
sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats,
feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats,
start_infer_time: pd.Timestamp,
end_infer_time: pd.Timestamp,
schedule_time: pd.Timestamp,
latest_request: pd.Timestamp,
endpoint_id: str,
output_stream_uri: str,
) -> ModelMonitoringApplicationResult:
self.context.log_artifact(TableArtifact("sample_df_stats", df=sample_df_stats))
self.context.log_artifact(TableArtifact("sample_df_stats", df=self.dict_to_histogram(sample_df_stats)))
return ModelMonitoringApplicationResult(
name="data_drift_test",
value=0.5,
Expand Down Expand Up @@ -126,14 +126,16 @@ def do(
return results, event

def _lazy_init(self, app_name: str):
self.context = self._create_context_for_logging(app_name=app_name)
self.context = cast(
mlrun.MLClientCtx, self._create_context_for_logging(app_name=app_name)
)

@abstractmethod
def do_tracking(
self,
application_name: str,
sample_df_stats: pd.DataFrame,
feature_stats: pd.DataFrame,
sample_df_stats: mlrun.common.model_monitoring.helpers.FeatureStats,
feature_stats: mlrun.common.model_monitoring.helpers.FeatureStats,
sample_df: pd.DataFrame,
start_infer_time: pd.Timestamp,
end_infer_time: pd.Timestamp,
Expand All @@ -147,8 +149,8 @@ def do_tracking(
Implement this method with your custom monitoring logic.
:param application_name: (str) the app name
:param sample_df_stats: (pd.DataFrame) The new sample distribution DataFrame.
:param feature_stats: (pd.DataFrame) The train sample distribution DataFrame.
:param sample_df_stats: (FeatureStats) The new sample distribution dictionary.
:param feature_stats: (FeatureStats) The train sample distribution dictionary.
:param sample_df: (pd.DataFrame) The new sample DataFrame.
:param start_infer_time: (pd.Timestamp) Start time of the monitoring schedule.
:param end_infer_time: (pd.Timestamp) End time of the monitoring schedule.
Expand All @@ -167,8 +169,8 @@ def _resolve_event(
event: dict[str, Any],
) -> tuple[
str,
pd.DataFrame,
pd.DataFrame,
mlrun.common.model_monitoring.helpers.FeatureStats,
mlrun.common.model_monitoring.helpers.FeatureStats,
pd.DataFrame,
pd.Timestamp,
pd.Timestamp,
Expand All @@ -184,8 +186,8 @@ def _resolve_event(
:return: A tuple of:
[0] = (str) application name
[1] = (pd.DataFrame) current input statistics
[2] = (pd.DataFrame) train statistics
[1] = (dict) current input statistics
[2] = (dict) train statistics
[3] = (pd.DataFrame) current input data
[4] = (pd.Timestamp) start time of the monitoring schedule
[5] = (pd.Timestamp) end time of the monitoring schedule
Expand All @@ -197,12 +199,8 @@ def _resolve_event(
end_time = pd.Timestamp(event[mm_constant.ApplicationEvent.END_INFER_TIME])
return (
event[mm_constant.ApplicationEvent.APPLICATION_NAME],
cls._dict_to_histogram(
json.loads(event[mm_constant.ApplicationEvent.CURRENT_STATS])
),
cls._dict_to_histogram(
json.loads(event[mm_constant.ApplicationEvent.FEATURE_STATS])
),
json.loads(event[mm_constant.ApplicationEvent.CURRENT_STATS]),
json.loads(event[mm_constant.ApplicationEvent.FEATURE_STATS]),
ParquetTarget(
path=event[mm_constant.ApplicationEvent.SAMPLE_PARQUET_PATH]
).as_df(start_time=start_time, end_time=end_time, time_column="timestamp"),
Expand All @@ -223,7 +221,9 @@ def _create_context_for_logging(app_name: str):
return context

@staticmethod
def _dict_to_histogram(histogram_dict: dict[str, dict[str, Any]]) -> pd.DataFrame:
def dict_to_histogram(
histogram_dict: mlrun.common.model_monitoring.helpers.FeatureStats,
) -> pd.DataFrame:
"""
Convert histogram dictionary to pandas DataFrame with feature histograms as columns
Expand Down

0 comments on commit 0f21baf

Please sign in to comment.