Skip to content

Commit

Permalink
[Notifications] making DB notification objects more generic (#3678)
Browse files Browse the repository at this point in the history
  • Loading branch information
theSaarco committed Jun 1, 2023
1 parent aae8e70 commit 2b73d4d
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 99 deletions.
46 changes: 31 additions & 15 deletions mlrun/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,13 +56,13 @@
Function,
HubSource,
Log,
Notification,
Project,
Run,
Schedule,
User,
_labeled,
_tagged,
_with_notifications,
)
from mlrun.config import config
from mlrun.errors import err_to_str
Expand Down Expand Up @@ -382,7 +382,7 @@ def list_runs(

# Purposefully not using outer join to avoid returning runs without notifications
if with_notifications:
query = query.join(Notification, Run.id == Notification.run)
query = query.join(Run.Notification)

runs = RunList()
for run in query:
Expand Down Expand Up @@ -1769,7 +1769,9 @@ def verify_project_has_no_related_resources(self, session: Session, name: str):
self._verify_empty_list_of_project_related_resources(name, logs, "logs")
runs = self._find_runs(session, None, name, []).all()
self._verify_empty_list_of_project_related_resources(name, runs, "runs")
notifications = self._get_db_notifications(session, project=name)
notifications = []
for cls in _with_notifications:
notifications.extend(self._get_db_notifications(session, cls, project=name))
self._verify_empty_list_of_project_related_resources(
name, notifications, "notifications"
)
Expand Down Expand Up @@ -2914,10 +2916,10 @@ def _find_runs(self, session, uid, project, labels):
return self._add_labels_filter(session, query, Run, labels)

def _get_db_notifications(
self, session, name: str = None, run_id: int = None, project: str = None
self, session, cls, name: str = None, parent_id: int = None, project: str = None
):
return self._query(
session, Notification, name=name, run=run_id, project=project
session, cls.Notification, name=name, parent_id=parent_id, project=project
).all()

def _latest_uid_filter(self, session, query):
Expand Down Expand Up @@ -3260,7 +3262,7 @@ def _transform_project_record_to_schema(

def _transform_notification_record_to_spec_and_status(
self,
notification_record: Notification,
notification_record,
) -> typing.Tuple[dict, dict]:
notification_spec = self._transform_notification_record_to_schema(
notification_record
Expand All @@ -3273,7 +3275,7 @@ def _transform_notification_record_to_spec_and_status(

@staticmethod
def _transform_notification_record_to_schema(
notification_record: Notification,
notification_record,
) -> mlrun.model.Notification:
return mlrun.model.Notification(
kind=notification_record.kind,
Expand Down Expand Up @@ -3666,18 +3668,30 @@ def store_run_notifications(
f"Run not found: uid={run_uid}, project={project}"
)

run_notifications = {
self._store_notifications(session, Run, notification_objects, run.id, project)

def _store_notifications(
self,
session,
cls,
notification_objects: typing.List[mlrun.model.Notification],
parent_id: str,
project: str,
):
db_notifications = {
notification.name: notification
for notification in self._get_db_notifications(session, run_id=run.id)
for notification in self._get_db_notifications(
session, cls, parent_id=parent_id
)
}
notifications = []
for notification_model in notification_objects:
new_notification = False
notification = run_notifications.get(notification_model.name, None)
notification = db_notifications.get(notification_model.name, None)
if not notification:
new_notification = True
notification = Notification(
name=notification_model.name, run=run.id, project=project
notification = cls.Notification(
name=notification_model.name, parent_id=parent_id, project=project
)

notification.kind = notification_model.kind
Expand All @@ -3695,7 +3709,7 @@ def store_run_notifications(
logger.debug(
f"Storing {'new' if new_notification else 'existing'} notification",
notification_name=notification.name,
run_uid=run_uid,
parent_id=parent_id,
project=project,
)
notifications.append(notification)
Expand All @@ -3716,7 +3730,9 @@ def list_run_notifications(

return [
self._transform_notification_record_to_schema(notification)
for notification in self._query(session, Notification, run=run.id).all()
for notification in self._query(
session, Run.Notification, parent_id=run.id
).all()
]

def delete_run_notifications(
Expand All @@ -3742,7 +3758,7 @@ def delete_run_notifications(
if project == "*":
project = None

query = self._get_db_notifications(session, name, run_id, project)
query = self._get_db_notifications(session, Run, name, run_id, project)
for notification in query:
session.delete(notification)

Expand Down
4 changes: 2 additions & 2 deletions mlrun/api/db/sqldb/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from .models_mysql import * # noqa

# importing private variables as well
from .models_mysql import _classes, _labeled, _table2cls, _tagged # noqa # isort:skip
from .models_mysql import _classes, _labeled, _table2cls, _tagged, _with_notifications # noqa # isort:skip
else:
from .models_sqlite import * # noqa

# importing private variables as well
from .models_sqlite import _classes, _labeled, _table2cls, _tagged # noqa # isort:skip
from .models_sqlite import _classes, _labeled, _table2cls, _tagged, _with_notifications # noqa # isort:skip
# fmt: on
90 changes: 50 additions & 40 deletions mlrun/api/db/sqldb/models/models_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,52 @@ class Tag(Base, mlrun.utils.db.BaseModel):
return Tag


def make_notification(table):
class Notification(Base, mlrun.utils.db.BaseModel):
__tablename__ = f"{table}_notifications"
__table_args__ = (
UniqueConstraint("name", "parent_id", name=f"_{table}_notifications_uc"),
)

id = Column(Integer, primary_key=True)
project = Column(String(255, collation=SQLCollationUtil.collation()))
name = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
kind = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
message = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
severity = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
when = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
condition = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
params = Column("params", JSON)
parent_id = Column(Integer, ForeignKey(f"{table}.id"))

# TODO: Separate table for notification state.
# Currently, we are only supporting one notification being sent per DB row (either on completion or on error).
# In the future, we might want to support multiple notifications per DB row, and we might want to support on
# start, therefore we need to separate the state from the notification itself (e.g. this table can be table
# with notification_id, state, when, last_sent, etc.). This will require some refactoring in the code.
sent_time = Column(
sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3),
nullable=True,
)
status = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)

return Notification


# quell SQLAlchemy warnings on duplicate class name (Label)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down Expand Up @@ -139,46 +185,6 @@ class Function(Base, mlrun.utils.db.HasStruct):
def get_identifier_string(self) -> str:
return f"{self.project}/{self.name}/{self.uid}"

class Notification(Base, mlrun.utils.db.BaseModel):
__tablename__ = "notifications"
__table_args__ = (UniqueConstraint("name", "run", name="_notifications_uc"),)

id = Column(Integer, primary_key=True)
project = Column(String(255, collation=SQLCollationUtil.collation()))
name = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
kind = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
message = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
severity = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
when = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
condition = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
params = Column("params", JSON)
run = Column(Integer, ForeignKey("runs.id"))

# TODO: Separate table for notification state.
# Currently, we are only supporting one notification being sent per DB row (either on completion or on error).
# In the future, we might want to support multiple notifications per DB row, and we might want to support on
# start, therefore we need to separate the state from the notification itself (e.g. this table can be table
# with notification_id, state, when, last_sent, etc.). This will require some refactoring in the code.
sent_time = Column(
sqlalchemy.dialects.mysql.TIMESTAMP(fsp=3),
nullable=True,
)
status = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)

class Log(Base, mlrun.utils.db.BaseModel):
__tablename__ = "logs"

Expand All @@ -199,6 +205,7 @@ class Run(Base, mlrun.utils.db.HasStruct):

Label = make_label(__tablename__)
Tag = make_tag(__tablename__)
Notification = make_notification(__tablename__)

id = Column(Integer, primary_key=True)
uid = Column(String(255, collation=SQLCollationUtil.collation()))
Expand Down Expand Up @@ -505,5 +512,8 @@ class DataVersion(Base, mlrun.utils.db.BaseModel):
# Must be after all table definitions
_tagged = [cls for cls in Base.__subclasses__() if hasattr(cls, "Tag")]
_labeled = [cls for cls in Base.__subclasses__() if hasattr(cls, "Label")]
_with_notifications = [
cls for cls in Base.__subclasses__() if hasattr(cls, "Notification")
]
_classes = [cls for cls in Base.__subclasses__()]
_table2cls = {cls.__table__.name: cls for cls in Base.__subclasses__()}
78 changes: 44 additions & 34 deletions mlrun/api/db/sqldb/models/models_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,46 @@ class Tag(Base, mlrun.utils.db.BaseModel):
return Tag


def make_notification(table):
class Notification(Base, mlrun.utils.db.BaseModel):
__tablename__ = f"{table}_notifications"
__table_args__ = (
UniqueConstraint("name", "parent_id", name=f"_{table}_notifications_uc"),
)

id = Column(Integer, primary_key=True)
project = Column(String(255, collation=SQLCollationUtil.collation()))
name = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
kind = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
message = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
severity = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
when = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
condition = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
params = Column("params", JSON)
parent_id = Column(Integer, ForeignKey(f"{table}.id"))
sent_time = Column(
TIMESTAMP(),
nullable=True,
)
status = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)

return Notification


# quell SQLAlchemy warnings on duplicate class name (Label)
with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand Down Expand Up @@ -151,40 +191,6 @@ class Log(Base, mlrun.utils.db.BaseModel):
def get_identifier_string(self) -> str:
return f"{self.project}/{self.uid}"

class Notification(Base, mlrun.utils.db.BaseModel):
__tablename__ = "notifications"
__table_args__ = (UniqueConstraint("name", "run", name="_notifications_uc"),)

id = Column(Integer, primary_key=True)
project = Column(String(255, collation=SQLCollationUtil.collation()))
name = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
kind = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
message = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
severity = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
when = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
condition = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)
params = Column("params", JSON)
run = Column(Integer, ForeignKey("runs.id"))
sent_time = Column(
TIMESTAMP(),
nullable=True,
)
status = Column(
String(255, collation=SQLCollationUtil.collation()), nullable=False
)

class Run(Base, mlrun.utils.db.HasStruct):
__tablename__ = "runs"
__table_args__ = (
Expand All @@ -193,6 +199,7 @@ class Run(Base, mlrun.utils.db.HasStruct):

Label = make_label(__tablename__)
Tag = make_tag(__tablename__)
Notification = make_notification(__tablename__)

id = Column(Integer, primary_key=True)
uid = Column(String(255, collation=SQLCollationUtil.collation()))
Expand Down Expand Up @@ -461,5 +468,8 @@ class DataVersion(Base, mlrun.utils.db.BaseModel):
# Must be after all table definitions
_tagged = [cls for cls in Base.__subclasses__() if hasattr(cls, "Tag")]
_labeled = [cls for cls in Base.__subclasses__() if hasattr(cls, "Label")]
_with_notifications = [
cls for cls in Base.__subclasses__() if hasattr(cls, "Notification")
]
_classes = [cls for cls in Base.__subclasses__()]
_table2cls = {cls.__table__.name: cls for cls in Base.__subclasses__()}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"notifications",
"runs_notifications",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("project", sa.String(length=255, collation="utf8_bin")),
sa.Column("name", sa.String(length=255, collation="utf8_bin"), nullable=False),
Expand All @@ -49,17 +49,19 @@ def upgrade():
"condition", sa.String(length=255, collation="utf8_bin"), nullable=False
),
sa.Column("params", sa.JSON(), nullable=True),
sa.Column("run", sa.Integer(), nullable=True),
# A generic parent_id rather than run_id since notification table is standard across objects, see the
# make_notification function for its definition and usage.
sa.Column("parent_id", sa.Integer(), nullable=True),
sa.Column("sent_time", mysql.TIMESTAMP(fsp=3), nullable=True),
sa.Column(
"status", sa.String(length=255, collation="utf8_bin"), nullable=False
),
sa.ForeignKeyConstraint(
["run"],
["parent_id"],
["runs.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name", "run", name="_notifications_uc"),
sa.UniqueConstraint("name", "parent_id", name="_runs_notifications_uc"),
)
# ### end Alembic commands ###

Expand Down

0 comments on commit 2b73d4d

Please sign in to comment.