Skip to content

Commit

Permalink
[SDK] Cache run db instance (#614)
Browse files Browse the repository at this point in the history
  • Loading branch information
yaronha committed Dec 23, 2020
1 parent a138b48 commit d89ed4f
Show file tree
Hide file tree
Showing 17 changed files with 52 additions and 35 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -558,7 +558,7 @@ from mlrun import get_run_db
# Get an MLRun DB object and connect to an MLRun database/API service.
# Specify the DB path (for example, './' for the current directory) or
# the API URL ('http://mlrun-api:8080' for the default configuration).
db = get_run_db('./').connect()
db = get_run_db('./')

# List all runs
db.list_runs('').show()
Expand Down
2 changes: 1 addition & 1 deletion docs/job-submission-and-tracking.md
Expand Up @@ -463,7 +463,7 @@ from mlrun import get_run_db
# Get an MLRun DB object and connect to an MLRun database/API service.
# Specify the DB path (for example, './' for the current directory) or
# the API URL ('http://mlrun-api:8080' for the default configuration).
db = get_run_db('./').connect()
db = get_run_db('./')

# List all runs
db.list_runs('').show()
Expand Down
2 changes: 1 addition & 1 deletion mlrun/__init__.py
Expand Up @@ -80,7 +80,7 @@ def set_environment(api_path: str = None, artifact_path: str = "", project: str
raise ValueError("DB/API path was not detected, please specify its address")

# check connectivity and load remote defaults
get_run_db().connect()
get_run_db()
if api_path:
environ["MLRUN_DBPATH"] = mlconf.dbpath

Expand Down
14 changes: 7 additions & 7 deletions mlrun/__main__.py
Expand Up @@ -524,7 +524,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args):
start = i.status.start_time.strftime("%b %d %H:%M:%S")
print("{:10} {:16} {:8} {}".format(state, start, task, name))
elif kind.startswith("runtime"):
mldb = get_run_db(db or mlconf.dbpath).connect()
mldb = get_run_db(db or mlconf.dbpath)
if name:
# the runtime identifier is its kind
runtime = mldb.get_runtime(kind=name, label_selector=selector)
Expand All @@ -533,7 +533,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args):
runtimes = mldb.list_runtimes(label_selector=selector)
print(dict_to_yaml(runtimes))
elif kind.startswith("run"):
mldb = get_run_db().connect()
mldb = get_run_db()
if name:
run = mldb.read_run(name, project=project)
print(dict_to_yaml(run))
Expand All @@ -550,7 +550,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args):
print(tabulate(df, headers="keys"))

elif kind.startswith("art"):
mldb = get_run_db().connect()
mldb = get_run_db()
artifacts = mldb.list_artifacts(name, project=project, tag=tag, labels=selector)
df = artifacts.to_df()[
["tree", "key", "iter", "kind", "path", "hash", "updated"]
Expand All @@ -560,7 +560,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args):
print(tabulate(df, headers="keys"))

elif kind.startswith("func"):
mldb = get_run_db().connect()
mldb = get_run_db()
if name:
f = mldb.get_function(name, project=project, tag=tag)
print(dict_to_yaml(f))
Expand Down Expand Up @@ -618,7 +618,7 @@ def version():
@click.option("--watch", "-w", is_flag=True, help="watch/follow log")
def logs(uid, project, offset, db, watch):
"""Get or watch task logs"""
mldb = get_run_db(db or mlconf.dbpath).connect()
mldb = get_run_db(db or mlconf.dbpath)
if mldb.kind == "http":
state = mldb.watch_log(uid, project, watch=watch, offset=offset)
else:
Expand Down Expand Up @@ -818,7 +818,7 @@ def clean(kind, object_id, api, label_selector, force, grace_period):
# Clean resources for specific job (by uid)
mlrun clean dask 15d04c19c2194c0a8efb26ea3017254b
"""
mldb = get_run_db(api or mlconf.dbpath).connect()
mldb = get_run_db(api or mlconf.dbpath)
if kind:
if object_id:
mldb.delete_runtime_object(
Expand Down Expand Up @@ -908,7 +908,7 @@ def func_url_to_runtime(func_url):
if func_url.startswith("db://"):
func_url = func_url[5:]
project, name, tag, hash_key = parse_function_uri(func_url)
mldb = get_run_db(mlconf.dbpath).connect()
mldb = get_run_db(mlconf.dbpath)
runtime = mldb.get_function(name, project, tag, hash_key)
else:
func_url = "function.yaml" if func_url == "." else func_url
Expand Down
2 changes: 1 addition & 1 deletion mlrun/config.py
Expand Up @@ -197,7 +197,7 @@ def dbpath(self, value):
import mlrun.db

# when dbpath is set we want to connect to it which will sync configuration from it to the client
mlrun.db.get_run_db(value).connect()
mlrun.db.get_run_db(value)


# Global configuration
Expand Down
2 changes: 1 addition & 1 deletion mlrun/datastore/datastore.py
Expand Up @@ -90,7 +90,7 @@ def set(self, secrets=None, db=None):

def _get_db(self):
if not self._db:
self._db = mlrun.get_run_db().connect(self._secrets)
self._db = mlrun.get_run_db(secrets=self._secrets)
return self._db

def from_dict(self, struct: dict):
Expand Down
21 changes: 19 additions & 2 deletions mlrun/db/__init__.py
Expand Up @@ -43,11 +43,26 @@ def get_httpdb_kwargs(host, username, password):
}


def get_run_db(url=""):
_run_db = None
_last_db_url = None


def get_run_db(url="", secrets=None, force_reconnect=False):
"""Returns the runtime database"""
global _run_db, _last_db_url

if not url:
url = get_or_set_dburl("./")

if (
_last_db_url is not None
and url == _last_db_url
and _run_db
and not force_reconnect
):
return _run_db
_last_db_url = url

parsed_url = urlparse(url)
scheme = parsed_url.scheme.lower()
kwargs = {}
Expand All @@ -68,4 +83,6 @@ def get_run_db(url=""):
else:
cls = SQLDB

return cls(url, **kwargs)
_run_db = cls(url, **kwargs)
_run_db.connect(secrets=secrets)
return _run_db
3 changes: 1 addition & 2 deletions mlrun/execution.py
Expand Up @@ -144,8 +144,7 @@ def set_logger_stream(self, stream):
def _init_dbs(self, rundb):
if rundb:
if isinstance(rundb, str):
self._rundb = get_run_db(rundb)
self._rundb.connect(self._secrets_manager)
self._rundb = get_run_db(rundb, secrets=self._secrets_manager)
else:
self._rundb = rundb
self._data_stores = store_manager.set(self._secrets_manager, db=self._rundb)
Expand Down
6 changes: 3 additions & 3 deletions mlrun/model.py
Expand Up @@ -563,7 +563,7 @@ def uid(self):
return self.metadata.uid

def state(self):
db = get_run_db().connect()
db = get_run_db()
run = db.read_run(
uid=self.metadata.uid,
project=self.metadata.project,
Expand All @@ -573,12 +573,12 @@ def state(self):
return get_in(run, "status.state", "unknown")

def show(self):
db = get_run_db().connect()
db = get_run_db()
db.list_runs(uid=self.metadata.uid, project=self.metadata.project).show()

def logs(self, watch=True, db=None):
if not db:
db = get_run_db().connect()
db = get_run_db()
if not db:
print("DB is not configured, cannot show logs")
return None
Expand Down
8 changes: 4 additions & 4 deletions mlrun/projects/project.py
Expand Up @@ -154,7 +154,7 @@ def _load_project_dir(context, name="", subpath=""):


def _load_project_from_db(url, secrets):
db = get_run_db().connect(secrets)
db = get_run_db(secrets=secrets)
project_name = url.replace("git://", "")
return db.get_project(project_name)

Expand Down Expand Up @@ -689,7 +689,7 @@ def register_artifacts(self):
def _get_artifact_manager(self):
if self._artifact_manager:
return self._artifact_manager
db = get_run_db().connect(self._secrets)
db = get_run_db(secrets=self._secrets)
sm = store_manager.set(self._secrets, db)
self._artifact_manager = ArtifactManager(sm, db)
return self._artifact_manager
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def get_run_status(
if run_info:
status = run_info["run"].get("status")

mldb = get_run_db().connect(self._secrets)
mldb = get_run_db(secrets=self._secrets)
runs = mldb.list_runs(
project=self.metadata.name, labels=f"workflow={workflow_id}"
)
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def save(self, filepath=None):
self.save_to_db()

def save_to_db(self):
db = get_run_db().connect(self._secrets)
db = get_run_db(secrets=self._secrets)
db.store_project(self.metadata.name, self.to_dict())

def export(self, filepath=None):
Expand Down
8 changes: 4 additions & 4 deletions mlrun/run.py
Expand Up @@ -366,7 +366,7 @@ def import_function(url="", secrets=None, db=""):
if url.startswith("db://"):
url = url[5:]
project, name, tag, hash_key = parse_function_uri(url)
db = get_run_db(db or get_or_set_dburl()).connect(secrets)
db = get_run_db(db or get_or_set_dburl(), secrets=secrets)
runtime = db.get_function(name, project, tag, hash_key)
if not runtime:
raise KeyError("function {}:{} not found in the DB".format(name, tag))
Expand Down Expand Up @@ -763,7 +763,7 @@ def run_pipeline(
arguments = arguments or {}

if remote or url:
mldb = get_run_db(url).connect()
mldb = get_run_db(url)
if mldb.kind != "http":
raise ValueError(
"run pipeline require access to remote api-service"
Expand Down Expand Up @@ -828,7 +828,7 @@ def wait_for_pipeline_completion(
)

if remote:
mldb = get_run_db().connect()
mldb = get_run_db()

def get_pipeline_if_completed(run_id, namespace=namespace):
resp = mldb.get_pipeline(run_id, namespace=namespace)
Expand Down Expand Up @@ -889,7 +889,7 @@ def get_pipeline(run_id, namespace=None):
namespace = namespace or mlconf.namespace
remote = not get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster()
if remote:
mldb = get_run_db().connect()
mldb = get_run_db()
if mldb.kind != "http":
raise ValueError(
"get pipeline require access to remote api-service"
Expand Down
2 changes: 1 addition & 1 deletion mlrun/runtimes/base.py
Expand Up @@ -197,7 +197,7 @@ def _get_db(self):
self._ensure_run_db()
if not self._db_conn:
if self.spec.rundb:
self._db_conn = get_run_db(self.spec.rundb).connect(self._secrets)
self._db_conn = get_run_db(self.spec.rundb, secrets=self._secrets)
return self._db_conn

def run(
Expand Down
2 changes: 1 addition & 1 deletion mlrun/runtimes/sparkjob.py
Expand Up @@ -174,7 +174,7 @@ def _default_image(self):
def deploy(self, watch=True, with_mlrun=True, skip_deployed=False, is_kfp=False):
"""deploy function, build container with dependencies"""
# connect will populate the config from the server config
get_run_db().connect()
get_run_db()
if not self.spec.build.base_image:
self.spec.build.base_image = self._default_image
return super().deploy(
Expand Down
2 changes: 1 addition & 1 deletion mlrun/runtimes/utils.py
Expand Up @@ -88,7 +88,7 @@ def resolve_mpijob_crd_version(api_context=False):
elif not in_k8s_cluster and not api_context:
# connect will populate the config from the server config
# TODO: something nicer
get_run_db().connect()
get_run_db()
mpijob_crd_version = config.mpijob_crd_version

# If resolution failed simply use default
Expand Down
7 changes: 4 additions & 3 deletions tests/api/conftest.py
Expand Up @@ -28,9 +28,6 @@ def db() -> Generator:
dsn = f"sqlite:///{db_file.name}?check_same_thread=false"
config.httpdb.dsn = dsn

# we're also running client code in tests
config.dbpath = dsn

# TODO: make it simpler - doesn't make sense to call 3 different functions to initialize the db
# we need to force re-init the engine cause otherwise it is cached between tests
_init_engine(config.httpdb.dsn)
Expand All @@ -39,6 +36,10 @@ def db() -> Generator:
init_data(from_scratch=True)
initialize_db()
initialize_project_member()

# we're also running client code in tests so set dbpath as well
# note that setting this attribute triggers connection to the run db therefore must happen after the initialization
config.dbpath = dsn
yield create_session()
logger.info(f"Removing temp db file: {db_file.name}")
db_file.close()
Expand Down
2 changes: 1 addition & 1 deletion tests/rundb/test_rundb.py
Expand Up @@ -18,7 +18,7 @@


def get_db():
return mlrun.get_run_db(rundb_path).connect()
return mlrun.get_run_db(rundb_path)


#
Expand Down
2 changes: 1 addition & 1 deletion tests/test_run.py
Expand Up @@ -170,7 +170,7 @@ def test_local_no_context():
).run(spec)
verify_state(result)

db = get_run_db().connect()
db = get_run_db()
state, log = db.get_log(result.metadata.uid)
log = str(log)
print(state)
Expand Down

0 comments on commit d89ed4f

Please sign in to comment.