From 1d9be470dbf9499c6a94070572956ff844b1c66e Mon Sep 17 00:00:00 2001 From: Electronic-Waste <2690692950@qq.com> Date: Mon, 24 Jun 2024 23:01:56 +0800 Subject: [PATCH] chore: add report_metrics. Signed-off-by: Electronic-Waste <2690692950@qq.com> --- .../katib/api/katib_report_metrics.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 sdk/python/v1beta1/kubeflow/katib/api/katib_report_metrics.py diff --git a/sdk/python/v1beta1/kubeflow/katib/api/katib_report_metrics.py b/sdk/python/v1beta1/kubeflow/katib/api/katib_report_metrics.py new file mode 100644 index 00000000000..7e326d0f6ee --- /dev/null +++ b/sdk/python/v1beta1/kubeflow/katib/api/katib_report_metrics.py @@ -0,0 +1,78 @@ +# Copyright 2024 The Kubeflow Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from datetime import datetime +from typing import Any, Dict + +import grpc +import pytz +import kubeflow.katib.katib_api_pb2 as katib_api_pb2 +from kubeflow.katib.constants import constants +from kubeflow.katib.utils import utils + +def report_metrics( + metrics: Dict[str, Any], + db_manager_address: str = constants.DEFAULT_DB_MANAGER_ADDRESS, + timeout: int = constants.DEFAULT_TIMEOUT, +): + """Push Metrics Directly to Katib DB + + [!!!] Trial name should always be passed into Katib Trials as env variable `KATIB_TRIAL_NAME`. + + Args: + metrics: Dict of metrics pushed to Katib DB. + For examle, `metrics = {"loss": 0.01, "accuracy": 0.99}`. + db-manager-address: Address for the Katib DB Manager in this format: `ip-address:port`. + timeout: Optional, gRPC API Server timeout in seconds to report metrics. + + Raises: + ValueError: The Trial name is not passed to environment variables. + RuntimeError: Unable to push Trial metrics to Katib DB. + """ + + namespace = utils.get_current_k8s_namespace() + name = os.getenv("KATIB_TRIAL_NAME") + if name is None: + raise ValueError( + "The Trial name is not passed to environment variables" + ) + + db_manager_address = db_manager_address.split(":") + channel = grpc.beta.implementations.insecure_channel( + db_manager_address[0], int(db_manager_address[1]) + ) + + with katib_api_pb2.beta_create_DBManager_stub(channel) as client: + try: + timestamp = datetime.now(pytz.UTC).isoformat(timespec="nanoseconds") + client.ReportObservationLog( + request=katib_api_pb2.ReportObservationLogRequest( + trial_name=name, + observation_logs=katib_api_pb2.ObservationLog( + metric_logs=[ + katib_api_pb2.MetricLog( + time_stamp=timestamp, + metric=katib_api_pb2.Metric(name=name,value=value) + ) + for name, value in metrics.items() + ] + ) + ), + timeout=timeout, + ) + except Exception as e: + raise RuntimeError( + f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}" + ) \ No newline at end of file