Skip to content

Commit

Permalink
[API] Add partition-by option to list APIs for feature-sets and featu…
Browse files Browse the repository at this point in the history
…re-vectors (#806)
  • Loading branch information
theSaarco committed Apr 9, 2021
1 parent 91f5a82 commit ba52a71
Show file tree
Hide file tree
Showing 14 changed files with 359 additions and 4 deletions.
36 changes: 34 additions & 2 deletions mlrun/api/api/endpoints/feature_sets.py
Expand Up @@ -111,10 +111,27 @@ def list_feature_sets(
entities: List[str] = Query(None, alias="entity"),
features: List[str] = Query(None, alias="feature"),
labels: List[str] = Query(None, alias="label"),
partition_by: schemas.FeatureStorePartitionByField = Query(
None, alias="partition-by"
),
rows_per_partition: int = Query(1, alias="rows-per-partition", gt=0),
sort: schemas.SortField = Query(None, alias="partition-sort-by"),
order: schemas.OrderType = Query(schemas.OrderType.desc, alias="partition-order"),
db_session: Session = Depends(deps.get_db_session),
):
feature_sets = get_db().list_feature_sets(
db_session, project, name, tag, state, entities, features, labels
db_session,
project,
name,
tag,
state,
entities,
features,
labels,
partition_by,
rows_per_partition,
sort,
order,
)

return feature_sets
Expand Down Expand Up @@ -190,10 +207,25 @@ def list_feature_vectors(
state: str = None,
tag: str = None,
labels: List[str] = Query(None, alias="label"),
partition_by: schemas.FeatureStorePartitionByField = Query(
None, alias="partition-by"
),
rows_per_partition: int = Query(1, alias="rows-per-partition", gt=0),
sort: schemas.SortField = Query(None, alias="partition-sort-by"),
order: schemas.OrderType = Query(schemas.OrderType.desc, alias="partition-order"),
db_session: Session = Depends(deps.get_db_session),
):
feature_vectors = get_db().list_feature_vectors(
db_session, project, name, tag, state, labels
db_session,
project,
name,
tag,
state,
labels,
partition_by,
rows_per_partition,
sort,
order,
)

return feature_vectors
Expand Down
8 changes: 8 additions & 0 deletions mlrun/api/db/base.py
Expand Up @@ -286,6 +286,10 @@ def list_feature_sets(
entities: List[str] = None,
features: List[str] = None,
labels: List[str] = None,
partition_by: schemas.FeatureStorePartitionByField = None,
rows_per_partition: int = 1,
partition_sort: schemas.SortField = None,
partition_order: schemas.OrderType = schemas.OrderType.desc,
) -> schemas.FeatureSetsOutput:
pass

Expand Down Expand Up @@ -327,6 +331,10 @@ def list_feature_vectors(
tag: str = None,
state: str = None,
labels: List[str] = None,
partition_by: schemas.FeatureStorePartitionByField = None,
rows_per_partition: int = 1,
partition_sort_by: schemas.SortField = None,
partition_order: schemas.OrderType = schemas.OrderType.desc,
) -> schemas.FeatureVectorsOutput:
pass

Expand Down
8 changes: 8 additions & 0 deletions mlrun/api/db/filedb/db.py
Expand Up @@ -227,6 +227,10 @@ def list_feature_sets(
entities: List[str] = None,
features: List[str] = None,
labels: List[str] = None,
partition_by: schemas.FeatureStorePartitionByField = None,
rows_per_partition: int = 1,
partition_sort: schemas.SortField = None,
partition_order: schemas.OrderType = schemas.OrderType.desc,
) -> schemas.FeatureSetsOutput:
raise NotImplementedError()

Expand Down Expand Up @@ -263,6 +267,10 @@ def list_feature_vectors(
tag: str = None,
state: str = None,
labels: List[str] = None,
partition_by: schemas.FeatureStorePartitionByField = None,
rows_per_partition: int = 1,
partition_sort_by: schemas.SortField = None,
partition_order: schemas.OrderType = schemas.OrderType.desc,
) -> schemas.FeatureVectorsOutput:
raise NotImplementedError()

Expand Down
62 changes: 61 additions & 1 deletion mlrun/api/db/sqldb/db.py
Expand Up @@ -8,7 +8,7 @@
import pytz
from sqlalchemy import and_, distinct, func, or_
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import Session
from sqlalchemy.orm import Session, aliased

import mlrun.api.utils.projects.remotes.member
import mlrun.errors
Expand Down Expand Up @@ -1323,6 +1323,36 @@ def list_entities(
)
return schemas.EntitiesOutput(entities=entities_results)

@staticmethod
def _assert_partition_by_parameters(partition_by, sort):
if sort is None:
raise mlrun.errors.MLRunInvalidArgumentError(
"sort parameter must be provided when partition_by is used."
)
# For now, name is the only supported value. Remove once more fields are added.
if partition_by != schemas.FeatureStorePartitionByField.name:
raise mlrun.errors.MLRunInvalidArgumentError(
f"partition_by for feature-store objects must be 'name'. Value given: '{partition_by.value}'"
)

@staticmethod
def _create_partitioned_query(session, query, cls, group_by, order, rows_per_group):
row_number_column = (
func.row_number()
.over(
partition_by=group_by.to_partition_by_db_field(cls),
order_by=order.to_order_by_predicate(cls.updated),
)
.label("row_number")
)

# Need to generate a subquery so we can filter based on the row_number, since it
# is a window function using over().
subquery = query.add_column(row_number_column).subquery()
return session.query(aliased(cls, subquery)).filter(
subquery.c.row_number <= rows_per_group
)

def list_feature_sets(
self,
session,
Expand All @@ -1333,6 +1363,10 @@ def list_feature_sets(
entities: List[str] = None,
features: List[str] = None,
labels: List[str] = None,
partition_by: schemas.FeatureStorePartitionByField = None,
rows_per_partition: int = 1,
partition_sort: schemas.SortField = None,
partition_order: schemas.OrderType = schemas.OrderType.desc,
) -> schemas.FeatureSetsOutput:
obj_id_tags = self._get_records_to_tags_map(
session, FeatureSet, project, tag, name
Expand All @@ -1352,6 +1386,17 @@ def list_feature_sets(
if labels:
query = self._add_labels_filter(session, query, FeatureSet, labels)

if partition_by:
self._assert_partition_by_parameters(partition_by, partition_sort)
query = self._create_partitioned_query(
session,
query,
FeatureSet,
partition_by,
partition_order,
rows_per_partition,
)

feature_sets = []
for feature_set_record in query:
feature_sets.extend(
Expand Down Expand Up @@ -1712,6 +1757,10 @@ def list_feature_vectors(
tag: str = None,
state: str = None,
labels: List[str] = None,
partition_by: schemas.FeatureStorePartitionByField = None,
rows_per_partition: int = 1,
partition_sort_by: schemas.SortField = None,
partition_order: schemas.OrderType = schemas.OrderType.desc,
) -> schemas.FeatureVectorsOutput:
obj_id_tags = self._get_records_to_tags_map(
session, FeatureVector, project, tag, name
Expand All @@ -1727,6 +1776,17 @@ def list_feature_vectors(
if labels:
query = self._add_labels_filter(session, query, FeatureVector, labels)

if partition_by:
self._assert_partition_by_parameters(partition_by, partition_sort_by)
query = self._create_partitioned_query(
session,
query,
FeatureVector,
partition_by,
partition_order,
rows_per_partition,
)

feature_vectors = []
for feature_vector_record in query:
feature_vectors.extend(
Expand Down
10 changes: 9 additions & 1 deletion mlrun/api/schemas/__init__.py
Expand Up @@ -8,7 +8,15 @@
BackgroundTaskState,
BackgroundTaskStatus,
)
from .constants import DeletionStrategy, Format, HeaderNames, PatchMode
from .constants import (
DeletionStrategy,
FeatureStorePartitionByField,
Format,
HeaderNames,
OrderType,
PatchMode,
SortField,
)
from .feature_store import (
EntitiesOutput,
Entity,
Expand Down
28 changes: 28 additions & 0 deletions mlrun/api/schemas/constants.py
Expand Up @@ -55,6 +55,34 @@ class HeaderNames:
secret_store_token = f"{headers_prefix}secret-store-token"


class FeatureStorePartitionByField(str, Enum):
name = "name" # Supported for feature-store objects

def to_partition_by_db_field(self, db_cls):
if self.value == FeatureStorePartitionByField.name:
return db_cls.name
else:
raise mlrun.errors.MLRunInvalidArgumentError(
f"Unknown group by field: {self.value}"
)


# For now, we only support sorting by updated field
class SortField(str, Enum):
updated = "updated"


class OrderType(str, Enum):
asc = "asc"
desc = "desc"

def to_order_by_predicate(self, db_field):
if self.value == OrderType.asc:
return db_field.asc()
else:
return db_field.desc()


labels_prefix = "mlrun/"


Expand Down
8 changes: 8 additions & 0 deletions mlrun/db/base.py
Expand Up @@ -204,6 +204,10 @@ def list_feature_sets(
entities: List[str] = None,
features: List[str] = None,
labels: List[str] = None,
partition_by: Union[schemas.FeatureStorePartitionByField, str] = None,
rows_per_partition: int = 1,
partition_sort_by: Union[schemas.SortField, str] = None,
partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc,
) -> List[dict]:
pass

Expand Down Expand Up @@ -258,6 +262,10 @@ def list_feature_vectors(
tag: str = None,
state: str = None,
labels: List[str] = None,
partition_by: Union[schemas.FeatureStorePartitionByField, str] = None,
rows_per_partition: int = 1,
partition_sort_by: Union[schemas.SortField, str] = None,
partition_order: Union[schemas.OrderType, str] = schemas.OrderType.desc,
) -> List[dict]:
pass

Expand Down
8 changes: 8 additions & 0 deletions mlrun/db/filedb.py
Expand Up @@ -553,6 +553,10 @@ def list_feature_sets(
entities: List[str] = None,
features: List[str] = None,
labels: List[str] = None,
partition_by: str = None,
rows_per_partition: int = 1,
partition_sort_by: str = None,
partition_order: str = "desc",
):
raise NotImplementedError()

Expand Down Expand Up @@ -584,6 +588,10 @@ def list_feature_vectors(
tag: str = None,
state: str = None,
labels: List[str] = None,
partition_by: str = None,
rows_per_partition: int = 1,
partition_sort_by: str = None,
partition_order: str = "desc",
) -> List[dict]:
raise NotImplementedError()

Expand Down

0 comments on commit ba52a71

Please sign in to comment.