Skip to content

Commit

Permalink
chore: add report_metrics.
Browse files Browse the repository at this point in the history
Signed-off-by: Electronic-Waste <2690692950@qq.com>
  • Loading branch information
Electronic-Waste committed Jun 24, 2024
1 parent 8bbac20 commit 1d9be47
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions sdk/python/v1beta1/kubeflow/katib/api/katib_report_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
# Copyright 2024 The Kubeflow Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from datetime import datetime
from typing import Any, Dict

import grpc
import pytz
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
from kubeflow.katib.constants import constants
from kubeflow.katib.utils import utils

def report_metrics(
metrics: Dict[str, Any],
db_manager_address: str = constants.DEFAULT_DB_MANAGER_ADDRESS,
timeout: int = constants.DEFAULT_TIMEOUT,
):
"""Push Metrics Directly to Katib DB
[!!!] Trial name should always be passed into Katib Trials as env variable `KATIB_TRIAL_NAME`.
Args:
metrics: Dict of metrics pushed to Katib DB.
For examle, `metrics = {"loss": 0.01, "accuracy": 0.99}`.
db-manager-address: Address for the Katib DB Manager in this format: `ip-address:port`.
timeout: Optional, gRPC API Server timeout in seconds to report metrics.
Raises:
ValueError: The Trial name is not passed to environment variables.
RuntimeError: Unable to push Trial metrics to Katib DB.
"""

namespace = utils.get_current_k8s_namespace()
name = os.getenv("KATIB_TRIAL_NAME")
if name is None:
raise ValueError(
"The Trial name is not passed to environment variables"
)

db_manager_address = db_manager_address.split(":")
channel = grpc.beta.implementations.insecure_channel(
db_manager_address[0], int(db_manager_address[1])
)

with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
try:
timestamp = datetime.now(pytz.UTC).isoformat(timespec="nanoseconds")
client.ReportObservationLog(
request=katib_api_pb2.ReportObservationLogRequest(
trial_name=name,
observation_logs=katib_api_pb2.ObservationLog(
metric_logs=[
katib_api_pb2.MetricLog(
time_stamp=timestamp,
metric=katib_api_pb2.Metric(name=name,value=value)
)
for name, value in metrics.items()
]
)
),
timeout=timeout,
)
except Exception as e:
raise RuntimeError(
f"Unable to push metrics to Katib DB for Trial {namespace}/{name}. Exception: {e}"
)

0 comments on commit 1d9be47

Please sign in to comment.