Skip to content

Commit

Permalink
[Model Monitoring] Improvements to mm-apps do_tracking API (#5474)
Browse files Browse the repository at this point in the history
  • Loading branch information
davesh0812 committed May 5, 2024
1 parent b0dbde7 commit f942608
Show file tree
Hide file tree
Showing 32 changed files with 1,797 additions and 699 deletions.
16 changes: 14 additions & 2 deletions docs/api/mlrun.model_monitoring.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,21 @@ mlrun.model_monitoring
.. automodule:: mlrun.model_monitoring.api
:members:

.. autoclass:: mlrun.model_monitoring.application.ModelMonitoringApplicationResult
.. autoclass:: mlrun.model_monitoring.applications.ModelMonitoringApplicationResult
:members:

.. autoclass:: mlrun.model_monitoring.application.ModelMonitoringApplicationBase
.. autoclass:: mlrun.model_monitoring.applications.ModelMonitoringApplicationMetric
:members:

.. autoclass:: mlrun.model_monitoring.applications.MonitoringApplicationContext
:members:
:exclude-members: sample_df, model_endpoint, feature_stats, sample_df_stats, feature_names, label_names, model

.. autoclass:: mlrun.model_monitoring.applications.ModelMonitoringApplicationBaseV2
:members:
:exclude-members: do

.. autoclass:: mlrun.model_monitoring.applications.ModelMonitoringApplicationBase
:members:
:exclude-members: do

2 changes: 2 additions & 0 deletions mlrun/common/schemas/model_monitoring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,15 @@
FeatureSetFeatures,
FileTargetKind,
FunctionURI,
MetricData,
ModelEndpointTarget,
ModelMonitoringMode,
ModelMonitoringStoreKinds,
MonitoringFunctionNames,
ProjectSecretKeys,
PrometheusEndpoints,
PrometheusMetric,
ResultData,
SchedulingKeys,
TimeSeriesTarget,
VersionedModel,
Expand Down
24 changes: 21 additions & 3 deletions mlrun/common/schemas/model_monitoring/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,21 +99,39 @@ def entity(cls):

class ApplicationEvent:
APPLICATION_NAME = "application_name"
CURRENT_STATS = "current_stats"
FEATURE_STATS = "feature_stats"
SAMPLE_PARQUET_PATH = "sample_parquet_path"
START_INFER_TIME = "start_infer_time"
END_INFER_TIME = "end_infer_time"
LAST_REQUEST = "last_request"
ENDPOINT_ID = "endpoint_id"
OUTPUT_STREAM_URI = "output_stream_uri"
MLRUN_CONTEXT = "mlrun_context"

# Deprecated fields - TODO : delete in 1.9.0 (V1 app deprecation)
SAMPLE_PARQUET_PATH = "sample_parquet_path"
CURRENT_STATS = "current_stats"
FEATURE_STATS = "feature_stats"


class WriterEvent(MonitoringStrEnum):
APPLICATION_NAME = "application_name"
ENDPOINT_ID = "endpoint_id"
START_INFER_TIME = "start_infer_time"
END_INFER_TIME = "end_infer_time"
EVENT_KIND = "event_kind" # metric or result
DATA = "data"


class WriterEventKind(MonitoringStrEnum):
METRIC = "metric"
RESULT = "result"


class MetricData(MonitoringStrEnum):
METRIC_NAME = "metric_name"
METRIC_VALUE = "metric_value"


class ResultData(MonitoringStrEnum):
RESULT_NAME = "result_name"
RESULT_VALUE = "result_value"
RESULT_KIND = "result_kind"
Expand Down
59 changes: 41 additions & 18 deletions mlrun/model_monitoring/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@

import mlrun.artifacts
import mlrun.common.helpers
import mlrun.common.schemas.model_monitoring.constants as mm_consts
import mlrun.common.schemas.model_monitoring.constants as mm_constants
import mlrun.feature_store
import mlrun.model_monitoring.application
import mlrun.model_monitoring.applications as mm_app
import mlrun.serving
from mlrun.data_types.infer import InferOptions, get_df_stats
from mlrun.utils import datetime_now, logger
Expand All @@ -48,7 +49,7 @@ def get_or_create_model_endpoint(
sample_set_statistics: dict[str, typing.Any] = None,
drift_threshold: float = None,
possible_drift_threshold: float = None,
monitoring_mode: mm_consts.ModelMonitoringMode = mm_consts.ModelMonitoringMode.disabled,
monitoring_mode: mm_constants.ModelMonitoringMode = mm_constants.ModelMonitoringMode.disabled,
db_session=None,
) -> ModelEndpoint:
"""
Expand Down Expand Up @@ -128,7 +129,7 @@ def record_results(
context: typing.Optional[mlrun.MLClientCtx] = None,
infer_results_df: typing.Optional[pd.DataFrame] = None,
sample_set_statistics: typing.Optional[dict[str, typing.Any]] = None,
monitoring_mode: mm_consts.ModelMonitoringMode = mm_consts.ModelMonitoringMode.enabled,
monitoring_mode: mm_constants.ModelMonitoringMode = mm_constants.ModelMonitoringMode.enabled,
# Deprecated arguments:
drift_threshold: typing.Optional[float] = None,
possible_drift_threshold: typing.Optional[float] = None,
Expand Down Expand Up @@ -282,7 +283,7 @@ def _model_endpoint_validations(
# drift and possible drift thresholds
if drift_threshold:
current_drift_threshold = model_endpoint.spec.monitor_configuration.get(
mm_consts.EventFieldType.DRIFT_DETECTED_THRESHOLD,
mm_constants.EventFieldType.DRIFT_DETECTED_THRESHOLD,
mlrun.mlconf.model_endpoint_monitoring.drift_thresholds.default.drift_detected,
)
if current_drift_threshold != drift_threshold:
Expand All @@ -293,7 +294,7 @@ def _model_endpoint_validations(

if possible_drift_threshold:
current_possible_drift_threshold = model_endpoint.spec.monitor_configuration.get(
mm_consts.EventFieldType.POSSIBLE_DRIFT_THRESHOLD,
mm_constants.EventFieldType.POSSIBLE_DRIFT_THRESHOLD,
mlrun.mlconf.model_endpoint_monitoring.drift_thresholds.default.possible_drift,
)
if current_possible_drift_threshold != possible_drift_threshold:
Expand Down Expand Up @@ -332,14 +333,14 @@ def write_monitoring_df(
)

# Modify the DataFrame to the required structure that will be used later by the monitoring batch job
if mm_consts.EventFieldType.TIMESTAMP not in infer_results_df.columns:
if mm_constants.EventFieldType.TIMESTAMP not in infer_results_df.columns:
# Initialize timestamp column with the current time
infer_results_df[mm_consts.EventFieldType.TIMESTAMP] = infer_datetime
infer_results_df[mm_constants.EventFieldType.TIMESTAMP] = infer_datetime

# `endpoint_id` is the monitoring feature set entity and therefore it should be defined as the df index before
# the ingest process
infer_results_df[mm_consts.EventFieldType.ENDPOINT_ID] = endpoint_id
infer_results_df.set_index(mm_consts.EventFieldType.ENDPOINT_ID, inplace=True)
infer_results_df[mm_constants.EventFieldType.ENDPOINT_ID] = endpoint_id
infer_results_df.set_index(mm_constants.EventFieldType.ENDPOINT_ID, inplace=True)

monitoring_feature_set.ingest(source=infer_results_df, overwrite=False)

Expand All @@ -355,7 +356,7 @@ def _generate_model_endpoint(
sample_set_statistics: dict[str, typing.Any],
drift_threshold: float,
possible_drift_threshold: float,
monitoring_mode: mm_consts.ModelMonitoringMode = mm_consts.ModelMonitoringMode.disabled,
monitoring_mode: mm_constants.ModelMonitoringMode = mm_constants.ModelMonitoringMode.disabled,
) -> ModelEndpoint:
"""
Write a new model endpoint record.
Expand Down Expand Up @@ -394,11 +395,11 @@ def _generate_model_endpoint(
model_endpoint.spec.model_class = "drift-analysis"
if drift_threshold:
model_endpoint.spec.monitor_configuration[
mm_consts.EventFieldType.DRIFT_DETECTED_THRESHOLD
mm_constants.EventFieldType.DRIFT_DETECTED_THRESHOLD
] = drift_threshold
if possible_drift_threshold:
model_endpoint.spec.monitor_configuration[
mm_consts.EventFieldType.POSSIBLE_DRIFT_THRESHOLD
mm_constants.EventFieldType.POSSIBLE_DRIFT_THRESHOLD
] = possible_drift_threshold

model_endpoint.spec.monitoring_mode = monitoring_mode
Expand Down Expand Up @@ -589,7 +590,10 @@ def _create_model_monitoring_function_base(
project: str,
func: typing.Union[str, None] = None,
application_class: typing.Union[
str, mlrun.model_monitoring.application.ModelMonitoringApplicationBase, None
str,
mlrun.model_monitoring.application.ModelMonitoringApplicationBase,
mm_app.ModelMonitoringApplicationBaseV2,
None,
] = None,
name: typing.Optional[str] = None,
image: typing.Optional[str] = None,
Expand All @@ -602,6 +606,20 @@ def _create_model_monitoring_function_base(
Note: this is an internal API only.
This function does not set the labels or mounts v3io.
"""
if isinstance(
application_class,
mlrun.model_monitoring.application.ModelMonitoringApplicationBase,
):
warnings.warn(
"The `ModelMonitoringApplicationBase` class is deprecated from version 1.7.0, "
"please use `ModelMonitoringApplicationBaseV2`. It will be removed in 1.9.0.",
FutureWarning,
)
if name in mm_constants.MonitoringFunctionNames.list():
raise mlrun.errors.MLRunInvalidArgumentError(
f"An application cannot have the following names: "
f"{mm_constants.MonitoringFunctionNames.list()}"
)
if func is None:
func = ""
func_obj = typing.cast(
Expand All @@ -618,14 +636,19 @@ def _create_model_monitoring_function_base(
),
)
graph = func_obj.set_topology(mlrun.serving.states.StepKinds.flow)
prepare_step = graph.to(
class_name="mlrun.model_monitoring.applications._application_steps._PrepareMonitoringEvent",
name="PrepareMonitoringEvent",
application_name=name,
)
if isinstance(application_class, str):
first_step = graph.to(class_name=application_class, **application_kwargs)
app_step = prepare_step.to(class_name=application_class, **application_kwargs)
else:
first_step = graph.to(class_name=application_class)
first_step.to(
class_name="mlrun.model_monitoring.application.PushToMonitoringWriter",
app_step = prepare_step.to(class_name=application_class)
app_step.to(
class_name="mlrun.model_monitoring.applications._application_steps._PushToMonitoringWriter",
name="PushToMonitoringWriter",
project=project,
writer_application_name=mm_consts.MonitoringFunctionNames.WRITER,
writer_application_name=mm_constants.MonitoringFunctionNames.WRITER,
).respond()
return func_obj

0 comments on commit f942608

Please sign in to comment.