Skip to content

Commit

Permalink
Remove MLClient hint to save import time on evaluate API
Browse files Browse the repository at this point in the history
  • Loading branch information
ninghu committed Jun 11, 2024
1 parent 681daac commit 24a83d8
Showing 1 changed file with 59 additions and 93 deletions.
152 changes: 59 additions & 93 deletions src/promptflow-evals/promptflow/evals/evaluate/_eval_run.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,37 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------
from typing import Any, Dict, Optional, Type

import dataclasses
import json
import logging
import os
import posixpath
import requests
import time
import uuid
from typing import Any, Dict, Optional, Type
from urllib.parse import urlparse

from azure.ai.ml import MLClient
import requests
from azure.storage.blob import BlobClient
from requests.adapters import HTTPAdapter
from urllib.parse import urlparse
from urllib3.util.retry import Retry

from promptflow.evals._version import VERSION
import time

LOGGER = logging.getLogger(__name__)


@dataclasses.dataclass
class RunInfo():
class RunInfo:
"""
A holder for run info, needed for logging.
"""

run_id: str
experiment_id: str

@staticmethod
def generate() -> 'RunInfo':
def generate() -> "RunInfo":
"""
Generate the new RunInfo instance with the RunID and Experiment ID.
"""
Expand All @@ -44,6 +43,7 @@ def generate() -> 'RunInfo':

class Singleton(type):
"""Singleton class, which will be used as a metaclass."""

_instances = {}

def __call__(cls, *args, **kwargs):
Expand All @@ -63,7 +63,7 @@ def destroy(cls: Type) -> None:


class EvalRun(metaclass=Singleton):
'''
"""
The simple singleton run class, used for accessing artifact store.
:param run_name: The name of the run.
Expand All @@ -78,21 +78,22 @@ class EvalRun(metaclass=Singleton):
:type workspace_name: str
:param ml_client: The ml client used for authentication into Azure.
:type ml_client: MLClient
'''
"""

_MAX_RETRIES = 5
_BACKOFF_FACTOR = 2
_TIMEOUT = 5
_SCOPE = "https://management.azure.com/.default"

def __init__(self,
run_name: Optional[str],
tracking_uri: str,
subscription_id: str,
group_name: str,
workspace_name: str,
ml_client: MLClient
):
def __init__(
self,
run_name: Optional[str],
tracking_uri: str,
subscription_id: str,
group_name: str,
workspace_name: str,
ml_client,
):
"""
Constructor
"""
Expand All @@ -101,7 +102,7 @@ def __init__(self,
self._subscription_id: str = subscription_id
self._resource_group_name: str = group_name
self._workspace_name: str = workspace_name
self._ml_client: MLClient = ml_client
self._ml_client = ml_client
self._url_base = urlparse(self._tracking_uri).netloc
self._is_broken = self._start_run()
self._is_terminated = False
Expand All @@ -117,9 +118,7 @@ def _get_scope(self):
:rtype: str
"""
return (
"/subscriptions/{}/resourceGroups/{}/providers"
"/Microsoft.MachineLearningServices"
"/workspaces/{}"
"/subscriptions/{}/resourceGroups/{}/providers" "/Microsoft.MachineLearningServices" "/workspaces/{}"
).format(
self._subscription_id,
self._resource_group_name,
Expand All @@ -133,34 +132,25 @@ def _start_run(self) -> bool:
marked as broken and the logging will be switched off.
:returns: True if the run has started and False otherwise.
"""
url = (
f"https://{self._url_base}/mlflow/v2.0"
f"{self._get_scope()}/api/2.0/mlflow/runs/create")
url = f"https://{self._url_base}/mlflow/v2.0" f"{self._get_scope()}/api/2.0/mlflow/runs/create"
body = {
"experiment_id": "0",
"user_id": "promptflow-evals",
"start_time": int(time.time() * 1000),
"tags": [
{
"key": "mlflow.user",
"value": "promptflow-evals"
}
]
"tags": [{"key": "mlflow.user", "value": "promptflow-evals"}],
}
response = self.request_with_retry(
url=url,
method='POST',
json_dict=body
)
response = self.request_with_retry(url=url, method="POST", json_dict=body)
if response.status_code != 200:
self.info = RunInfo.generate()
LOGGER.error(f"The run failed to start: {response.status_code}: {response.text}."
"The results will be saved locally, but will not be logged to Azure.")
LOGGER.error(
f"The run failed to start: {response.status_code}: {response.text}."
"The results will be saved locally, but will not be logged to Azure."
)
return True
parsed_response = response.json()
self.info = RunInfo(
run_id=parsed_response['run']['info']['run_id'],
experiment_id=parsed_response['run']['info']['experiment_id'],
run_id=parsed_response["run"]["info"]["run_id"],
experiment_id=parsed_response["run"]["info"]["experiment_id"],
)
return False

Expand All @@ -174,28 +164,22 @@ def end_run(self, status: str) -> None:
"""
if status not in ("FINISHED", "FAILED", "KILLED"):
raise ValueError(
f"Incorrect terminal status {status}. "
"Valid statuses are \"FINISHED\", \"FAILED\" and \"KILLED\".")
f"Incorrect terminal status {status}. " 'Valid statuses are "FINISHED", "FAILED" and "KILLED".'
)
if self._is_terminated:
LOGGER.warning("Unable to stop run because it was already terminated.")
return
if self._is_broken:
LOGGER.error("Unable to stop run because the run failed to start.")
return
url = (
f"https://{self._url_base}/mlflow/v2.0"
f"{self._get_scope()}/api/2.0/mlflow/runs/update")
url = f"https://{self._url_base}/mlflow/v2.0" f"{self._get_scope()}/api/2.0/mlflow/runs/update"
body = {
"run_uuid": self.info.run_id,
"status": status,
"end_time": int(time.time() * 1000),
"run_id": self.info.run_id
"run_id": self.info.run_id,
}
response = self.request_with_retry(
url=url,
method='POST',
json_dict=body
)
response = self.request_with_retry(url=url, method="POST", json_dict=body)
if response.status_code != 200:
LOGGER.error("Unable to terminate the run.")
Singleton.destroy(EvalRun)
Expand All @@ -209,25 +193,20 @@ def get_run_history_uri(self) -> str:
f"https://{self._url_base}"
"/history/v1.0"
f"{self._get_scope()}"
f'/experimentids/{self.info.experiment_id}/runs/{self.info.run_id}'
f"/experimentids/{self.info.experiment_id}/runs/{self.info.run_id}"
)

def get_artifacts_uri(self) -> str:
"""
Returns the url to upload the artifacts.
"""
return self.get_run_history_uri() + '/artifacts/batch/metadata'
return self.get_run_history_uri() + "/artifacts/batch/metadata"

def get_metrics_url(self):
"""
Return the url needed to track the mlflow metrics.
"""
return (
f"https://{self._url_base}"
"/mlflow/v2.0"
f"{self._get_scope()}"
f'/api/2.0/mlflow/runs/log-metric'
)
return f"https://{self._url_base}" "/mlflow/v2.0" f"{self._get_scope()}" f"/api/2.0/mlflow/runs/log-metric"

def _get_token(self):
"""The simple method to get token from the MLClient."""
Expand All @@ -237,11 +216,7 @@ def _get_token(self):
return self._ml_client._credential.get_token(EvalRun._SCOPE)

def request_with_retry(
self,
url: str,
method: str,
json_dict: Dict[str, Any],
headers: Optional[Dict[str, str]] = None
self, url: str, method: str, json_dict: Dict[str, Any], headers: Optional[Dict[str, str]] = None
) -> requests.Response:
"""
Send the request with retries.
Expand All @@ -258,8 +233,8 @@ def request_with_retry(
"""
if headers is None:
headers = {}
headers['User-Agent'] = f'promptflow/{VERSION}'
headers['Authorization'] = f'Bearer {self._get_token().token}'
headers["User-Agent"] = f"promptflow/{VERSION}"
headers["Authorization"] = f"Bearer {self._get_token().token}"
retry = Retry(
total=EvalRun._MAX_RETRIES,
connect=EvalRun._MAX_RETRIES,
Expand All @@ -268,18 +243,12 @@ def request_with_retry(
status=EvalRun._MAX_RETRIES,
status_forcelist=(408, 429, 500, 502, 503, 504),
backoff_factor=EvalRun._BACKOFF_FACTOR,
allowed_methods=None
allowed_methods=None,
)
adapter = HTTPAdapter(max_retries=retry)
session = requests.Session()
session.mount("https://", adapter)
return session.request(
method,
url,
headers=headers,
json=json_dict,
timeout=EvalRun._TIMEOUT
)
return session.request(method, url, headers=headers, json=json_dict, timeout=EvalRun._TIMEOUT)

def _log_error(self, failed_op: str, response: requests.Response) -> None:
"""
Expand Down Expand Up @@ -318,42 +287,39 @@ def log_artifact(self, artifact_folder: str) -> None:
return
# First we will list the files and the appropriate remote paths for them.
upload_path = os.path.basename(os.path.normpath(artifact_folder))
remote_paths = {'paths': []}
remote_paths = {"paths": []}
local_paths = []

for (root, _, filenames) in os.walk(artifact_folder):
if root != artifact_folder:
rel_path = os.path.relpath(root, artifact_folder)
if rel_path != '.':
if rel_path != ".":
upload_path = posixpath.join(upload_path, rel_path)
for f in filenames:
remote_file_path = posixpath.join(upload_path, f)
remote_paths['paths'].append({'path': remote_file_path})
remote_paths["paths"].append({"path": remote_file_path})
local_file_path = os.path.join(root, f)
local_paths.append(local_file_path)
# Now we need to reserve the space for files in the artifact store.
headers = {
'Content-Type': "application/json",
'Accept': "application/json",
'Content-Length': str(len(json.dumps(remote_paths))),
'x-ms-client-request-id': str(uuid.uuid1()),
"Content-Type": "application/json",
"Accept": "application/json",
"Content-Length": str(len(json.dumps(remote_paths))),
"x-ms-client-request-id": str(uuid.uuid1()),
}
response = self.request_with_retry(
url=self.get_artifacts_uri(),
method='POST',
json_dict=remote_paths,
headers=headers
url=self.get_artifacts_uri(), method="POST", json_dict=remote_paths, headers=headers
)
if response.status_code != 200:
self._log_error("allocate Blob for the artifact", response)
return
empty_artifacts = response.json()['artifactContentInformation']
empty_artifacts = response.json()["artifactContentInformation"]
# The response from Azure contains the URL with SAS, that allows to upload file to the
# artifact store.
for local, remote in zip(local_paths, remote_paths['paths']):
artifact_loc = empty_artifacts[remote['path']]
blob_client = BlobClient.from_blob_url(artifact_loc['contentUri'], max_single_put_size=32 * 1024 * 1024)
with open(local, 'rb') as fp:
for local, remote in zip(local_paths, remote_paths["paths"]):
artifact_loc = empty_artifacts[remote["path"]]
blob_client = BlobClient.from_blob_url(artifact_loc["contentUri"], max_single_put_size=32 * 1024 * 1024)
with open(local, "rb") as fp:
blob_client.upload_blob(fp)

def log_metric(self, key: str, value: float) -> None:
Expand All @@ -374,15 +340,15 @@ def log_metric(self, key: str, value: float) -> None:
"value": value,
"timestamp": int(time.time() * 1000),
"step": 0,
"run_id": self.info.run_id
"run_id": self.info.run_id,
}
response = self.request_with_retry(
url=self.get_metrics_url(),
method='POST',
method="POST",
json_dict=body,
)
if response.status_code != 200:
self._log_error('save metrics', response)
self._log_error("save metrics", response)

@staticmethod
def get_instance(*args, **kwargs) -> "EvalRun":
Expand Down

0 comments on commit 24a83d8

Please sign in to comment.