Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SDK] Cache run db instance #614

Merged
merged 2 commits into from
Dec 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,7 @@ from mlrun import get_run_db
# Get an MLRun DB object and connect to an MLRun database/API service.
# Specify the DB path (for example, './' for the current directory) or
# the API URL ('http://mlrun-api:8080' for the default configuration).
db = get_run_db('./').connect()
db = get_run_db('./')

# List all runs
db.list_runs('').show()
Expand Down
2 changes: 1 addition & 1 deletion docs/job-submission-and-tracking.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ from mlrun import get_run_db
# Get an MLRun DB object and connect to an MLRun database/API service.
# Specify the DB path (for example, './' for the current directory) or
# the API URL ('http://mlrun-api:8080' for the default configuration).
db = get_run_db('./').connect()
db = get_run_db('./')

# List all runs
db.list_runs('').show()
Expand Down
2 changes: 1 addition & 1 deletion mlrun/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def set_environment(api_path: str = None, artifact_path: str = "", project: str
raise ValueError("DB/API path was not detected, please specify its address")

# check connectivity and load remote defaults
get_run_db().connect()
get_run_db()
if api_path:
environ["MLRUN_DBPATH"] = mlconf.dbpath

Expand Down
14 changes: 7 additions & 7 deletions mlrun/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args):
start = i.status.start_time.strftime("%b %d %H:%M:%S")
print("{:10} {:16} {:8} {}".format(state, start, task, name))
elif kind.startswith("runtime"):
mldb = get_run_db(db or mlconf.dbpath).connect()
mldb = get_run_db(db or mlconf.dbpath)
if name:
# the runtime identifier is its kind
runtime = mldb.get_runtime(kind=name, label_selector=selector)
Expand All @@ -533,7 +533,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args):
runtimes = mldb.list_runtimes(label_selector=selector)
print(dict_to_yaml(runtimes))
elif kind.startswith("run"):
mldb = get_run_db().connect()
mldb = get_run_db()
if name:
run = mldb.read_run(name, project=project)
print(dict_to_yaml(run))
Expand All @@ -550,7 +550,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args):
print(tabulate(df, headers="keys"))

elif kind.startswith("art"):
mldb = get_run_db().connect()
mldb = get_run_db()
artifacts = mldb.list_artifacts(name, project=project, tag=tag, labels=selector)
df = artifacts.to_df()[
["tree", "key", "iter", "kind", "path", "hash", "updated"]
Expand All @@ -560,7 +560,7 @@ def get(kind, name, selector, namespace, uid, project, tag, db, extra_args):
print(tabulate(df, headers="keys"))

elif kind.startswith("func"):
mldb = get_run_db().connect()
mldb = get_run_db()
if name:
f = mldb.get_function(name, project=project, tag=tag)
print(dict_to_yaml(f))
Expand Down Expand Up @@ -618,7 +618,7 @@ def version():
@click.option("--watch", "-w", is_flag=True, help="watch/follow log")
def logs(uid, project, offset, db, watch):
"""Get or watch task logs"""
mldb = get_run_db(db or mlconf.dbpath).connect()
mldb = get_run_db(db or mlconf.dbpath)
if mldb.kind == "http":
state = mldb.watch_log(uid, project, watch=watch, offset=offset)
else:
Expand Down Expand Up @@ -818,7 +818,7 @@ def clean(kind, object_id, api, label_selector, force, grace_period):
# Clean resources for specific job (by uid)
mlrun clean dask 15d04c19c2194c0a8efb26ea3017254b
"""
mldb = get_run_db(api or mlconf.dbpath).connect()
mldb = get_run_db(api or mlconf.dbpath)
if kind:
if object_id:
mldb.delete_runtime_object(
Expand Down Expand Up @@ -908,7 +908,7 @@ def func_url_to_runtime(func_url):
if func_url.startswith("db://"):
func_url = func_url[5:]
project, name, tag, hash_key = parse_function_uri(func_url)
mldb = get_run_db(mlconf.dbpath).connect()
mldb = get_run_db(mlconf.dbpath)
runtime = mldb.get_function(name, project, tag, hash_key)
else:
func_url = "function.yaml" if func_url == "." else func_url
Expand Down
2 changes: 1 addition & 1 deletion mlrun/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def dbpath(self, value):
import mlrun.db

# when dbpath is set we want to connect to it which will sync configuration from it to the client
mlrun.db.get_run_db(value).connect()
mlrun.db.get_run_db(value)


# Global configuration
Expand Down
2 changes: 1 addition & 1 deletion mlrun/datastore/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def set(self, secrets=None, db=None):

def _get_db(self):
if not self._db:
self._db = mlrun.get_run_db().connect(self._secrets)
self._db = mlrun.get_run_db(secrets=self._secrets)
return self._db

def from_dict(self, struct: dict):
Expand Down
21 changes: 19 additions & 2 deletions mlrun/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,26 @@ def get_httpdb_kwargs(host, username, password):
}


def get_run_db(url=""):
_run_db = None
_last_db_url = None


def get_run_db(url="", secrets=None, force_reconnect=False):
"""Returns the runtime database"""
global _run_db, _last_db_url

if not url:
url = get_or_set_dburl("./")

if (
_last_db_url is not None
and url == _last_db_url
and _run_db
and not force_reconnect
):
return _run_db
_last_db_url = url

parsed_url = urlparse(url)
scheme = parsed_url.scheme.lower()
kwargs = {}
Expand All @@ -68,4 +83,6 @@ def get_run_db(url=""):
else:
cls = SQLDB

return cls(url, **kwargs)
_run_db = cls(url, **kwargs)
_run_db.connect(secrets=secrets)
return _run_db
3 changes: 1 addition & 2 deletions mlrun/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,7 @@ def set_logger_stream(self, stream):
def _init_dbs(self, rundb):
if rundb:
if isinstance(rundb, str):
self._rundb = get_run_db(rundb)
self._rundb.connect(self._secrets_manager)
self._rundb = get_run_db(rundb, secrets=self._secrets_manager)
else:
self._rundb = rundb
self._data_stores = store_manager.set(self._secrets_manager, db=self._rundb)
Expand Down
6 changes: 3 additions & 3 deletions mlrun/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ def uid(self):
return self.metadata.uid

def state(self):
db = get_run_db().connect()
db = get_run_db()
run = db.read_run(
uid=self.metadata.uid,
project=self.metadata.project,
Expand All @@ -573,12 +573,12 @@ def state(self):
return get_in(run, "status.state", "unknown")

def show(self):
db = get_run_db().connect()
db = get_run_db()
db.list_runs(uid=self.metadata.uid, project=self.metadata.project).show()

def logs(self, watch=True, db=None):
if not db:
db = get_run_db().connect()
db = get_run_db()
if not db:
print("DB is not configured, cannot show logs")
return None
Expand Down
8 changes: 4 additions & 4 deletions mlrun/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _load_project_dir(context, name="", subpath=""):


def _load_project_from_db(url, secrets):
db = get_run_db().connect(secrets)
db = get_run_db(secrets=secrets)
project_name = url.replace("git://", "")
return db.get_project(project_name)

Expand Down Expand Up @@ -689,7 +689,7 @@ def register_artifacts(self):
def _get_artifact_manager(self):
if self._artifact_manager:
return self._artifact_manager
db = get_run_db().connect(self._secrets)
db = get_run_db(secrets=self._secrets)
sm = store_manager.set(self._secrets, db)
self._artifact_manager = ArtifactManager(sm, db)
return self._artifact_manager
Expand Down Expand Up @@ -1076,7 +1076,7 @@ def get_run_status(
if run_info:
status = run_info["run"].get("status")

mldb = get_run_db().connect(self._secrets)
mldb = get_run_db(secrets=self._secrets)
runs = mldb.list_runs(
project=self.metadata.name, labels=f"workflow={workflow_id}"
)
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def save(self, filepath=None):
self.save_to_db()

def save_to_db(self):
db = get_run_db().connect(self._secrets)
db = get_run_db(secrets=self._secrets)
db.store_project(self.metadata.name, self.to_dict())

def export(self, filepath=None):
Expand Down
8 changes: 4 additions & 4 deletions mlrun/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ def import_function(url="", secrets=None, db=""):
if url.startswith("db://"):
url = url[5:]
project, name, tag, hash_key = parse_function_uri(url)
db = get_run_db(db or get_or_set_dburl()).connect(secrets)
db = get_run_db(db or get_or_set_dburl(), secrets=secrets)
runtime = db.get_function(name, project, tag, hash_key)
if not runtime:
raise KeyError("function {}:{} not found in the DB".format(name, tag))
Expand Down Expand Up @@ -763,7 +763,7 @@ def run_pipeline(
arguments = arguments or {}

if remote or url:
mldb = get_run_db(url).connect()
mldb = get_run_db(url)
if mldb.kind != "http":
raise ValueError(
"run pipeline require access to remote api-service"
Expand Down Expand Up @@ -828,7 +828,7 @@ def wait_for_pipeline_completion(
)

if remote:
mldb = get_run_db().connect()
mldb = get_run_db()

def get_pipeline_if_completed(run_id, namespace=namespace):
resp = mldb.get_pipeline(run_id, namespace=namespace)
Expand Down Expand Up @@ -889,7 +889,7 @@ def get_pipeline(run_id, namespace=None):
namespace = namespace or mlconf.namespace
remote = not get_k8s_helper(silent=True).is_running_inside_kubernetes_cluster()
if remote:
mldb = get_run_db().connect()
mldb = get_run_db()
if mldb.kind != "http":
raise ValueError(
"get pipeline require access to remote api-service"
Expand Down
2 changes: 1 addition & 1 deletion mlrun/runtimes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def _get_db(self):
self._ensure_run_db()
if not self._db_conn:
if self.spec.rundb:
self._db_conn = get_run_db(self.spec.rundb).connect(self._secrets)
self._db_conn = get_run_db(self.spec.rundb, secrets=self._secrets)
return self._db_conn

def run(
Expand Down
2 changes: 1 addition & 1 deletion mlrun/runtimes/sparkjob.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def _default_image(self):
def deploy(self, watch=True, with_mlrun=True, skip_deployed=False, is_kfp=False):
"""deploy function, build container with dependencies"""
# connect will populate the config from the server config
get_run_db().connect()
get_run_db()
if not self.spec.build.base_image:
self.spec.build.base_image = self._default_image
return super().deploy(
Expand Down
2 changes: 1 addition & 1 deletion mlrun/runtimes/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def resolve_mpijob_crd_version(api_context=False):
elif not in_k8s_cluster and not api_context:
# connect will populate the config from the server config
# TODO: something nicer
get_run_db().connect()
get_run_db()
mpijob_crd_version = config.mpijob_crd_version

# If resolution failed simply use default
Expand Down
7 changes: 4 additions & 3 deletions tests/api/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ def db() -> Generator:
dsn = f"sqlite:///{db_file.name}?check_same_thread=false"
config.httpdb.dsn = dsn

# we're also running client code in tests
config.dbpath = dsn

# TODO: make it simpler - doesn't make sense to call 3 different functions to initialize the db
# we need to force re-init the engine cause otherwise it is cached between tests
_init_engine(config.httpdb.dsn)
Expand All @@ -39,6 +36,10 @@ def db() -> Generator:
init_data(from_scratch=True)
initialize_db()
initialize_project_member()

# we're also running client code in tests so set dbpath as well
# note that setting this attribute triggers connection to the run db therefore must happen after the initialization
config.dbpath = dsn
yield create_session()
logger.info(f"Removing temp db file: {db_file.name}")
db_file.close()
Expand Down
2 changes: 1 addition & 1 deletion tests/rundb/test_rundb.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


def get_db():
return mlrun.get_run_db(rundb_path).connect()
return mlrun.get_run_db(rundb_path)


#
Expand Down
2 changes: 1 addition & 1 deletion tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_local_no_context():
).run(spec)
verify_state(result)

db = get_run_db().connect()
db = get_run_db()
state, log = db.get_log(result.metadata.uid)
log = str(log)
print(state)
Expand Down