Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] Fix model monitoring endpoints sync/async separation #885

Merged
merged 2 commits into from
Apr 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
31 changes: 14 additions & 17 deletions mlrun/api/api/endpoints/grafana_proxy.py
Expand Up @@ -28,7 +28,7 @@


@router.get("/grafana-proxy/model-endpoints", status_code=HTTPStatus.OK.value)
async def grafana_proxy_model_endpoints_check_connection(request: Request):
def grafana_proxy_model_endpoints_check_connection(request: Request):
"""
Root of grafana proxy for the model-endpoints API, used for validating the model-endpoints data source
connectivity.
Expand Down Expand Up @@ -61,7 +61,7 @@ async def grafana_proxy_model_endpoints_query(
# checks again.
target_endpoint = query_parameters["target_endpoint"]
function = NAME_TO_QUERY_FUNCTION_DICTIONARY[target_endpoint]
result = await function(body, query_parameters, access_key)
result = await run_in_threadpool(function, body, query_parameters, access_key)
return result


Expand All @@ -86,20 +86,18 @@ async def grafana_proxy_model_endpoints_search(
# checks again.
target_endpoint = query_parameters["target_endpoint"]
function = NAME_TO_SEARCH_FUNCTION_DICTIONARY[target_endpoint]
result = await function(db_session)
result = await run_in_threadpool(function, db_session)
return result


async def grafana_list_projects(db_session: Session) -> List[str]:
def grafana_list_projects(db_session: Session) -> List[str]:
db = get_db()

projects_output = await run_in_threadpool(
db.list_projects, session=db_session, format_=Format.name_only,
)
projects_output = db.list_projects(session=db_session, format_=Format.name_only)
return projects_output.projects


async def grafana_list_endpoints(
def grafana_list_endpoints(
body: Dict[str, Any], query_parameters: Dict[str, str], access_key: str
) -> List[GrafanaTable]:
project = query_parameters.get("project")
Expand All @@ -118,7 +116,7 @@ async def grafana_list_endpoints(
start = body.get("rangeRaw", {}).get("start", "now-1h")
end = body.get("rangeRaw", {}).get("end", "now")

endpoint_list = await ModelEndpoints.list_endpoints(
endpoint_list = ModelEndpoints.list_endpoints(
access_key=access_key,
project=project,
model=model,
Expand Down Expand Up @@ -176,13 +174,13 @@ async def grafana_list_endpoints(
return [table]


async def grafana_individual_feature_analysis(
def grafana_individual_feature_analysis(
body: Dict[str, Any], query_parameters: Dict[str, str], access_key: str
):
endpoint_id = query_parameters.get("endpoint_id")
project = query_parameters.get("project")

endpoint = await ModelEndpoints.get_endpoint(
endpoint = ModelEndpoints.get_endpoint(
access_key=access_key,
project=project,
endpoint_id=endpoint_id,
Expand Down Expand Up @@ -229,13 +227,13 @@ async def grafana_individual_feature_analysis(
return [table]


async def grafana_overall_feature_analysis(
def grafana_overall_feature_analysis(
body: Dict[str, Any], query_parameters: Dict[str, str], access_key: str
):
endpoint_id = query_parameters.get("endpoint_id")
project = query_parameters.get("project")

endpoint = await ModelEndpoints.get_endpoint(
endpoint = ModelEndpoints.get_endpoint(
access_key=access_key,
project=project,
endpoint_id=endpoint_id,
Expand Down Expand Up @@ -266,15 +264,15 @@ async def grafana_overall_feature_analysis(
return [table]


async def grafana_incoming_features(
def grafana_incoming_features(
body: Dict[str, Any], query_parameters: Dict[str, str], access_key: str
):
endpoint_id = query_parameters.get("endpoint_id")
project = query_parameters.get("project")
start = body.get("rangeRaw", {}).get("from", "now-1h")
end = body.get("rangeRaw", {}).get("to", "now")

endpoint = await ModelEndpoints.get_endpoint(
endpoint = ModelEndpoints.get_endpoint(
access_key=access_key, project=project, endpoint_id=endpoint_id
)

Expand All @@ -298,8 +296,7 @@ async def grafana_incoming_features(
token=access_key, address=config.v3io_framesd, container=container,
)

data: pd.DataFrame = await run_in_threadpool(
client.read,
data: pd.DataFrame = client.read(
backend="tsdb",
table=path,
columns=feature_names,
Expand Down
20 changes: 10 additions & 10 deletions mlrun/api/api/endpoints/model_endpoints.py
Expand Up @@ -14,7 +14,7 @@
"/projects/{project}/model-endpoints/{endpoint_id}",
status_code=HTTPStatus.NO_CONTENT.value,
)
async def create_or_patch(
def create_or_patch(
request: Request, project: str, endpoint_id: str, model_endpoint: ModelEndpoint
) -> Response:
"""
Expand All @@ -30,8 +30,8 @@ async def create_or_patch(
f"Mismatch between endpoint_id {endpoint_id} and ModelEndpoint.metadata.uid {model_endpoint.metadata.uid}."
f"\nMake sure the supplied function_uri, and model are configured as intended"
)
await ModelEndpoints.create_or_patch(
access_key=access_key, model_endpoint=model_endpoint
ModelEndpoints.create_or_patch(
access_key=access_key, model_endpoint=model_endpoint,
)
return Response(status_code=HTTPStatus.NO_CONTENT.value)

Expand All @@ -40,21 +40,21 @@ async def create_or_patch(
"/projects/{project}/model-endpoints/{endpoint_id}",
status_code=HTTPStatus.NO_CONTENT.value,
)
async def delete_endpoint_record(
def delete_endpoint_record(
request: Request, project: str, endpoint_id: str
) -> Response:
"""
Clears endpoint record from KV by endpoint_id
"""
access_key = get_access_key(request.headers)
await ModelEndpoints.delete_endpoint_record(
access_key=access_key, project=project, endpoint_id=endpoint_id
ModelEndpoints.delete_endpoint_record(
access_key=access_key, project=project, endpoint_id=endpoint_id,
)
return Response(status_code=HTTPStatus.NO_CONTENT.value)


@router.get("/projects/{project}/model-endpoints", response_model=ModelEndpointList)
async def list_endpoints(
def list_endpoints(
request: Request,
project: str,
model: Optional[str] = Query(None),
Expand All @@ -79,7 +79,7 @@ async def list_endpoints(
api/projects/{project}/model-endpoints/?label=mylabel=1,myotherlabel=2
"""
access_key = get_access_key(request.headers)
endpoints = await ModelEndpoints.list_endpoints(
endpoints = ModelEndpoints.list_endpoints(
access_key=access_key,
project=project,
model=model,
Expand All @@ -95,7 +95,7 @@ async def list_endpoints(
@router.get(
"/projects/{project}/model-endpoints/{endpoint_id}", response_model=ModelEndpoint
)
async def get_endpoint(
def get_endpoint(
request: Request,
project: str,
endpoint_id: str,
Expand All @@ -105,7 +105,7 @@ async def get_endpoint(
feature_analysis: bool = Query(default=False),
) -> ModelEndpoint:
access_key = get_access_key(request.headers)
endpoint = await ModelEndpoints.get_endpoint(
endpoint = ModelEndpoints.get_endpoint(
access_key=access_key,
project=project,
endpoint_id=endpoint_id,
Expand Down
38 changes: 15 additions & 23 deletions mlrun/api/crud/model_endpoints.py
@@ -1,7 +1,6 @@
import json
from typing import Any, Dict, List, Mapping, Optional

from starlette.concurrency import run_in_threadpool
from v3io.dataplane import RaiseForStatus

from mlrun.api.schemas import (
Expand Down Expand Up @@ -30,7 +29,7 @@

class ModelEndpoints:
@staticmethod
async def create_or_patch(access_key: str, model_endpoint: ModelEndpoint):
def create_or_patch(access_key: str, model_endpoint: ModelEndpoint):
"""
Creates or patch a KV record with the given model_endpoint record

Expand All @@ -51,9 +50,8 @@ async def create_or_patch(access_key: str, model_endpoint: ModelEndpoint):
logger.info(
"Getting model object, inferring column names and collecting feature stats"
)
model_obj: tuple = await run_in_threadpool(
get_model, model_endpoint.spec.model_uri
)
model_obj: tuple = get_model(model_endpoint.spec.model_uri)

model_obj: ModelArtifact = model_obj[1]

if not model_endpoint.status.feature_stats:
Expand Down Expand Up @@ -107,7 +105,7 @@ async def create_or_patch(access_key: str, model_endpoint: ModelEndpoint):
# system
logger.info("Updating model endpoint", endpoint_id=model_endpoint.metadata.uid)

await write_endpoint_to_kv(
write_endpoint_to_kv(
access_key=access_key, endpoint=model_endpoint, update=True,
)

Expand All @@ -116,7 +114,7 @@ async def create_or_patch(access_key: str, model_endpoint: ModelEndpoint):
return model_endpoint

@staticmethod
async def delete_endpoint_record(access_key: str, project: str, endpoint_id: str):
def delete_endpoint_record(access_key: str, project: str, endpoint_id: str):
"""
Deletes the KV record of a given model endpoint, project nad endpoint_id are used for lookup

Expand All @@ -133,8 +131,7 @@ async def delete_endpoint_record(access_key: str, project: str, endpoint_id: str
)
_, container, path = parse_model_endpoint_store_prefix(path)

await run_in_threadpool(
client.kv.delete,
client.kv.delete(
container=container,
table_path=path,
key=endpoint_id,
Expand All @@ -144,7 +141,7 @@ async def delete_endpoint_record(access_key: str, project: str, endpoint_id: str
logger.info("Model endpoint table cleared", endpoint_id=endpoint_id)

@staticmethod
async def list_endpoints(
def list_endpoints(
access_key: str,
project: str,
model: Optional[str] = None,
Expand Down Expand Up @@ -211,7 +208,7 @@ async def list_endpoints(
if item is None:
break
endpoint_id = item["endpoint_id"]
endpoint = await ModelEndpoints.get_endpoint(
endpoint = ModelEndpoints.get_endpoint(
access_key=access_key,
project=project,
endpoint_id=endpoint_id,
Expand All @@ -223,7 +220,7 @@ async def list_endpoints(
return endpoint_list

@staticmethod
async def get_endpoint(
def get_endpoint(
access_key: str,
project: str,
endpoint_id: str,
Expand Down Expand Up @@ -256,8 +253,7 @@ async def get_endpoint(
)
_, container, path = parse_model_endpoint_store_prefix(path)

endpoint = await run_in_threadpool(
client.kv.get,
endpoint = client.kv.get(
container=container,
table_path=path,
key=endpoint_id,
Expand Down Expand Up @@ -330,7 +326,7 @@ async def get_endpoint(
endpoint.status.drift_measures = drift_measures

if metrics:
endpoint_metrics = await get_endpoint_metrics(
endpoint_metrics = get_endpoint_metrics(
access_key=access_key,
project=project,
endpoint_id=endpoint_id,
Expand All @@ -344,9 +340,7 @@ async def get_endpoint(
return endpoint


async def write_endpoint_to_kv(
access_key: str, endpoint: ModelEndpoint, update: bool = True
):
def write_endpoint_to_kv(access_key: str, endpoint: ModelEndpoint, update: bool = True):
"""
Writes endpoint data to KV, a prerequisite for initializing the monitoring process

Expand All @@ -372,8 +366,7 @@ async def write_endpoint_to_kv(
)
_, container, path = parse_model_endpoint_store_prefix(path)

await run_in_threadpool(
function,
function(
container=container,
table_path=path,
key=endpoint.metadata.uid,
Expand Down Expand Up @@ -411,7 +404,7 @@ def _clean_feature_name(feature_name):
return feature_name.replace(" ", "_").replace("(", "").replace(")", "")


async def get_endpoint_metrics(
def get_endpoint_metrics(
access_key: str,
project: str,
endpoint_id: str,
Expand All @@ -432,8 +425,7 @@ async def get_endpoint_metrics(
token=access_key, address=config.v3io_framesd, container=container,
)

data = await run_in_threadpool(
client.read,
data = client.read(
backend="tsdb",
table=path,
columns=["endpoint_id", *metrics],
Expand Down