In [4]:
import mlflow.sklearn
from mlflow.store.artifact.runs_artifact_repo import RunsArtifactRepository
from mlflow import MlflowClient
from sklearn.ensemble import RandomForestRegressor

mlflow.set_tracking_uri("sqlite:///mlruns.db")

def print_models_info(mv):
    for m in mv:
        print("name: {}".format(m.name))
        print("latest version: {}".format(m.version))
        print("run_id: {}".format(m.run_id))
        print("current_stage: {}".format(m.current_stage))

mlflow.sklearn.autolog()
with mlflow.start_run() as run1:
    params = {"n_estimators": 3, "random_state": 42}
    rfr = RandomForestRegressor(**params).fit([[0, 1]], [1])
    mlflow.log_params(params)
    mlflow.sklearn.log_model(rfr, artifact_path="sklearn-model_test")

with mlflow.start_run() as run2:
    params = {"n_estimators": 6, "random_state": 42}
    rfr = RandomForestRegressor(**params).fit([[0, 1]], [1])
    mlflow.log_params(params)
    mlflow.sklearn.log_model(rfr, artifact_path="sklearn-model_test")

# Register model name in the model registry
name = "RandomForestRegression_test"
client = MlflowClient()
client.create_registered_model(name)

for run_id in [run1.info.run_id, run2.info.run_id]:
    model_uri = "runs:/{}/sklearn-model".format(run_id)
    mv = client.create_model_version(name, model_uri, run_id)
    print("model version {} created".format(mv.version))

2023/03/16 10:41:37 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: RandomForestRegression_test, version 1
2023/03/16 10:41:37 INFO mlflow.tracking._model_registry.client: Waiting up to 300 seconds for model version to finish creation.                     Model name: RandomForestRegression_test, version 2


model version 1 created
model version 2 created


In [6]:
model_versions = client.get_latest_versions(name)
for version in model_versions:
    print(version.version)

2


In [13]:
model_src = model_versions[0].source

In [9]:
run_id = model_versions[0].run_id

'9fb1c844a2e845d2b88872a5495d50b8'

In [15]:
runs_uri = "runs:/{}/sklearn-model".format(run2.info.run_id)
model_src = RunsArtifactRepository.get_underlying_uri(runs_uri)

In [18]:
model_src

'./mlruns/0/9fb1c844a2e845d2b88872a5495d50b8/artifacts/sklearn-model'