Skip to content

Commit

Permalink
[DB] Align all upserts to happen from helper method (#1810)
Browse files Browse the repository at this point in the history
(cherry picked from commit feac66f)
  • Loading branch information
Hedingber committed Mar 14, 2022
1 parent 9fb9a61 commit 8a140a1
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 37 deletions.
85 changes: 53 additions & 32 deletions mlrun/api/db/sqldb/db.py
Expand Up @@ -62,6 +62,13 @@
unversioned_tagged_object_uid_prefix = "unversioned-"


conflict_messages = [
"(sqlite3.IntegrityError) UNIQUE constraint failed",
"(pymysql.err.IntegrityError) (1062",
"(pymysql.err.IntegrityError) (1586",
]


def retry_on_conflict(function):
"""
Most of our store_x functions starting from doing get, then if nothing is found creating the object otherwise
Expand All @@ -79,11 +86,7 @@ def _try_function():
try:
return function(*args, **kwargs)
except Exception as exc:
conflict_messages = [
"(sqlite3.IntegrityError) UNIQUE constraint failed",
"(pymysql.err.IntegrityError) (1062",
"(pymysql.err.IntegrityError) (1586",
]

if mlrun.utils.helpers.are_strings_in_exception_chain_messages(
exc, conflict_messages
):
Expand Down Expand Up @@ -172,7 +175,7 @@ def store_run(
run.start_time = start_time
self._update_run_updated_time(run, run_data, now=now)
run.struct = run_data
self._upsert(session, run, ignore=True)
self._upsert(session, [run], ignore=True)

def update_run(self, session, updates: dict, uid, project="", iter=0):
project = project or config.default_project
Expand All @@ -191,8 +194,7 @@ def update_run(self, session, updates: dict, uid, project="", iter=0):
update_labels(run, run_labels(struct))
self._update_run_updated_time(run, struct)
run.struct = struct
session.merge(run)
session.commit()
self._upsert(session, [run])
self._delete_empty_labels(session, Run.Label)

def read_run(self, session, uid, project=None, iter=0):
Expand Down Expand Up @@ -363,7 +365,7 @@ def _store_artifact(
artifact.pop("tag", None)

art.struct = artifact
self._upsert(session, art)
self._upsert(session, [art])
if tag_artifact:
tag = tag or "latest"
self.tag_artifacts(session, [art], project, tag)
Expand Down Expand Up @@ -589,7 +591,7 @@ def store_function(
labels = get_in(function, "metadata.labels", {})
update_labels(fn, labels)
fn.struct = function
self._upsert(session, fn)
self._upsert(session, [fn])
self.tag_objects_v2(session, [fn], project, tag)
return hash_key

Expand Down Expand Up @@ -765,7 +767,7 @@ def create_schedule(
cron_trigger=cron_trigger,
concurrency_limit=concurrency_limit,
)
self._upsert(session, schedule)
self._upsert(session, [schedule])

def update_schedule(
self,
Expand Down Expand Up @@ -804,8 +806,7 @@ def update_schedule(
labels=labels,
concurrency_limit=concurrency_limit,
)
session.merge(schedule)
session.commit()
self._upsert(session, [schedule])

def list_schedules(
self,
Expand Down Expand Up @@ -902,9 +903,10 @@ def tag_artifacts(self, session, artifacts, project: str, name: str):
if not tag:
tag = artifact.Tag(project=project, name=name)
tag.obj_id = artifact.id
self._upsert(session, tag, ignore=True)
self._upsert(session, [tag], ignore=True)

def tag_objects_v2(self, session, objs, project: str, name: str):
tags = []
for obj in objs:
query = self._query(
session, obj.Tag, name=name, project=project, obj_name=obj.name
Expand All @@ -913,8 +915,8 @@ def tag_objects_v2(self, session, objs, project: str, name: str):
if not tag:
tag = obj.Tag(project=project, name=name, obj_name=obj.name)
tag.obj_id = obj.id
session.add(tag)
session.commit()
tags.append(tag)
self._upsert(session, tags)

def create_project(self, session: Session, project: schemas.Project):
logger.debug("Creating project in DB", project=project)
Expand All @@ -931,7 +933,7 @@ def create_project(self, session: Session, project: schemas.Project):
)
labels = project.metadata.labels or {}
update_labels(project_record, labels)
self._upsert(session, project_record)
self._upsert(session, [project_record])

@retry_on_conflict
def store_project(self, session: Session, name: str, project: schemas.Project):
Expand Down Expand Up @@ -1199,7 +1201,7 @@ def _update_project_record_from_project(
project_record.state = project.status.state
labels = project.metadata.labels or {}
update_labels(project_record, labels)
self._upsert(session, project_record)
self._upsert(session, [project_record])

def _patch_project_record_from_project(
self,
Expand All @@ -1221,7 +1223,7 @@ def _patch_project_record_from_project(
)

project_record.full_object = project_record_full_object
self._upsert(session, project_record)
self._upsert(session, [project_record])

def is_project_exists(self, session: Session, name: str, **kwargs):
project_record = self._get_project_record(
Expand Down Expand Up @@ -1789,7 +1791,7 @@ def store_feature_set(
self._update_db_record_from_object_dict(db_feature_set, feature_set_dict, uid)

self._update_feature_set_spec(db_feature_set, feature_set_dict)
self._upsert(session, db_feature_set)
self._upsert(session, [db_feature_set])
self.tag_objects_v2(session, [db_feature_set], project, tag)

return uid
Expand Down Expand Up @@ -1833,7 +1835,7 @@ def create_feature_set(
self._update_db_record_from_object_dict(db_feature_set, feature_set_dict, uid)
self._update_feature_set_spec(db_feature_set, feature_set_dict)

self._upsert(session, db_feature_set)
self._upsert(session, [db_feature_set])
self.tag_objects_v2(session, [db_feature_set], project, tag)

return uid
Expand Down Expand Up @@ -1927,7 +1929,7 @@ def create_feature_vector(
db_feature_vector, feature_vector_dict, uid
)

self._upsert(session, db_feature_vector)
self._upsert(session, [db_feature_vector])
self.tag_objects_v2(session, [db_feature_vector], project, tag)

return uid
Expand Down Expand Up @@ -2080,7 +2082,7 @@ def store_feature_vector(
db_feature_vector, feature_vector_dict, uid
)

self._upsert(session, db_feature_vector)
self._upsert(session, [db_feature_vector])
self.tag_objects_v2(session, [db_feature_vector], project, tag)

return uid
Expand Down Expand Up @@ -2214,29 +2216,45 @@ def _delete_empty_labels(self, session, cls):
session.query(cls).filter(cls.parent == NULL).delete()
session.commit()

def _upsert(self, session, obj, ignore=False):
def _upsert(self, session, objects, ignore=False):
if not objects:
return
for object_ in objects:
session.add(object_)
self._commit(session, objects, ignore)

def _commit(self, session, objects, ignore=False):
def _try_commit_obj():
try:
session.add(obj)
session.commit()
except SQLAlchemyError as err:
session.rollback()
cls = obj.__class__.__name__
cls = objects[0].__class__.__name__
if "database is locked" in str(err):
logger.warning(
"Database is locked. Retrying", cls=cls, err=str(err)
)
raise mlrun.errors.MLRunRuntimeError(
"Failed adding resource, database is locked"
"Failed committing changes, database is locked"
) from err
logger.warning("Conflict adding resource to DB", cls=cls, err=str(err))
logger.warning("Failed committing changes to DB", cls=cls, err=str(err))
if not ignore:
identifiers = ",".join(
object_.get_identifier_string() for object_ in objects
)
# We want to retry only when database is locked so for any other scenario escalate to fatal failure
try:
raise mlrun.errors.MLRunConflictError(
f"Conflict - {cls} already exists: {obj.get_identifier_string()}"
if any([message in str(err) for message in conflict_messages]):
raise mlrun.errors.MLRunConflictError(
f"Conflict - {cls} already exists: {identifiers}"
) from err
raise mlrun.errors.MLRunRuntimeError(
f"Failed committing changes to DB. class={cls} objects={identifiers}"
) from err
except mlrun.errors.MLRunConflictError as exc:
except (
mlrun.errors.MLRunRuntimeError,
mlrun.errors.MLRunConflictError,
) as exc:
raise mlrun.errors.MLRunFatalFailureError(
original_exception=exc
)
Expand Down Expand Up @@ -2584,6 +2602,7 @@ def _move_and_reorder_table_items(

if move_from == move_to:
# It's just modifying the same object - update and exit.
# using merge since primary key is changing
session.merge(moved_object)
session.commit()
return
Expand All @@ -2609,9 +2628,11 @@ def _move_and_reorder_table_items(

for source_record in query:
source_record.index = source_record.index + modifier
# using merge since primary key is changing
session.merge(source_record)

if move_to:
# using merge since primary key is changing
session.merge(moved_object)
else:
session.delete(moved_object)
Expand Down Expand Up @@ -2804,4 +2825,4 @@ def create_data_version(self, session, version):

now = datetime.now(timezone.utc)
data_version_record = DataVersion(version=version, created=now)
self._upsert(session, data_version_record)
self._upsert(session, [data_version_record])
2 changes: 1 addition & 1 deletion mlrun/api/initial_data.py
Expand Up @@ -405,7 +405,7 @@ def _align_runs_table(
)
db._update_run_updated_time(run, run_dict, updated)
run.struct = run_dict
db._upsert(db_session, run, ignore=True)
db._upsert(db_session, [run], ignore=True)


def _perform_version_1_data_migrations(
Expand Down
4 changes: 2 additions & 2 deletions tests/api/api/test_runs.py
Expand Up @@ -112,7 +112,7 @@ def test_list_runs_times_filters(db: Session, client: TestClient) -> None:
updated=run_1_update_time,
)
run.struct = run_1
get_db()._upsert(db, run, ignore=True)
get_db()._upsert(db, [run], ignore=True)

between_run_1_and_2 = datetime.now(timezone.utc)

Expand All @@ -138,7 +138,7 @@ def test_list_runs_times_filters(db: Session, client: TestClient) -> None:
updated=run_2_update_time,
)
run.struct = run_2
get_db()._upsert(db, run, ignore=True)
get_db()._upsert(db, [run], ignore=True)

# all start time range
assert_time_range_request(client, [run_1_uid, run_2_uid])
Expand Down
4 changes: 2 additions & 2 deletions tests/api/db/test_runs.py
Expand Up @@ -138,7 +138,7 @@ def test_data_migration_align_runs_table(db: DBInterface, db_session: Session):
runs = db._find_runs(db_session, None, "*", None).all()
for run in runs:
_change_run_record_to_before_align_runs_migration(run, time_before_creation)
db._upsert(db_session, run, ignore=True)
db._upsert(db_session, [run], ignore=True)

# run the migration
mlrun.api.initial_data._align_runs_table(db, db_session)
Expand Down Expand Up @@ -168,7 +168,7 @@ def test_data_migration_align_runs_table_with_empty_run_body(
# change to be as it will be in field (before the migration) and then empty the body
_change_run_record_to_before_align_runs_migration(run, time_before_creation)
run.struct = {}
db._upsert(db_session, run, ignore=True)
db._upsert(db_session, [run], ignore=True)

# run the migration
mlrun.api.initial_data._align_runs_table(db, db_session)
Expand Down

0 comments on commit 8a140a1

Please sign in to comment.