Skip to content

Commit

Permalink
[Model Monitoring] Add top_level and list_ids to filters to `list…
Browse files Browse the repository at this point in the history
…_endpoints` (#1459)
  • Loading branch information
katyakats committed Nov 3, 2021
1 parent 11ef6d9 commit 9667636
Show file tree
Hide file tree
Showing 10 changed files with 144 additions and 30 deletions.
12 changes: 11 additions & 1 deletion mlrun/api/api/endpoints/model_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,12 +101,16 @@ def list_endpoints(
start: str = Query(default="now-1h"),
end: str = Query(default="now"),
metrics: List[str] = Query([], alias="metric"),
top_level: bool = Query(False, alias="top-level"),
uids: List[str] = Query(None, alias="uid"),
auth_info: mlrun.api.schemas.AuthInfo = Depends(
mlrun.api.api.deps.authenticate_request
),
) -> ModelEndpointList:
"""
Returns a list of endpoints of type 'ModelEndpoint', supports filtering by model, function, tag and labels.
Returns a list of endpoints of type 'ModelEndpoint', supports filtering by model, function, tag,
labels or top level.
If uids are passed: will return ModelEndpointList of endpoints with uid in uids
Labels can be used to filter on the existence of a label:
api/projects/{project}/model-endpoints/?label=mylabel
Expand All @@ -118,10 +122,13 @@ def list_endpoints(
Or by using a "," (comma) separator:
api/projects/{project}/model-endpoints/?label=mylabel=1,myotherlabel=2
Top level: if true will return only routers and endpoint that are NOT children of any router
"""

mlrun.api.utils.auth.verifier.AuthVerifier().query_project_permissions(
project, mlrun.api.schemas.AuthorizationAction.read, auth_info,
)

endpoints = mlrun.api.crud.ModelEndpoints().list_endpoints(
auth_info=auth_info,
project=project,
Expand All @@ -131,13 +138,16 @@ def list_endpoints(
metrics=metrics,
start=start,
end=end,
top_level=top_level,
uids=uids,
)
allowed_endpoints = mlrun.api.utils.auth.verifier.AuthVerifier().filter_project_resources_by_permissions(
mlrun.api.schemas.AuthorizationResourceTypes.model_endpoint,
endpoints.endpoints,
lambda _endpoint: (_endpoint.metadata.project, _endpoint.metadata.uid,),
auth_info,
)

endpoints.endpoints = allowed_endpoints
return endpoints

Expand Down
73 changes: 51 additions & 22 deletions mlrun/api/crud/model_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from mlrun.runtimes.function import get_nuclio_deploy_status
from mlrun.utils.helpers import logger
from mlrun.utils.model_monitoring import (
EndpointType,
parse_model_endpoint_project_prefix,
parse_model_endpoint_store_prefix,
)
Expand Down Expand Up @@ -179,13 +180,17 @@ def list_endpoints(
metrics: Optional[List[str]] = None,
start: str = "now-1h",
end: str = "now",
top_level: Optional[bool] = False,
uids: Optional[List[str]] = None,
) -> ModelEndpointList:
"""
Returns a list of ModelEndpointState objects. Each object represents the current state of a model endpoint.
This functions supports filtering by the following parameters:
1) model
2) function
3) labels
4) top level
5) uids
By default, when no filters are applied, all available endpoints for the given project will be listed.
In addition, this functions provides a facade for listing endpoint related metrics. This facade is time-based
Expand All @@ -201,6 +206,8 @@ def list_endpoints(
:param metrics: A list of metrics to return for each endpoint, read more in 'TimeMetric'
:param start: The start time of the metrics
:param end: The end time of the metrics
:param top_level: if True will return only routers and endpoint that are NOT children of any router
:param uids: will return ModelEndpointList of endpoints with uid in uids
"""

logger.info(
Expand All @@ -212,34 +219,38 @@ def list_endpoints(
metrics=metrics,
start=start,
end=end,
top_level=top_level,
uids=uids,
)

client = get_v3io_client(endpoint=config.v3io_api)
endpoint_list = ModelEndpointList(endpoints=[])

path = config.model_endpoint_monitoring.store_prefixes.default.format(
project=project, kind=mlrun.api.schemas.ModelMonitoringStoreKinds.ENDPOINTS
)
_, container, path = parse_model_endpoint_store_prefix(path)
if uids is None:
client = get_v3io_client(endpoint=config.v3io_api)

cursor = client.kv.new_cursor(
container=container,
table_path=path,
access_key=auth_info.data_session,
filter_expression=self.build_kv_cursor_filter_expression(
project, function, model, labels
),
attribute_names=["endpoint_id"],
raise_for_status=RaiseForStatus.never,
)
path = config.model_endpoint_monitoring.store_prefixes.default.format(
project=project,
kind=mlrun.api.schemas.ModelMonitoringStoreKinds.ENDPOINTS,
)
_, container, path = parse_model_endpoint_store_prefix(path)
cursor = client.kv.new_cursor(
container=container,
table_path=path,
access_key=auth_info.data_session,
filter_expression=self.build_kv_cursor_filter_expression(
project, function, model, labels, top_level,
),
attribute_names=["endpoint_id"],
raise_for_status=RaiseForStatus.never,
)
try:
items = cursor.all()
except Exception:
return endpoint_list

endpoint_list = ModelEndpointList(endpoints=[])
try:
items = cursor.all()
except Exception:
return endpoint_list
uids = [item["endpoint_id"] for item in items]

for item in items:
endpoint_id = item["endpoint_id"]
for endpoint_id in uids:
endpoint = self.get_endpoint(
auth_info=auth_info,
project=project,
Expand Down Expand Up @@ -320,6 +331,12 @@ def get_endpoint(
monitor_configuration = endpoint.get("monitor_configuration")
monitor_configuration = self._json_loads_if_not_none(monitor_configuration)

endpoint_type = endpoint.get("endpoint_type")
endpoint_type = self._json_loads_if_not_none(endpoint_type)

children_uids = endpoint.get("children_uids")
children_uids = self._json_loads_if_not_none(children_uids)

endpoint = ModelEndpoint(
metadata=ModelEndpointMetadata(
project=endpoint.get("project"),
Expand Down Expand Up @@ -348,6 +365,8 @@ def get_endpoint(
accuracy=endpoint.get("accuracy") or None,
error_count=endpoint.get("error_count") or None,
drift_status=endpoint.get("drift_status") or None,
endpoint_type=endpoint_type or None,
children_uids=children_uids or None,
),
)

Expand Down Expand Up @@ -415,6 +434,8 @@ def write_endpoint_to_kv(
current_stats = endpoint.status.current_stats or {}
children = endpoint.status.children or []
monitor_configuration = endpoint.spec.monitor_configuration or {}
endpoint_type = endpoint.status.endpoint_type or None
children_uids = endpoint.status.children_uids or []

client = get_v3io_client(endpoint=config.v3io_api)
function = client.kv.update if update else client.kv.put
Expand Down Expand Up @@ -447,6 +468,8 @@ def write_endpoint_to_kv(
"children": json.dumps(children),
"label_names": json.dumps(label_names),
"monitor_configuration": json.dumps(monitor_configuration),
"endpoint_type": json.dumps(endpoint_type),
"children_uids": json.dumps(children_uids),
**searchable_labels,
},
)
Expand Down Expand Up @@ -689,6 +712,7 @@ def build_kv_cursor_filter_expression(
function: Optional[str] = None,
model: Optional[str] = None,
labels: Optional[List[str]] = None,
top_level: Optional[bool] = False,
):
if not project:
raise MLRunInvalidArgumentError("project can't be empty")
Expand All @@ -710,6 +734,11 @@ def build_kv_cursor_filter_expression(
filter_expression.append(f"{lbl}=='{value}'")
else:
filter_expression.append(f"exists({label})")
if top_level:
filter_expression.append(
f"(endpoint_type=='{str(EndpointType.NODE_EP.value)}' "
f"OR endpoint_type=='{str(EndpointType.ROUTER.value)}')"
)

return " AND ".join(filter_expression)

Expand Down
4 changes: 3 additions & 1 deletion mlrun/api/schemas/model_endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pydantic.main import Extra

from mlrun.api.schemas.object import ObjectKind, ObjectSpec, ObjectStatus
from mlrun.utils.model_monitoring import create_model_endpoint_id
from mlrun.utils.model_monitoring import EndpointType, create_model_endpoint_id


class ModelMonitoringStoreKinds:
Expand Down Expand Up @@ -96,6 +96,8 @@ class ModelEndpointStatus(ObjectStatus):
metrics: Optional[Dict[str, Metric]]
features: Optional[List[Features]]
children: Optional[List[str]]
children_uids: Optional[List[str]]
endpoint_type: Optional[EndpointType]

class Config:
extra = Extra.allow
Expand Down
6 changes: 6 additions & 0 deletions mlrun/db/httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2326,6 +2326,8 @@ def list_model_endpoints(
end: str = "now",
metrics: Optional[List[str]] = None,
access_key: Optional[str] = None,
top_level: bool = False,
uids: Optional[List[str]] = None,
) -> schemas.ModelEndpointList:
"""
Returns a list of ModelEndpointState objects. Each object represents the current state of a model endpoint.
Expand All @@ -2348,6 +2350,8 @@ def list_model_endpoints(
:param start: The start time of the metrics
:param end: The end time of the metrics
:param access_key: V3IO access key, when None, will be look for in environ
:param top_level: if true will return only routers and endpoint that are NOT children of any router
:param uids: if passed will return ModelEndpointList of endpoints with uid in uids
"""
access_key = access_key or os.environ.get("V3IO_ACCESS_KEY")
if not access_key:
Expand All @@ -2367,6 +2371,8 @@ def list_model_endpoints(
"start": start,
"end": end,
"metric": metrics or [],
"top-level": top_level,
"uid": uids,
},
headers={"X-V3io-Access-Key": access_key},
)
Expand Down
33 changes: 31 additions & 2 deletions mlrun/serving/routers.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
ModelEndpointStatus,
)
from ..config import config
from ..utils.model_monitoring import EndpointType
from .utils import RouterToDict, _extract_input_data, _update_result_body
from .v2_serving import _ModelLogPusher

Expand Down Expand Up @@ -338,6 +339,7 @@ def __init__(
self.log_router = True
self.prediction_col_name = prediction_col_name or "prediction"
self.format_response_with_col_name_flag = False
self.model_endpoint_uid = None

def post_init(self, mode="sync"):
server = getattr(self.context, "_server", None) or getattr(
Expand All @@ -348,7 +350,7 @@ def post_init(self, mode="sync"):
return

if not self.context.is_mock or self.context.server.track_models:
_init_endpoint_record(server, self)
self.model_endpoint_uid = _init_endpoint_record(server, self)

def _resolve_route(self, body, urlpath):
"""Resolves the appropriate model to send the event to.
Expand Down Expand Up @@ -697,6 +699,8 @@ def validate(self, request):
def _init_endpoint_record(graph_server, voting_ensemble: VotingEnsemble):
logger.info("Initializing endpoint records")

endpoint_uid = None

try:
project, uri, tag, hash_key = parse_versioned_object_uri(
graph_server.function_uri
Expand All @@ -707,6 +711,11 @@ def _init_endpoint_record(graph_server, voting_ensemble: VotingEnsemble):
else:
versioned_model_name = f"{voting_ensemble.name}:latest"

children_uids = []
for _, c in voting_ensemble.routes.items():
if hasattr(c, "endpoint_uid"):
children_uids.append(c.endpoint_uid)

model_endpoint = ModelEndpoint(
metadata=ModelEndpointMetadata(project=project),
spec=ModelEndpointSpec(
Expand All @@ -718,8 +727,13 @@ def _init_endpoint_record(graph_server, voting_ensemble: VotingEnsemble):
),
active=True,
),
status=ModelEndpointStatus(children=list(voting_ensemble.routes.keys())),
status=ModelEndpointStatus(
children=list(voting_ensemble.routes.keys()),
endpoint_type=EndpointType.ROUTER,
children_uids=children_uids,
),
)
endpoint_uid = model_endpoint.metadata.uid

db = mlrun.get_run_db()

Expand All @@ -728,12 +742,27 @@ def _init_endpoint_record(graph_server, voting_ensemble: VotingEnsemble):
endpoint_id=model_endpoint.metadata.uid,
model_endpoint=model_endpoint,
)

for model_endpoint in children_uids:
# here to update that it is a node now
current_endpoint = db.get_model_endpoint(
project=project, endpoint_id=model_endpoint
)
current_endpoint.status.endpoint_type = EndpointType.LEAF_EP

db.create_or_patch_model_endpoint(
project=project,
endpoint_id=model_endpoint,
model_endpoint=current_endpoint,
)

except Exception as exc:
logger.warning(
"Failed creating model endpoint record",
exc=exc,
traceback=traceback.format_exc(),
)
return endpoint_uid


class EnrichmentModelRouter(ModelRouter):
Expand Down
2 changes: 2 additions & 0 deletions mlrun/serving/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,8 @@ def clear_object(self):
def _post_init(self, mode="sync"):
if self._object and hasattr(self._object, "post_init"):
self._object.post_init(mode)
if hasattr(self._object, "model_endpoint_uid"):
self.endpoint_uid = self._object.model_endpoint_uid

def respond(self):
"""mark this step as the responder.
Expand Down

0 comments on commit 9667636

Please sign in to comment.