Skip to content

Commit

Permalink
[Artifacts] Add data migration to handle old corrupted artifact tags (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Jan 12, 2021
1 parent 46bd4fe commit e2361f3
Show file tree
Hide file tree
Showing 4 changed files with 291 additions and 22 deletions.
50 changes: 37 additions & 13 deletions mlrun/api/db/sqldb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
UniqueConstraint,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import relationship
from sqlalchemy.orm import relationship, class_mapper

from mlrun.api import schemas

Expand All @@ -39,7 +39,23 @@
run_time_fmt = "%Y-%m-%dT%H:%M:%S.%fZ"


class HasStruct:
class BaseModel:
def to_dict(self, exclude=None):
"""
NOTE - this function (currently) does not handle serializing relationships
"""
exclude = exclude or []
mapper = class_mapper(self.__class__)
columns = [column.key for column in mapper.columns if column.key not in exclude]
get_key_value = (
lambda c: (c, getattr(self, c).isoformat())
if isinstance(getattr(self, c), datetime)
else (c, getattr(self, c))
)
return dict(map(get_key_value, columns))


class HasStruct(BaseModel):
@property
def struct(self):
return pickle.loads(self.body)
Expand All @@ -48,9 +64,17 @@ def struct(self):
def struct(self, value):
self.body = pickle.dumps(value)

def to_dict(self, exclude=None):
"""
NOTE - this function (currently) does not handle serializing relationships
"""
exclude = exclude or []
exclude.append("body")
return super().to_dict(exclude)


def make_label(table):
class Label(Base):
class Label(Base, BaseModel):
__tablename__ = f"{table}_labels"
__table_args__ = (
UniqueConstraint("name", "parent", name=f"_{table}_labels_uc"),
Expand All @@ -65,7 +89,7 @@ class Label(Base):


def make_tag(table):
class Tag(Base):
class Tag(Base, BaseModel):
__tablename__ = f"{table}_tags"
__table_args__ = (
UniqueConstraint("project", "name", "obj_id", name=f"_{table}_tags_uc"),
Expand All @@ -82,7 +106,7 @@ class Tag(Base):
# TODO: don't want to refactor everything in one PR so splitting this function to 2 versions - eventually only this one
# should be used
def make_tag_v2(table):
class Tag(Base):
class Tag(Base, BaseModel):
__tablename__ = f"{table}_tags"
__table_args__ = (
UniqueConstraint("project", "name", "obj_name", name=f"_{table}_tags_uc"),
Expand Down Expand Up @@ -135,7 +159,7 @@ class Function(Base, HasStruct):
updated = Column(TIMESTAMP)
labels = relationship(Label)

class Log(Base):
class Log(Base, BaseModel):
__tablename__ = "logs"

id = Column(Integer, primary_key=True)
Expand All @@ -161,7 +185,7 @@ class Run(Base, HasStruct):
start_time = Column(TIMESTAMP)
labels = relationship(Label)

class Schedule(Base):
class Schedule(Base, BaseModel):
__tablename__ = "schedules_v2"
__table_args__ = (UniqueConstraint("project", "name", name="_schedules_v2_uc"),)

Expand Down Expand Up @@ -203,14 +227,14 @@ def cron_trigger(self, trigger: schemas.ScheduleCronTrigger):
Column("user_id", Integer, ForeignKey("users.id")),
)

class User(Base):
class User(Base, BaseModel):
__tablename__ = "users"
__table_args__ = (UniqueConstraint("name", name="_users_uc"),)

id = Column(Integer, primary_key=True)
name = Column(String)

class Project(Base):
class Project(Base, BaseModel):
__tablename__ = "projects"
# For now since we use project name a lot
__table_args__ = (UniqueConstraint("name", name="_projects_uc"),)
Expand Down Expand Up @@ -240,7 +264,7 @@ def full_object(self):
def full_object(self, value):
self._full_object = pickle.dumps(value)

class Feature(Base):
class Feature(Base, BaseModel):
__tablename__ = "features"
id = Column(Integer, primary_key=True)
feature_set_id = Column(Integer, ForeignKey("feature_sets.id"))
Expand All @@ -251,7 +275,7 @@ class Feature(Base):
Label = make_label(__tablename__)
labels = relationship(Label, cascade="all, delete-orphan")

class Entity(Base):
class Entity(Base, BaseModel):
__tablename__ = "entities"
id = Column(Integer, primary_key=True)
feature_set_id = Column(Integer, ForeignKey("feature_sets.id"))
Expand All @@ -262,7 +286,7 @@ class Entity(Base):
Label = make_label(__tablename__)
labels = relationship(Label, cascade="all, delete-orphan")

class FeatureSet(Base):
class FeatureSet(Base, BaseModel):
__tablename__ = "feature_sets"
__table_args__ = (
UniqueConstraint("name", "project", "uid", name="_feature_set_uc"),
Expand Down Expand Up @@ -295,7 +319,7 @@ def full_object(self):
def full_object(self, value):
self._full_object = json.dumps(value)

class FeatureVector(Base):
class FeatureVector(Base, BaseModel):
__tablename__ = "feature_vectors"
__table_args__ = (
UniqueConstraint("name", "project", "uid", name="_feature_vectors_uc"),
Expand Down
90 changes: 89 additions & 1 deletion mlrun/api/initial_data.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import collections
import datetime
import os
import sqlalchemy.orm
import pathlib
import sqlalchemy.orm
import typing

from mlrun.api.db.init_db import init_db
import mlrun.api.db.sqldb.db
import mlrun.api.db.sqldb.models
import mlrun.api.schemas
from mlrun.api.db.session import create_session, close_session
from mlrun.utils import logger
Expand Down Expand Up @@ -34,6 +38,90 @@ def _perform_data_migrations(db_session: sqlalchemy.orm.Session):
db = mlrun.api.db.sqldb.db.SQLDB("")
logger.info("Performing data migrations")
_fill_project_state(db, db_session)
_fix_artifact_tags_duplications(db, db_session)


def _fix_artifact_tags_duplications(
db: mlrun.api.db.sqldb.db.SQLDB, db_session: sqlalchemy.orm.Session
):
# get all artifacts
artifacts = db._find_artifacts(db_session, None, "*")
# get all artifact tags
tags = db._query(db_session, mlrun.api.db.sqldb.models.Artifact.Tag).all()
# artifact record id -> artifact
artifact_record_id_map = {artifact.id: artifact for artifact in artifacts}
tags_to_delete = []
projects = {artifact.project for artifact in artifacts}
for project in projects:
artifact_keys = {
artifact.key for artifact in artifacts if artifact.project == project
}
for artifact_key in artifact_keys:
artifact_key_tags = []
for tag in tags:
# sanity
if tag.obj_id not in artifact_record_id_map:
logger.warning("Found orphan tag, deleting", tag=tag.to_dict())
if artifact_record_id_map[tag.obj_id].key == artifact_key:
artifact_key_tags.append(tag)
tag_name_tags_map = collections.defaultdict(list)
for tag in artifact_key_tags:
tag_name_tags_map[tag.name].append(tag)
for tag_name, _tags in tag_name_tags_map.items():
if len(_tags) == 1:
continue
tags_artifacts = [artifact_record_id_map[tag.obj_id] for tag in _tags]
last_updated_artifact = _find_last_updated_artifact(tags_artifacts)
for tag in _tags:
if tag.obj_id != last_updated_artifact.id:
tags_to_delete.append(tag)
if tags_to_delete:
logger.info(
"Found duplicated artifact tags. Removing duplications",
tags_to_delete=[
tag_to_delete.to_dict() for tag_to_delete in tags_to_delete
],
tags=[tag.to_dict() for tag in tags],
artifacts=[artifact.to_dict() for artifact in artifacts],
)
for tag in tags_to_delete:
db_session.delete(tag)
db_session.commit()


def _find_last_updated_artifact(
artifacts: typing.List[mlrun.api.db.sqldb.models.Artifact],
):
# sanity
if not artifacts:
raise RuntimeError("No artifacts given")
last_updated_artifact = None
last_updated_artifact_time = datetime.datetime.min
artifacts_with_same_update_time = []
for artifact in artifacts:
if artifact.updated > last_updated_artifact_time:
last_updated_artifact = artifact
last_updated_artifact_time = last_updated_artifact.updated
artifacts_with_same_update_time = [last_updated_artifact]
elif artifact.updated == last_updated_artifact_time:
artifacts_with_same_update_time.append(artifact)
if len(artifacts_with_same_update_time) > 1:
logger.warning(
"Found several artifact with same update time, heuristically choosing the first",
artifacts=[
artifact.to_dict() for artifact in artifacts_with_same_update_time
],
)
# we don't really need to do anything to choose the first, it's already happening because the first if is >
# and not >=
if not last_updated_artifact:
logger.warning(
"No artifact had update time, heuristically choosing the first",
artifacts=[artifact.to_dict() for artifact in artifacts],
)
last_updated_artifact = artifacts[0]

return last_updated_artifact


def _fill_project_state(
Expand Down

0 comments on commit e2361f3

Please sign in to comment.