Skip to content

Commit

Permalink
Make function name more clear and refactor the logic
Browse files Browse the repository at this point in the history
  • Loading branch information
helloworld1 committed Apr 1, 2024
1 parent 0c49528 commit a126d3c
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/transformers/integrations/integration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,17 +131,23 @@ def is_mlflow_available():
return importlib.util.find_spec("mlflow") is not None


def get_mlflow_version():
def is_mlflow_async_log_available():
# MLFlow can also be provided from mlflow-skinny package, which has the same versioning of mlflow package.
for mlflow_package_name in ["mlflow", "mlflow-skinny"]:
try:
return importlib.metadata.version(mlflow_package_name)
mlflow_version = importlib.metadata.version(mlflow_package_name)

# "synchronous" flag is only available with mlflow version >= 2.8.0
# https://github.com/mlflow/mlflow/pull/9705
# https://github.com/mlflow/mlflow/releases/tag/v2.8.0
if packaging.version.parse(mlflow_version) >= packaging.version.parse("2.8.0"):
return True
except importlib.metadata.PackageNotFoundError:
# We will try different mlflow package candidates
pass

# Unable to determine the mlflow version
return None
# If MLFlow version cannot be determined, fallback to not doing async log to be safe
return False


def is_dagshub_available():
Expand Down Expand Up @@ -1011,13 +1017,7 @@ def setup(self, args, state, model):
self._experiment_name = os.getenv("MLFLOW_EXPERIMENT_NAME", None)
self._flatten_params = os.getenv("MLFLOW_FLATTEN_PARAMS", "FALSE").upper() in ENV_VARS_TRUE_VALUES
self._run_id = os.getenv("MLFLOW_RUN_ID", None)
self._async_log = False
# "synchronous" flag is only available with mlflow version >= 2.8.0
# https://github.com/mlflow/mlflow/pull/9705
# https://github.com/mlflow/mlflow/releases/tag/v2.8.0
mlflow_version = get_mlflow_version()
if mlflow_version is not None and packaging.version.parse(mlflow_version) >= packaging.version.parse("2.8.0"):
self._async_log = True
self._async_log = is_mlflow_async_log_available()
logger.debug(
f"MLflow experiment_name={self._experiment_name}, run_name={args.run_name}, nested={self._nested_run},"
f" tags={self._nested_run}, tracking_uri={self._tracking_uri}"
Expand Down

0 comments on commit a126d3c

Please sign in to comment.