Skip to content

Commit

Permalink
[FeatureStore] Fix additional filters with passthrough (#5687)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomerm-iguazio committed Jun 2, 2024
1 parent 02a043b commit bd29590
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 16 deletions.
1 change: 1 addition & 0 deletions mlrun/feature_store/retrieval/spark_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,7 @@ def _get_engine_df(
source_kind = feature_set.spec.source.kind
source_path = feature_set.spec.source.path
source_kwargs.update(feature_set.spec.source.attributes)
source_kwargs.pop("additional_filters")
else:
target = get_offline_target(feature_set)
if not target:
Expand Down
31 changes: 21 additions & 10 deletions tests/system/feature_store/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4820,7 +4820,12 @@ def test_merge_different_number_of_entities(self):

@pytest.mark.parametrize("local", [True, False])
@pytest.mark.parametrize("engine", ["local", "dask"])
def test_parquet_filters(self, engine, local):
@pytest.mark.parametrize("passthrough", [True, False])
def test_parquet_filters(self, engine, local, passthrough):
if passthrough and engine == "dask":
pytest.skip(
"Dask engine with passthrough=True is not supported. Open issue ML-6684"
)
config_parameters = {} if local else {"image": "mlrun/mlrun"}
run_config = fstore.RunConfig(local=local, **config_parameters)
parquet_path = os.path.relpath(str(self.assets_path / "testdata.parquet"))
Expand All @@ -4843,7 +4848,9 @@ def test_parquet_filters(self, engine, local):
filtered_df.sort_values(by="patient_id").reset_index(drop=True),
)
feature_set = fstore.FeatureSet(
"parquet-filters-fs", entities=[fstore.Entity("patient_id")]
"parquet-filters-fs",
entities=[fstore.Entity("patient_id")],
passthrough=passthrough,
)

target = ParquetTarget(
Expand All @@ -4855,13 +4862,16 @@ def test_parquet_filters(self, engine, local):
feature_set.ingest(
source=parquet_source, targets=[target], run_config=run_config
)
result = target.as_df(additional_filters=[("room", "=", 1)]).reset_index()
# We want to include patient_id in the comparison,
# sort the columns alphabetically, and sort the rows by patient_id values.
result = sort_df(result, "patient_id")
expected = sort_df(filtered_df.query("room == 1"), "patient_id")
# the content of category column is still checked:
assert_frame_equal(result, expected, check_dtype=False, check_categorical=False)
if not passthrough:
result = target.as_df(additional_filters=[("room", "=", 1)]).reset_index()
# We want to include patient_id in the comparison,
# sort the columns alphabetically, and sort the rows by patient_id values.
result = sort_df(result, "patient_id")
expected = sort_df(filtered_df.query("room == 1"), "patient_id")
# the content of category column is still checked:
assert_frame_equal(
result, expected, check_dtype=False, check_categorical=False
)
vec = fstore.FeatureVector(
name="test-fs-vec", features=["parquet-filters-fs.*"]
)
Expand All @@ -4881,7 +4891,8 @@ def test_parquet_filters(self, engine, local):
.to_dataframe()
.reset_index()
)
expected = sort_df(filtered_df.query("bad == 95"), "patient_id")
expected = df if passthrough else filtered_df
expected = sort_df(expected.query("bad == 95"), "patient_id")
result = sort_df(result, "patient_id")
assert_frame_equal(result, expected, check_dtype=False, check_categorical=False)

Expand Down
23 changes: 17 additions & 6 deletions tests/system/feature_store/test_spark_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ def test_basic_remote_spark_ingest(self):
print(f"expected_stats_df: {expected_stats_df.to_json()}")
assert stats_df.equals(expected_stats_df)

def test_parquet_filters(self):
@pytest.mark.parametrize("passthrough", [True, False])
def test_parquet_filters(self, passthrough):
parquet_source_path = self.get_pq_source_path()
source_file_name = "testdata_with_none.parquet"
parquet_source_path = parquet_source_path.replace(
Expand All @@ -358,7 +359,10 @@ def test_parquet_filters(self):
additional_filters=filters,
)
feature_set = fstore.FeatureSet(
"parquet-filters-fs", entities=[fstore.Entity("patient_id")], engine="spark"
"parquet-filters-fs",
entities=[fstore.Entity("patient_id")],
engine="spark",
passthrough=passthrough,
)

target = ParquetTarget(
Expand All @@ -372,9 +376,14 @@ def test_parquet_filters(self):
spark_context=self.spark_service,
run_config=run_config,
)
result = sort_df(pd.read_parquet(feature_set.get_target_path()), "patient_id")
expected = sort_df(filtered_df, "patient_id")
assert_frame_equal(result, expected, check_dtype=False, check_categorical=False)
if not passthrough:
result = sort_df(
pd.read_parquet(feature_set.get_target_path()), "patient_id"
)
expected = sort_df(filtered_df, "patient_id")
assert_frame_equal(
result, expected, check_dtype=False, check_categorical=False
)

vec = fstore.FeatureVector(
name="test-fs-vec", features=["parquet-filters-fs.*"]
Expand All @@ -399,8 +408,10 @@ def test_parquet_filters(self):

result = resp.to_dataframe()
result.reset_index(drop=False, inplace=True)

expected = pd.read_parquet(parquet_source_path) if passthrough else filtered_df
expected = sort_df(
filtered_df.query("bad not in [38,100] & movements < 6"),
expected.query("bad not in [38,100] & movements < 6"),
["patient_id"],
)
result = sort_df(result, ["patient_id"])
Expand Down

0 comments on commit bd29590

Please sign in to comment.