From d89ed4f8e82df5f513492523aaf12f3c8efd951e Mon Sep 17 00:00:00 2001 From: Yaron Haviv Date: Wed, 23 Dec 2020 03:01:01 +0200 Subject: [PATCH] [SDK] Cache run db instance (#614) --- README.md | 2 +- docs/job-submission-and-tracking.md | 2 +- mlrun/__init__.py | 2 +- mlrun/__main__.py | 14 +++++++------- mlrun/config.py | 2 +- mlrun/datastore/datastore.py | 2 +- mlrun/db/__init__.py | 21 +++++++++++++++++++-- mlrun/execution.py | 3 +-- mlrun/model.py | 6 +++--- mlrun/projects/project.py | 8 ++++---- mlrun/run.py | 8 ++++---- mlrun/runtimes/base.py | 2 +- mlrun/runtimes/sparkjob.py | 2 +- mlrun/runtimes/utils.py | 2 +- tests/api/conftest.py | 7 ++++--- tests/rundb/test_rundb.py | 2 +- tests/test_run.py | 2 +- 17 files changed, 52 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 954328525de..889c8e48d5d 100644 --- a/README.md +++ b/README.md @@ -558,7 +558,7 @@ from mlrun import get_run_db # Get an MLRun DB object and connect to an MLRun database/API service. # Specify the DB path (for example, './' for the current directory) or # the API URL ('http://mlrun-api:8080' for the default configuration). -db = get_run_db('./').connect() +db = get_run_db('./') # List all runs db.list_runs('').show() diff --git a/docs/job-submission-and-tracking.md b/docs/job-submission-and-tracking.md index 6f1fbfd1e96..95b06bd7c43 100644 --- a/docs/job-submission-and-tracking.md +++ b/docs/job-submission-and-tracking.md @@ -463,7 +463,7 @@ from mlrun import get_run_db # Get an MLRun DB object and connect to an MLRun database/API service. # Specify the DB path (for example, './' for the current directory) or # the API URL ('http://mlrun-api:8080' for the default configuration). -db = get_run_db('./').connect() +db = get_run_db('./') # List all runs db.list_runs('').show() diff --git a/mlrun/__init__.py b/mlrun/__init__.py index 6b90725d77c..4a7e17b66a7 100644 --- a/mlrun/__init__.py +++ b/mlrun/__init__.py @@ -80,7 +80,7 @@ def set_environment(api_path: str = None, artifact_path: str = "", project: str raise ValueError("DB/API path was not detected, please specify its address") # check connectivity and load remote defaults - get_run_db().connect() + get_run_db() if api_path: environ["MLRUN_DBPATH"] = mlconf.dbpath diff --git a/mlrun/__main__.py b/mlrun/__main__.py index 4823164d1d3..a91c58b905c 100644 --- a/mlrun/__main__.py +++ b/mlrun/__main__.py @@ -524,7 +524,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args): start = i.status.start_time.strftime("%b %d %H:%M:%S") print("{:10} {:16} {:8} {}".format(state, start, task, name)) elif kind.startswith("runtime"): - mldb = get_run_db(db or mlconf.dbpath).connect() + mldb = get_run_db(db or mlconf.dbpath) if name: # the runtime identifier is its kind runtime = mldb.get_runtime(kind=name, label_selector=selector) @@ -533,7 +533,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args): runtimes = mldb.list_runtimes(label_selector=selector) print(dict_to_yaml(runtimes)) elif kind.startswith("run"): - mldb = get_run_db().connect() + mldb = get_run_db() if name: run = mldb.read_run(name, project=project) print(dict_to_yaml(run)) @@ -550,7 +550,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args): print(tabulate(df, headers="keys")) elif kind.startswith("art"): - mldb = get_run_db().connect() + mldb = get_run_db() artifacts = mldb.list_artifacts(name, project=project, tag=tag, labels=selector) df = artifacts.to_df()[ ["tree", "key", "iter", "kind", "path", "hash", "updated"] @@ -560,7 +560,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args): print(tabulate(df, headers="keys")) elif kind.startswith("func"): - mldb = get_run_db().connect() + mldb = get_run_db() if name: f = mldb.get_function(name, project=project, tag=tag) print(dict_to_yaml(f)) @@ -618,7 +618,7 @@ def version(): @click.option("--watch", "-w", is_flag=True, help="watch/follow log") def logs(uid, project, offset, db, watch): """Get or watch task logs""" - mldb = get_run_db(db or mlconf.dbpath).connect() + mldb = get_run_db(db or mlconf.dbpath) if mldb.kind == "http": state = mldb.watch_log(uid, project, watch=watch, offset=offset) else: @@ -818,7 +818,7 @@ def clean(kind, object_id, api, label_selector, force, grace_period): # Clean resources for specific job (by uid) mlrun clean dask 15d04c19c2194c0a8efb26ea3017254b """ - mldb = get_run_db(api or mlconf.dbpath).connect() + mldb = get_run_db(api or mlconf.dbpath) if kind: if object_id: mldb.delete_runtime_object( @@ -908,7 +908,7 @@ def func_url_to_runtime(func_url): if func_url.startswith("db://"): func_url = func_url[5:] project, name, tag, hash_key = parse_function_uri(func_url) - mldb = get_run_db(mlconf.dbpath).connect() + mldb = get_run_db(mlconf.dbpath) runtime = mldb.get_function(name, project, tag, hash_key) else: func_url = "function.yaml" if func_url == "." else func_url diff --git a/mlrun/config.py b/mlrun/config.py index c89f84deeb9..d7902a57bd1 100644 --- a/mlrun/config.py +++ b/mlrun/config.py @@ -197,7 +197,7 @@ def dbpath(self, value): import mlrun.db # when dbpath is set we want to connect to it which will sync configuration from it to the client - mlrun.db.get_run_db(value).connect() + mlrun.db.get_run_db(value) # Global configuration diff --git a/mlrun/datastore/datastore.py b/mlrun/datastore/datastore.py index 0a7d36871a8..1337822bc9d 100644 --- a/mlrun/datastore/datastore.py +++ b/mlrun/datastore/datastore.py @@ -90,7 +90,7 @@ def set(self, secrets=None, db=None): def _get_db(self): if not self._db: - self._db = mlrun.get_run_db().connect(self._secrets) + self._db = mlrun.get_run_db(secrets=self._secrets) return self._db def from_dict(self, struct: dict): diff --git a/mlrun/db/__init__.py b/mlrun/db/__init__.py index f0681581ac7..5e0fae5edd6 100644 --- a/mlrun/db/__init__.py +++ b/mlrun/db/__init__.py @@ -43,11 +43,26 @@ def get_httpdb_kwargs(host, username, password): } -def get_run_db(url=""): +_run_db = None +_last_db_url = None + + +def get_run_db(url="", secrets=None, force_reconnect=False): """Returns the runtime database""" + global _run_db, _last_db_url + if not url: url = get_or_set_dburl("./") + if ( + _last_db_url is not None + and url == _last_db_url + and _run_db + and not force_reconnect + ): + return _run_db + _last_db_url = url + parsed_url = urlparse(url) scheme = parsed_url.scheme.lower() kwargs = {} @@ -68,4 +83,6 @@ def get_run_db(url=""): else: cls = SQLDB - return cls(url, **kwargs) + _run_db = cls(url, **kwargs) + _run_db.connect(secrets=secrets) + return _run_db diff --git a/mlrun/execution.py b/mlrun/execution.py index ec75ae76ec8..1391c0c6481 100644 --- a/mlrun/execution.py +++ b/mlrun/execution.py @@ -144,8 +144,7 @@ def set_logger_stream(self, stream): def _init_dbs(self, rundb): if rundb: if isinstance(rundb, str): - self._rundb = get_run_db(rundb) - self._rundb.connect(self._secrets_manager) + self._rundb = get_run_db(rundb, secrets=self._secrets_manager) else: self._rundb = rundb self._data_stores = store_manager.set(self._secrets_manager, db=self._rundb) diff --git a/mlrun/model.py b/mlrun/model.py index ecc120a555b..5fee047884a 100644 --- a/mlrun/model.py +++ b/mlrun/model.py @@ -563,7 +563,7 @@ def uid(self): return self.metadata.uid def state(self): - db = get_run_db().connect() + db = get_run_db() run = db.read_run( uid=self.metadata.uid, project=self.metadata.project, @@ -573,12 +573,12 @@ def state(self): return get_in(run, "status.state", "unknown") def show(self): - db = get_run_db().connect() + db = get_run_db() db.list_runs(uid=self.metadata.uid, project=self.metadata.project).show() def logs(self, watch=True, db=None): if not db: - db = get_run_db().connect() + db = get_run_db() if not db: print("DB is not configured, cannot show logs") return None diff --git a/mlrun/projects/project.py b/mlrun/projects/project.py index daa726cd086..4aad8352f58 100644 --- a/mlrun/projects/project.py +++ b/mlrun/projects/project.py @@ -154,7 +154,7 @@ def _load_project_dir(context, name="", subpath=""): def _load_project_from_db(url, secrets): - db = get_run_db().connect(secrets) + db = get_run_db(secrets=secrets) project_name = url.replace("git://", "") return db.get_project(project_name) @@ -689,7 +689,7 @@ def register_artifacts(self): def _get_artifact_manager(self): if self._artifact_manager: return self._artifact_manager - db = get_run_db().connect(self._secrets) + db = get_run_db(secrets=self._secrets) sm = store_manager.set(self._secrets, db) self._artifact_manager = ArtifactManager(sm, db) return self._artifact_manager @@ -1076,7 +1076,7 @@ def get_run_status( if run_info: status = run_info["run"].get("status") - mldb = get_run_db().connect(self._secrets) + mldb = get_run_db(secrets=self._secrets) runs = mldb.list_runs( project=self.metadata.name, labels=f"workflow={workflow_id}" ) @@ -1110,7 +1110,7 @@ def save(self, filepath=None): self.save_to_db() def save_to_db(self): - db = get_run_db().connect(self._secrets) + db = get_run_db(secrets=self._secrets) db.store_project(self.metadata.name, self.to_dict()) def export(self, filepath=None): diff --git a/mlrun/run.py b/mlrun/run.py index 30536296406..5a6a528d48e 100644 --- a/mlrun/run.py +++ b/mlrun/run.py @@ -366,7 +366,7 @@ def import_function(url="", secrets=None, db=""): if url.startswith("db://"): url = url[5:] project, name, tag, hash_key = parse_function_uri(url) - db = get_run_db(db or get_or_set_dburl()).connect(secrets) + db = get_run_db(db or get_or_set_dburl(), secrets=secrets) runtime = db.get_function(name, project, tag, hash_key) if not runtime: raise KeyError("function {}:{} not found in the DB".format(name, tag)) @@ -763,7 +763,7 @@ def run_pipeline( arguments = arguments or {} if remote or url: - mldb = get_run_db(url).connect() + mldb = get_run_db(url) if mldb.kind != "http": raise ValueError( "run pipeline require access to remote api-service" @@ -828,7 +828,7 @@ def wait_for_pipeline_completion( ) if remote: - mldb = get_run_db().connect() + mldb = get_run_db() def get_pipeline_if_completed(run_id, namespace=namespace): resp = mldb.get_pipeline(run_id, namespace=namespace) @@ -889,7 +889,7 @@ def get_pipeline(run_id, namespace=None): namespace = namespace or mlconf.namespace remote = not get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster() if remote: - mldb = get_run_db().connect() + mldb = get_run_db() if mldb.kind != "http": raise ValueError( "get pipeline require access to remote api-service" diff --git a/mlrun/runtimes/base.py b/mlrun/runtimes/base.py index 09ef433abca..949156ac838 100644 --- a/mlrun/runtimes/base.py +++ b/mlrun/runtimes/base.py @@ -197,7 +197,7 @@ def _get_db(self): self._ensure_run_db() if not self._db_conn: if self.spec.rundb: - self._db_conn = get_run_db(self.spec.rundb).connect(self._secrets) + self._db_conn = get_run_db(self.spec.rundb, secrets=self._secrets) return self._db_conn def run( diff --git a/mlrun/runtimes/sparkjob.py b/mlrun/runtimes/sparkjob.py index 6cb12bea6e4..f813986953f 100644 --- a/mlrun/runtimes/sparkjob.py +++ b/mlrun/runtimes/sparkjob.py @@ -174,7 +174,7 @@ def _default_image(self): def deploy(self, watch=True, with_mlrun=True, skip_deployed=False, is_kfp=False): """deploy function, build container with dependencies""" # connect will populate the config from the server config - get_run_db().connect() + get_run_db() if not self.spec.build.base_image: self.spec.build.base_image = self._default_image return super().deploy( diff --git a/mlrun/runtimes/utils.py b/mlrun/runtimes/utils.py index d3a696c906d..8d3469e568a 100644 --- a/mlrun/runtimes/utils.py +++ b/mlrun/runtimes/utils.py @@ -88,7 +88,7 @@ def resolve_mpijob_crd_version(api_context=False): elif not in_k8s_cluster and not api_context: # connect will populate the config from the server config # TODO: something nicer - get_run_db().connect() + get_run_db() mpijob_crd_version = config.mpijob_crd_version # If resolution failed simply use default diff --git a/tests/api/conftest.py b/tests/api/conftest.py index c823020eb72..e7c0d1dd910 100644 --- a/tests/api/conftest.py +++ b/tests/api/conftest.py @@ -28,9 +28,6 @@ def db() -> Generator: dsn = f"sqlite:///{db_file.name}?check_same_thread=false" config.httpdb.dsn = dsn - # we're also running client code in tests - config.dbpath = dsn - # TODO: make it simpler - doesn't make sense to call 3 different functions to initialize the db # we need to force re-init the engine cause otherwise it is cached between tests _init_engine(config.httpdb.dsn) @@ -39,6 +36,10 @@ def db() -> Generator: init_data(from_scratch=True) initialize_db() initialize_project_member() + + # we're also running client code in tests so set dbpath as well + # note that setting this attribute triggers connection to the run db therefore must happen after the initialization + config.dbpath = dsn yield create_session() logger.info(f"Removing temp db file: {db_file.name}") db_file.close() diff --git a/tests/rundb/test_rundb.py b/tests/rundb/test_rundb.py index 2d072b32aa6..81f991abbff 100644 --- a/tests/rundb/test_rundb.py +++ b/tests/rundb/test_rundb.py @@ -18,7 +18,7 @@ def get_db(): - return mlrun.get_run_db(rundb_path).connect() + return mlrun.get_run_db(rundb_path) # diff --git a/tests/test_run.py b/tests/test_run.py index 32fc6967224..91757115abd 100644 --- a/tests/test_run.py +++ b/tests/test_run.py @@ -170,7 +170,7 @@ def test_local_no_context(): ).run(spec) verify_state(result) - db = get_run_db().connect() + db = get_run_db() state, log = db.get_log(result.metadata.uid) log = str(log) print(state)