Skip to content

Commit

Permalink
[Runtimes] Don't resolve completion of runs per execution (#2888)
Browse files Browse the repository at this point in the history
  • Loading branch information
alonmr committed Jan 21, 2023
1 parent e09cc0e commit 1229312
Show file tree
Hide file tree
Showing 13 changed files with 402 additions and 84 deletions.
1 change: 0 additions & 1 deletion mlrun/db/filedb.py
Expand Up @@ -111,7 +111,6 @@ def store_run(self, struct, uid, project="", iter=0):

def update_run(self, updates: dict, uid, project="", iter=0):
run = self.read_run(uid, project, iter=iter)
# TODO: Should we raise if run not found?
if run and updates:
for key, val in updates.items():
update_in(run, key, val)
Expand Down
159 changes: 120 additions & 39 deletions mlrun/execution.py
Expand Up @@ -19,6 +19,7 @@
from typing import List, Union

import numpy as np
import yaml

import mlrun
from mlrun.artifacts import ModelArtifact
Expand Down Expand Up @@ -93,6 +94,8 @@ def __init__(self, autocommit=False, tmp="", log_stream=None):
self._outputs = []

self._results = {}
# tracks the execution state, completion of runs is not decided by the execution
# as there may be multiple executions for a single run (e.g mpi)
self._state = "created"
self._error = None
self._commit = ""
Expand All @@ -113,7 +116,7 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, exc_traceback):
if exc_value:
self.set_state(error=exc_value, commit=False)
self.commit()
self.commit(completed=True)

def get_child_context(self, with_parent_params=False, **params):
"""get child context (iteration)
Expand Down Expand Up @@ -259,7 +262,7 @@ def from_dict(
host=None,
log_stream=None,
is_api=False,
update_db=True,
store_run=True,
):
"""create execution context from dict"""

Expand Down Expand Up @@ -314,8 +317,8 @@ def from_dict(
if start:
self._start_time = start
self._state = "running"
if update_db:
self._update_db(commit=True)
if store_run:
self.store_run()
return self

@property
Expand All @@ -330,6 +333,11 @@ def tag(self):
"""run tag (uid or workflow id if exists)"""
return self._labels.get("workflow") or self._uid

@property
def state(self):
"""execution state"""
return self._state

@property
def iteration(self):
"""child iteration index, for hyper parameters"""
Expand Down Expand Up @@ -445,7 +453,7 @@ def get_param(self, key: str, default=None):
if key not in self._parameters:
self._parameters[key] = default
if default:
self._update_db()
self._update_run()
return default
return self._parameters[key]

Expand Down Expand Up @@ -520,7 +528,7 @@ def log_result(self, key: str, value, commit=False):
:param commit: commit (write to DB now vs wait for the end of the run)
"""
self._results[str(key)] = _cast_result(value)
self._update_db(commit=commit)
self._update_run(commit=commit)

def log_results(self, results: dict, commit=False):
"""log a set of scalar result values
Expand All @@ -539,7 +547,7 @@ def log_results(self, results: dict, commit=False):

for p in results.keys():
self._results[str(p)] = _cast_result(results[p])
self._update_db(commit=commit)
self._update_run(commit=commit)

def log_iteration_results(self, best, summary: list, task: dict, commit=False):
"""Reserved for internal use"""
Expand All @@ -566,7 +574,7 @@ def log_iteration_results(self, best, summary: list, task: dict, commit=False):
if summary is not None:
self._iteration_results = summary
if commit:
self._update_db(commit=True)
self._update_run(commit=True)

def log_metric(self, key: str, value, timestamp=None, labels=None):
"""TBD, log a real-time time-series metric"""
Expand Down Expand Up @@ -648,7 +656,7 @@ def log_artifact(
format=format,
**kwargs,
)
self._update_db()
self._update_run()
return item

def log_dataset(
Expand Down Expand Up @@ -727,7 +735,7 @@ def log_dataset(
db_key=db_key,
labels=labels,
)
self._update_db()
self._update_run()
return item

def log_model(
Expand Down Expand Up @@ -829,7 +837,7 @@ def log_model(
db_key=db_key,
labels=labels,
)
self._update_db()
self._update_run()
return item

def get_cached_artifact(self, key):
Expand All @@ -840,13 +848,16 @@ def update_artifact(self, artifact_object):
"""update an artifact object in the cache and the DB"""
self._artifacts_manager.update_artifact(self, artifact_object)

def commit(self, message: str = "", completed=True):
def commit(self, message: str = "", completed=False):
"""save run state and optionally add a commit message
:param message: commit message to save in the run
:param completed: mark run as completed
"""
completed = completed and self._state == "running"
# changing state to completed is allowed only when the execution is in running state
if self._state != "running":
completed = False

if message:
self._annotations["message"] = message
if completed:
Expand All @@ -855,32 +866,42 @@ def commit(self, message: str = "", completed=True):
if self._parent:
self._parent.update_child_iterations()
self._parent._last_update = now_date()
self._parent._update_db(commit=True, message=message)
self._parent._update_run(commit=True, message=message)

if self._children:
self.update_child_iterations(commit_children=True, completed=completed)
self._last_update = now_date()
self._update_db(commit=True, message=message)
self._update_run(commit=True, message=message)
if completed and not self.iteration:
mlrun.runtimes.utils.global_context.set(None)

def set_state(self, state: str = None, error: str = None, commit=True):
"""modify and store the run state or mark an error
def set_state(self, execution_state: str = None, error: str = None, commit=True):
"""
Modify and store the execution state or mark an error and update the run state accordingly.
This method allows to set the run state to 'completed' in the DB which is discouraged.
Completion of runs should be decided externally to the execution context.
:param state: set run state
:param error: error message (if exist will set the state to error)
:param commit: will immediately update the state in the DB
:param execution_state: set execution state
:param error: error message (if exist will set the state to error)
:param commit: will immediately update the state in the DB
"""
# TODO: The execution context should not set the run state to completed.
# Create a separate state for the execution in the run object.

updates = {"status.last_update": now_date().isoformat()}

if error:
self._state = "error"
self._error = str(error)
updates["status.state"] = "error"
updates["status.error"] = error
elif state and state != self._state and self._state != "error":
self._state = state
updates["status.state"] = state
elif (
execution_state
and execution_state != self._state
and self._state != "error"
):
self._state = execution_state
updates["status.state"] = execution_state
self._last_update = now_date()

if self._rundb and commit:
Expand All @@ -900,9 +921,9 @@ def set_hostname(self, host: str):
def to_dict(self):
"""convert the run context to a dictionary"""

def set_if_valid(struct, key, val):
def set_if_not_none(_struct, key, val):
if val:
struct[key] = val
_struct[key] = val

struct = {
"kind": "run",
Expand All @@ -924,26 +945,52 @@ def set_if_valid(struct, key, val):
run_keys.inputs: {k: v.artifact_url for k, v in self._inputs.items()},
},
"status": {
"state": self._state,
"results": self._results,
"start_time": to_date_str(self._start_time),
"last_update": to_date_str(self._last_update),
},
}

# completion of runs is not decided by the execution as there may be
# multiple executions for a single run (e.g. mpi)
if self._state != "completed":
struct["status"]["state"] = self._state

if not self._iteration:
struct["spec"]["hyperparams"] = self._hyperparams
struct["spec"]["hyper_param_options"] = self._hyper_param_options.to_dict()

set_if_valid(struct["status"], "error", self._error)
set_if_valid(struct["status"], "commit", self._commit)
set_if_not_none(struct["status"], "error", self._error)
set_if_not_none(struct["status"], "commit", self._commit)
set_if_not_none(struct["status"], "iterations", self._iteration_results)

if self._iteration_results:
struct["status"]["iterations"] = self._iteration_results
struct["status"][run_keys.artifacts] = self._artifacts_manager.artifact_list()
self._data_stores.to_dict(struct["spec"])
return struct

def _get_updates(self):
def set_if_not_none(_struct, key, val):
if val:
_struct[key] = val

struct = {
"status.results": self._results,
"status.start_time": to_date_str(self._start_time),
"status.last_update": to_date_str(self._last_update),
}

# completion of runs is not decided by the execution as there may be
# multiple executions for a single run (e.g. mpi)
if self._state != "completed":
struct["status.state"] = self._state

set_if_not_none(struct, "status.error", self._error)
set_if_not_none(struct, "status.commit", self._commit)
set_if_not_none(struct, "status.iterations", self._iteration_results)

struct[f"status.{run_keys.artifacts}"] = self._artifacts_manager.artifact_list()
return struct

def to_yaml(self):
"""convert the run context to a yaml buffer"""
return dict_to_yaml(self.to_dict())
Expand All @@ -952,21 +999,55 @@ def to_json(self):
"""convert the run context to a json buffer"""
return dict_to_json(self.to_dict())

def _update_db(self, commit=False, message=""):
self.last_update = now_date()
if self._tmpfile:
data = self.to_json()
with open(self._tmpfile, "w") as fp:
fp.write(data)
fp.close()
def store_run(self):
self._write_tmpfile()
if self._rundb:
self._rundb.store_run(
self.to_dict(), self._uid, self.project, iter=self._iteration
)

def _update_run(self, commit=False, message=""):
"""
update the required fields in the run object (using mlrun.utils.helpers.update_in)
instead of overwriting existing
"""
self._merge_tmpfile()
if commit or self._autocommit:
self._commit = message
if self._rundb:
self._rundb.store_run(
self.to_dict(), self._uid, self.project, iter=self._iteration
self._rundb.update_run(
self._get_updates(), self._uid, self.project, iter=self._iteration
)

def _merge_tmpfile(self):
if not self._tmpfile:
return

loaded_run = self._read_tmpfile()
dict_run = self.to_dict()
if loaded_run:
for key, val in dict_run.items():
update_in(loaded_run, key, val)
else:
loaded_run = dict_run

self._write_tmpfile(json=dict_to_json(loaded_run))

def _read_tmpfile(self):
if self._tmpfile:
with open(self._tmpfile) as fp:
return yaml.safe_load(fp)

return None

def _write_tmpfile(self, json=None):
self.last_update = now_date()
if self._tmpfile:
data = json or self.to_json()
with open(self._tmpfile, "w") as fp:
fp.write(data)
fp.close()


def _cast_result(value):
if isinstance(value, (int, str, float)):
Expand Down

0 comments on commit 1229312

Please sign in to comment.