Skip to content

Commit

Permalink
Fix cascading behavior for trace tag table (mlflow#12102)
Browse files Browse the repository at this point in the history
Signed-off-by: B-Step62 <yuki.watanabe@databricks.com>
  • Loading branch information
B-Step62 committed May 22, 2024
1 parent 30d9c3f commit 3d86e5f
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""add cascade deletion to trace tables foreign keys
Revision ID: 5b0e9adcef9c
Revises: 867495a8f9d4
Create Date: 2024-05-22 17:44:24.597019
"""
from alembic import op
from mlflow.store.tracking.dbmodels.models import SqlTraceInfo, SqlTraceRequestMetadata, SqlTraceTag


# revision identifiers, used by Alembic.
revision = '5b0e9adcef9c'
down_revision = '867495a8f9d4'
branch_labels = None
depends_on = None


def upgrade():
tables = [SqlTraceTag.__tablename__, SqlTraceRequestMetadata.__tablename__]
for table in tables:
fk_tag_constaint_name = f"fk_{table}_request_id"
# We have to use batch_alter_table as SQLite does not support ALTER outside of a batch operation.
with op.batch_alter_table(table, schema=None) as batch_op:
batch_op.drop_constraint(fk_tag_constaint_name, type_="foreignkey")
batch_op.create_foreign_key(
fk_tag_constaint_name,
SqlTraceInfo.__tablename__,
["request_id"],
["request_id"],
# Add cascade deletion to the foreign key constraint. This is the only change in this migration.
ondelete="CASCADE",
)


def downgrade():
pass
8 changes: 6 additions & 2 deletions mlflow/store/tracking/dbmodels/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,7 +706,9 @@ class SqlTraceTag(Base):
"""
Value associated with tag: `String` (limit 250 characters). Could be *null*.
"""
request_id = Column(String(50), ForeignKey("trace_info.request_id"), nullable=False)
request_id = Column(
String(50), ForeignKey("trace_info.request_id", ondelete="CASCADE"), nullable=False
)
"""
Request ID to which this tag belongs: *Foreign Key* into ``trace_info`` table.
"""
Expand Down Expand Up @@ -734,7 +736,9 @@ class SqlTraceRequestMetadata(Base):
"""
Value associated with metadata: `String` (limit 250 characters). Could be *null*.
"""
request_id = Column(String(50), ForeignKey("trace_info.request_id"), nullable=False)
request_id = Column(
String(50), ForeignKey("trace_info.request_id", ondelete="CASCADE"), nullable=False
)
"""
Request ID to which this metadata belongs: *Foreign Key* into ``trace_info`` table.
"""
Expand Down
2 changes: 1 addition & 1 deletion tests/db/schemas/mssql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ CREATE TABLE trace_request_metadata (
value VARCHAR(8000) COLLATE "SQL_Latin1_General_CP1_CI_AS",
request_id VARCHAR(50) COLLATE "SQL_Latin1_General_CP1_CI_AS" NOT NULL,
CONSTRAINT trace_request_metadata_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)


Expand Down
4 changes: 2 additions & 2 deletions tests/db/schemas/mysql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ CREATE TABLE trace_request_metadata (
value VARCHAR(8000),
request_id VARCHAR(50) NOT NULL,
CONSTRAINT trace_request_metadata_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)


Expand All @@ -207,5 +207,5 @@ CREATE TABLE trace_tags (
value VARCHAR(8000),
request_id VARCHAR(50) NOT NULL,
CONSTRAINT trace_tag_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_tags_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_tags_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)
4 changes: 2 additions & 2 deletions tests/db/schemas/postgresql.sql
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ CREATE TABLE trace_request_metadata (
value VARCHAR(8000),
request_id VARCHAR(50) NOT NULL,
CONSTRAINT trace_request_metadata_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)


Expand All @@ -205,5 +205,5 @@ CREATE TABLE trace_tags (
value VARCHAR(8000),
request_id VARCHAR(50) NOT NULL,
CONSTRAINT trace_tag_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_tags_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_tags_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)
4 changes: 2 additions & 2 deletions tests/db/schemas/sqlite.sql
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ CREATE TABLE trace_request_metadata (
value VARCHAR(8000),
request_id VARCHAR(50) NOT NULL,
CONSTRAINT trace_request_metadata_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)


Expand All @@ -208,5 +208,5 @@ CREATE TABLE trace_tags (
value VARCHAR(8000),
request_id VARCHAR(50) NOT NULL,
CONSTRAINT trace_tag_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_tags_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_tags_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)
4 changes: 2 additions & 2 deletions tests/resources/db/latest_schema.sql
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ CREATE TABLE trace_request_metadata (
value VARCHAR(8000),
request_id VARCHAR(50) NOT NULL,
CONSTRAINT trace_request_metadata_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_request_metadata_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)


Expand All @@ -208,6 +208,6 @@ CREATE TABLE trace_tags (
value VARCHAR(8000),
request_id VARCHAR(50) NOT NULL,
CONSTRAINT trace_tag_pk PRIMARY KEY (key, request_id),
CONSTRAINT fk_trace_tags_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id)
CONSTRAINT fk_trace_tags_request_id FOREIGN KEY(request_id) REFERENCES trace_info (request_id) ON DELETE CASCADE
)

8 changes: 6 additions & 2 deletions tests/store/tracking/test_sqlalchemy_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4148,8 +4148,12 @@ def test_delete_traces(store):
exp2 = store.create_experiment("exp2")

for i in range(10):
_create_trace(store, f"tr-exp1-{i}", exp1)
_create_trace(store, f"tr-exp2-{i}", exp2)
_create_trace(
store, f"tr-exp1-{i}", exp1, tags={"tag": "apple"}, request_metadata={"rq": "foo"}
)
_create_trace(
store, f"tr-exp2-{i}", exp2, tags={"tag": "orange"}, request_metadata={"rq": "bar"}
)

traces, _ = store.search_traces([exp1, exp2])
assert len(traces) == 20
Expand Down

0 comments on commit 3d86e5f

Please sign in to comment.