diff --git a/qlib/workflow/exp.py b/qlib/workflow/exp.py index 08f429eb32..f80cb041ac 100644 --- a/qlib/workflow/exp.py +++ b/qlib/workflow/exp.py @@ -1,6 +1,7 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from typing import Union import mlflow, logging from mlflow.entities import ViewType from mlflow.exceptions import MlflowException @@ -324,14 +325,21 @@ def delete_recorder(self, recorder_id=None, recorder_name=None): UNLIMITED = 50000 # FIXME: Mlflow can only list 50000 records at most!!!!!!! - def list_recorders(self, max_results=UNLIMITED, status=None): + def list_recorders(self, max_results: int=UNLIMITED, status: Union[str, None]=None): + """ + Parameters + ---------- + max_results : int + the number limitation of the results + status : str + the criteria based on status to filter results. + `None` indicates no filtering. + """ runs = self._client.search_runs(self.id, run_view_type=ViewType.ACTIVE_ONLY, max_results=max_results) recorders = dict() for i in range(len(runs)): recorder = MLflowRecorder(self.id, self._uri, mlflow_run=runs[i]) - if status is not None: - if recorder.status != status: - continue - recorders[runs[i].info.run_id] = recorder + if status is None or recorder.status == status: + recorders[runs[i].info.run_id] = recorder return recorders