diff --git a/mlflow/models/model.py b/mlflow/models/model.py index b4548ffce6c90..b34459f9b87d7 100644 --- a/mlflow/models/model.py +++ b/mlflow/models/model.py @@ -692,6 +692,21 @@ def log( ): _logger.warning(_LOG_MODEL_MISSING_SIGNATURE_WARNING) mlflow.tracking.fluent.log_artifacts(local_path, mlflow_model.artifact_path, run_id) + + # if the model_config kwarg is passed in, then log the model config as an params + if "model_config" in kwargs: + model_config = kwargs["model_config"] + if isinstance(model_config, str): + try: + with open(model_config) as f: + model_config = json.load(f) + except Exception as e: + _logger.warning("Failed to load model config from %s: %s", model_config, e) + try: + mlflow.tracking.fluent.log_params(model_config) + except Exception as e: + _logger.warning("Failed to log model config as params: %s", str(e)) + try: mlflow.tracking.fluent._record_logged_model(mlflow_model, run_id) except MlflowException: