Skip to content

Commit

Permalink
[SQLDB] Refactor partition querying to minimize subquery table size […
Browse files Browse the repository at this point in the history
…1.6.x] (#5540)
  • Loading branch information
alonmr committed May 12, 2024
1 parent 270fcd5 commit 5223d5d
Showing 1 changed file with 15 additions and 12 deletions.
27 changes: 15 additions & 12 deletions server/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import pytz
from sqlalchemy import MetaData, and_, distinct, func, or_, text
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session, aliased
from sqlalchemy.orm import Session

import mlrun
import mlrun.common.schemas
Expand Down Expand Up @@ -2848,6 +2848,10 @@ def _create_partitioned_query(
)
.label("row_number")
)

# Retrieve only the ID from the subquery to minimize the inner table,
# in the final step we inner join the inner table with the full table.
query = query.with_entities(cls.id).add_column(row_number_column)
if max_partitions > 0:
max_partition_value = (
func.max(sort_by_field)
Expand All @@ -2860,18 +2864,15 @@ def _create_partitioned_query(

# Need to generate a subquery so we can filter based on the row_number, since it
# is a window function using over().
subquery = query.add_column(row_number_column).subquery()

subquery = query.subquery()
if max_partitions == 0:
# If we don't query on max-partitions, we end here. Need to alias the subquery so that the ORM will
# be able to properly map it to objects.
result_query = session.query(aliased(cls, subquery)).filter(
subquery.c.row_number <= rows_per_partition
result_query = (
session.query(cls)
.join(subquery, cls.id == subquery.c.id)
.filter(subquery.c.row_number <= rows_per_partition)
)
return result_query

# Otherwise no need for an alias, as this is an internal query and will be wrapped by another one where
# alias will apply. We just apply the filter here.
result_query = session.query(subquery).filter(
subquery.c.row_number <= rows_per_partition
)
Expand All @@ -2883,9 +2884,11 @@ def _create_partitioned_query(
.over(order_by=subquery.c.max_partition_value.desc())
.label("partition_rank")
)
result_query = result_query.add_column(partition_rank).subquery()
result_query = session.query(aliased(cls, result_query)).filter(
result_query.c.partition_rank <= max_partitions
subquery = result_query.add_column(partition_rank).subquery()
result_query = (
session.query(cls)
.join(subquery, cls.id == subquery.c.id)
.filter(subquery.c.partition_rank <= max_partitions)
)
return result_query

Expand Down

0 comments on commit 5223d5d

Please sign in to comment.