Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Store] Fix additional filters CSVSource bug #5639

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions mlrun/feature_store/feature_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,7 @@ def get_offline_features(
order_by: Union[str, list[str]] = None,
spark_service: str = None,
timestamp_for_filtering: Union[str, dict[str, str]] = None,
additional_filters: list = None,
):
"""retrieve offline feature vector results

Expand Down Expand Up @@ -797,6 +798,12 @@ def get_offline_features(
By default, the filter executes on the timestamp_key of each feature set.
Note: the time filtering is performed on each feature set before the
merge process using start_time and end_time params.
:param additional_filters: List of additional_filter conditions as tuples.
Each tuple should be in the format (column_name, operator, value).
Supported operators: "=", ">=", "<=", ">", "<".
Example: [("Product", "=", "Computer")]
For all supported filters, please see:
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetDataset.html

"""

Expand All @@ -817,6 +824,7 @@ def get_offline_features(
order_by,
spark_service,
timestamp_for_filtering,
additional_filters,
)

def get_online_feature_service(
Expand Down
11 changes: 10 additions & 1 deletion mlrun/feature_store/retrieval/spark_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@
import semver

import mlrun
from mlrun.datastore.sources import ParquetSource
from mlrun.datastore.targets import get_offline_target
from mlrun.utils.helpers import additional_filters_warning

from ...runtimes import RemoteSparkRuntime
from ...runtimes.sparkjob import Spark3Runtime
Expand Down Expand Up @@ -254,13 +256,20 @@ def _get_engine_df(
# handling case where there are multiple feature sets and user creates vector where
# entity_timestamp_column is from a specific feature set (can't be entity timestamp)
source_driver = mlrun.datastore.sources.source_kind_to_driver[source_kind]

if source_driver != ParquetSource:
additional_filters_warning(additional_filters, source_driver)
additional_filters = None
additional_filters_dict = (
{"additional_filters": additional_filters} if additional_filters else {}
)
source = source_driver(
name=self.vector.metadata.name,
path=source_path,
time_field=time_column,
start_time=start_time,
end_time=end_time,
additional_filters=additional_filters,
**additional_filters_dict,
**source_kwargs,
)

Expand Down
3 changes: 3 additions & 0 deletions tests/feature-store/test_featurevec.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_get_offline_features(mock_get_offline_features):
test_order_by = "col1"
test_spark_service = "test_spark_service"
test_timestamp_for_filtering = {"col1": "2021-01-01"}
additional_filters = [("x", "=", 3)]

fv.get_offline_features(
entity_rows=test_entity_rows,
Expand All @@ -87,6 +88,7 @@ def test_get_offline_features(mock_get_offline_features):
order_by=test_order_by,
spark_service=test_spark_service,
timestamp_for_filtering=test_timestamp_for_filtering,
additional_filters=additional_filters,
)
mock_get_offline_features.assert_called_once_with(
fv,
Expand All @@ -105,4 +107,5 @@ def test_get_offline_features(mock_get_offline_features):
test_order_by,
test_spark_service,
test_timestamp_for_filtering,
additional_filters,
)
21 changes: 21 additions & 0 deletions tests/system/feature_store/test_spark_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,27 @@ def test_parquet_filters(self, passthrough):
result = sort_df(result, ["patient_id"])
assert_frame_equal(result, expected, check_dtype=False)

target = ParquetTarget(
"vector_target", path=f"{self.output_dir()}-get_offline_features_by_vector"
)
# check get_offline_features vector function:
resp = vec.get_offline_features(
additional_filters=[
("bad", "not in", [38, 100]),
("movements", "<", 6),
],
with_indexes=True,
target=target,
engine="spark",
run_config=fstore.RunConfig(local=self.run_local, kind=kind),
spark_service=self.spark_service,
)

result = resp.to_dataframe()
result.reset_index(drop=False, inplace=True)
result = sort_df(result, ["patient_id"])
assert_frame_equal(result, expected, check_dtype=False)

def test_basic_remote_spark_ingest_csv(self):
key = "patient_id"
name = "measurements"
Expand Down
Loading