Skip to content

Commit

Permalink
support filter for story engine
Browse files Browse the repository at this point in the history
  • Loading branch information
davesh0812 committed May 11, 2022
1 parent ea7438f commit a7f10ee
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 45 deletions.
14 changes: 4 additions & 10 deletions mlrun/feature_store/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def get_offline_features(
update_stats: bool = False,
engine: str = None,
engine_args: dict = None,
filter: str = None
) -> OfflineVectorResponse:
"""retrieve offline feature vector results
Expand Down Expand Up @@ -177,16 +178,9 @@ def get_offline_features(
end_time = pd.Timestamp.now()
merger_engine = get_merger(engine)
merger = merger_engine(feature_vector, **(engine_args or {}))
return merger.start(
entity_rows,
entity_timestamp_column,
target=target,
drop_columns=drop_columns,
start_time=start_time,
end_time=end_time,
with_indexes=with_indexes,
update_stats=update_stats,
)
return merger.start(entity_rows, entity_timestamp_column, target=target, drop_columns=drop_columns,
start_time=start_time, end_time=end_time, with_indexes=with_indexes, update_stats=update_stats,
filter=filter)


def get_online_feature_service(
Expand Down
13 changes: 5 additions & 8 deletions mlrun/feature_store/retrieval/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def start(
end_time=None,
with_indexes=None,
update_stats=None,
filter=None
):
self._target = target

Expand Down Expand Up @@ -69,14 +70,9 @@ def start(
for key in feature_set.spec.entities.keys():
self._append_index(key)

return self._generate_vector(
entity_rows,
entity_timestamp_column,
feature_set_objects=feature_set_objects,
feature_set_fields=feature_set_fields,
start_time=start_time,
end_time=end_time,
)
return self._generate_vector(entity_rows, entity_timestamp_column, feature_set_objects=feature_set_objects,
feature_set_fields=feature_set_fields, start_time=start_time, end_time=end_time,
filter=filter)

def _write_to_target(self):
if self._target:
Expand Down Expand Up @@ -118,6 +114,7 @@ def _generate_vector(
feature_set_fields,
start_time=None,
end_time=None,
filter=None
):
raise NotImplementedError("_generate_vector() operation not supported in class")

Expand Down
11 changes: 2 additions & 9 deletions mlrun/feature_store/retrieval/dask_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,8 @@ def __init__(self, vector, **engine_args):
self.client = engine_args.get("dask_client")
self._dask_cluster_uri = engine_args.get("dask_cluster_uri")

def _generate_vector(
self,
entity_rows,
entity_timestamp_column,
feature_set_objects,
feature_set_fields,
start_time=None,
end_time=None,
):
def _generate_vector(self, entity_rows, entity_timestamp_column, feature_set_objects, feature_set_fields,
start_time=None, end_time=None, filter=None):
# init the dask client if needed
if not self.client:
if self._dask_cluster_uri:
Expand Down
14 changes: 5 additions & 9 deletions mlrun/feature_store/retrieval/local_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,8 @@ class LocalFeatureMerger(BaseMerger):
def __init__(self, vector, **engine_args):
super().__init__(vector, **engine_args)

def _generate_vector(
self,
entity_rows,
entity_timestamp_column,
feature_set_objects,
feature_set_fields,
start_time=None,
end_time=None,
):
def _generate_vector(self, entity_rows, entity_timestamp_column, feature_set_objects, feature_set_fields,
start_time=None, end_time=None, filter=None):

feature_sets = []
dfs = []
Expand Down Expand Up @@ -70,6 +63,9 @@ def _generate_vector(
subset=[self.vector.status.label_column]
)

if filter:
self._result_df.query(filter, inplace=True)

if self._drop_indexes:
self._result_df.reset_index(drop=True, inplace=True)
self._write_to_target()
Expand Down
11 changes: 2 additions & 9 deletions mlrun/feature_store/retrieval/spark_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,8 @@ def __init__(self, vector, **engine_args):
def to_spark_df(self, session, path):
return session.read.load(path)

def _generate_vector(
self,
entity_rows,
entity_timestamp_column,
feature_set_objects,
feature_set_fields,
start_time=None,
end_time=None,
):
def _generate_vector(self, entity_rows, entity_timestamp_column, feature_set_objects, feature_set_fields,
start_time=None, end_time=None, filter=None):
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

Expand Down

0 comments on commit a7f10ee

Please sign in to comment.