diff --git a/src/transformers/integrations.py b/src/transformers/integrations.py index e072de4c51b8..ded47f8fd1b5 100644 --- a/src/transformers/integrations.py +++ b/src/transformers/integrations.py @@ -96,6 +96,8 @@ def is_azureml_available(): def is_mlflow_available(): + if os.getenv("DISABLE_MLFLOW_INTEGRATION", "FALSE").upper() == "TRUE": + return False return importlib.util.find_spec("mlflow") is not None @@ -758,7 +760,8 @@ def on_log(self, args, state, control, logs=None, **kwargs): class MLflowCallback(TrainerCallback): """ - A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). + A [`TrainerCallback`] that sends the logs to [MLflow](https://www.mlflow.org/). Can be disabled by setting + environment variable `DISABLE_MLFLOW_INTEGRATION = TRUE`. """ def __init__(self): @@ -789,7 +792,8 @@ def setup(self, args, state, model): if log_artifacts in {"TRUE", "1"}: self._log_artifacts = True if state.is_world_process_zero: - self._ml_flow.start_run(run_name=args.run_name) + if self._ml_flow.active_run is None: + self._ml_flow.start_run(run_name=args.run_name) combined_dict = args.to_dict() if hasattr(model, "config") and model.config is not None: model_config = model.config.to_dict()