Skip to content

Commit

Permalink
[FeatureStore] Spark read optimization (#5514)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomerm-iguazio committed May 21, 2024
1 parent b96eb9b commit 2bc950e
Show file tree
Hide file tree
Showing 8 changed files with 237 additions and 22 deletions.
117 changes: 111 additions & 6 deletions mlrun/datastore/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import math
import operator
import os
import warnings
from base64 import b64encode
Expand Down Expand Up @@ -178,7 +180,7 @@ def __init__(
self,
name: str = "",
path: str = None,
attributes: dict[str, str] = None,
attributes: dict[str, object] = None,
key_field: str = None,
schedule: str = None,
parse_dates: Union[None, int, str, list[int], list[str]] = None,
Expand Down Expand Up @@ -305,14 +307,18 @@ def __init__(
self,
name: str = "",
path: str = None,
attributes: dict[str, str] = None,
attributes: dict[str, object] = None,
key_field: str = None,
time_field: str = None,
schedule: str = None,
start_time: Optional[Union[datetime, str]] = None,
end_time: Optional[Union[datetime, str]] = None,
additional_filters: Optional[list[tuple]] = None,
):
if additional_filters:
attributes = copy(attributes) or {}
attributes["additional_filters"] = additional_filters
self.validate_additional_filters(additional_filters)
super().__init__(
name,
path,
Expand All @@ -323,7 +329,6 @@ def __init__(
start_time,
end_time,
)
self.additional_filters = additional_filters

@property
def start_time(self):
Expand All @@ -341,6 +346,10 @@ def end_time(self):
def end_time(self, end_time):
self._end_time = self._convert_to_datetime(end_time)

@property
def additional_filters(self):
return self.attributes.get("additional_filters")

@staticmethod
def _convert_to_datetime(time):
if time and isinstance(time, str):
Expand All @@ -350,6 +359,25 @@ def _convert_to_datetime(time):
else:
return time

@staticmethod
def validate_additional_filters(additional_filters):
if not additional_filters:
return
for filter_tuple in additional_filters:
if not filter_tuple:
continue
col_name, op, value = filter_tuple
if isinstance(value, float) and math.isnan(value):
raise mlrun.errors.MLRunInvalidArgumentError(
"using NaN in additional_filters is not supported"
)
elif isinstance(value, (list, tuple, set)):
for sub_value in value:
if isinstance(sub_value, float) and math.isnan(sub_value):
raise mlrun.errors.MLRunInvalidArgumentError(
"using NaN in additional_filters is not supported"
)

def to_step(
self,
key_field=None,
Expand All @@ -361,13 +389,12 @@ def to_step(
):
import storey

attributes = self.attributes or {}
attributes = copy(self.attributes)
attributes.pop("additional_filters", None)
if context:
attributes["context"] = context

data_item = mlrun.store_manager.object(self.path)
store, path, url = mlrun.store_manager.get_or_create_store(self.path)

return storey.ParquetSource(
paths=url, # unlike self.path, it already has store:// replaced
key_field=self.key_field or key_field,
Expand Down Expand Up @@ -412,6 +439,84 @@ def to_dataframe(
**reader_args,
)

def _build_spark_additional_filters(self, column_types: dict):
if not self.additional_filters:
return None
from pyspark.sql.functions import col, isnan, lit

operators = {
"==": operator.eq,
"=": operator.eq,
">": operator.gt,
"<": operator.lt,
">=": operator.ge,
"<=": operator.le,
"!=": operator.ne,
}

spark_filter = None
new_filter = lit(True)
for filter_tuple in self.additional_filters:
if not filter_tuple:
continue
col_name, op, value = filter_tuple
if op.lower() in ("in", "not in") and isinstance(value, (list, tuple, set)):
none_exists = False
value = list(value)
for sub_value in value:
if sub_value is None:
value.remove(sub_value)
none_exists = True
if none_exists:
filter_nan = column_types[col_name] not in ("timestamp", "date")
if value:
if op.lower() == "in":
new_filter = (
col(col_name).isin(value) | col(col_name).isNull()
)
if filter_nan:
new_filter = new_filter | isnan(col(col_name))

else:
new_filter = (
~col(col_name).isin(value) & ~col(col_name).isNull()
)
if filter_nan:
new_filter = new_filter & ~isnan(col(col_name))
else:
if op.lower() == "in":
new_filter = col(col_name).isNull()
if filter_nan:
new_filter = new_filter | isnan(col(col_name))
else:
new_filter = ~col(col_name).isNull()
if filter_nan:
new_filter = new_filter & ~isnan(col(col_name))
else:
if op.lower() == "in":
new_filter = col(col_name).isin(value)
elif op.lower() == "not in":
new_filter = ~col(col_name).isin(value)
elif op in operators:
new_filter = operators[op](col(col_name), value)
else:
raise mlrun.errors.MLRunInvalidArgumentError(
f"unsupported filter operator: {op}"
)
if spark_filter is not None:
spark_filter = spark_filter & new_filter
else:
spark_filter = new_filter
return spark_filter

def _filter_spark_df(self, df, time_field=None, columns=None):
spark_additional_filters = self._build_spark_additional_filters(
column_types=dict(df.dtypes)
)
if spark_additional_filters is not None:
df = df.filter(spark_additional_filters)
return super()._filter_spark_df(df=df, time_field=time_field, columns=columns)


class BigQuerySource(BaseSourceDriver):
"""
Expand Down
5 changes: 3 additions & 2 deletions mlrun/feature_store/retrieval/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,8 @@ def target_uri(self):
from mlrun.datastore.targets import get_target_driver
def merge_handler(context, vector_uri, target, entity_rows=None,
entity_timestamp_column=None, drop_columns=None, with_indexes=None, query=None,
engine_args=None, order_by=None, start_time=None, end_time=None, timestamp_for_filtering=None):
engine_args=None, order_by=None, start_time=None, end_time=None, timestamp_for_filtering=None,
additional_filters=None):
vector = context.get_store_resource(vector_uri)
store_target = get_target_driver(target, vector)
if entity_rows:
Expand All @@ -208,7 +209,7 @@ def merge_handler(context, vector_uri, target, entity_rows=None,
merger = mlrun.feature_store.retrieval.{{{engine}}}(vector, **(engine_args or {}))
merger.start(entity_rows, entity_timestamp_column, store_target, drop_columns, with_indexes=with_indexes,
query=query, order_by=order_by, start_time=start_time, end_time=end_time,
timestamp_for_filtering=timestamp_for_filtering)
timestamp_for_filtering=timestamp_for_filtering, additional_filters=additional_filters)
target = vector.status.targets[store_target.name].to_dict()
context.log_result('feature_vector', vector.uri)
Expand Down
3 changes: 2 additions & 1 deletion mlrun/feature_store/retrieval/spark_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#

import pandas as pd
import semver

Expand Down Expand Up @@ -252,13 +253,13 @@ def _get_engine_df(
# handling case where there are multiple feature sets and user creates vector where
# entity_timestamp_column is from a specific feature set (can't be entity timestamp)
source_driver = mlrun.datastore.sources.source_kind_to_driver[source_kind]

source = source_driver(
name=self.vector.metadata.name,
path=source_path,
time_field=time_column,
start_time=start_time,
end_time=end_time,
additional_filters=additional_filters,
**source_kwargs,
)

Expand Down
16 changes: 15 additions & 1 deletion tests/datastore/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import mlrun.config
from mlrun import new_function
from mlrun.datastore import CSVSource, KafkaSource
from mlrun.datastore import CSVSource, KafkaSource, ParquetSource


def test_kafka_source_with_old_nuclio():
Expand Down Expand Up @@ -104,3 +104,17 @@ def test_timestamp_format_inference(rundb_mock):
)
)
pd.testing.assert_frame_equal(result_df, expected_result_df)


@pytest.mark.parametrize(
"additional_filters",
[[("age", "=", float("nan"))], [("age", "in", [10, float("nan")])]],
)
def test_nan_additional_filters(additional_filters):
with pytest.raises(
mlrun.errors.MLRunInvalidArgumentError,
match="using NaN in additional_filters is not supported",
):
ParquetSource(
"parquet_source", path="path/to/file", additional_filters=additional_filters
)
Binary file not shown.
18 changes: 6 additions & 12 deletions tests/system/feature_store/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from storey.dtypes import V3ioError

import mlrun
import mlrun.datastore.utils
import mlrun.feature_store as fstore
import tests.conftest
from mlrun.config import config
Expand Down Expand Up @@ -70,6 +71,7 @@
from mlrun.features import MinMaxValidator, RegexValidator
from mlrun.model import DataTarget
from tests.system.base import TestMLRunSystem
from tests.system.feature_store.utils import sort_df

from .data_sample import quotes, stocks, trades

Expand Down Expand Up @@ -4816,14 +4818,6 @@ def test_merge_different_number_of_entities(self):
).to_dataframe()
assert_frame_equal(expected_all, df, check_dtype=False)

@staticmethod
def _sort_df(df: pd.DataFrame, sort_column: str):
return (
df.reindex(sorted(df.columns), axis=1)
.sort_values(by=sort_column)
.reset_index(drop=True)
)

@pytest.mark.parametrize("engine", ["local", "dask"])
def test_parquet_filters(self, engine):
parquet_path = os.path.relpath(str(self.assets_path / "testdata.parquet"))
Expand Down Expand Up @@ -4859,8 +4853,8 @@ def test_parquet_filters(self, engine):
result = target.as_df(additional_filters=("room", "=", 1)).reset_index()
# We want to include patient_id in the comparison,
# sort the columns alphabetically, and sort the rows by patient_id values.
result = self._sort_df(result, "patient_id")
expected = self._sort_df(filtered_df.query("room == 1"), "patient_id")
result = sort_df(result, "patient_id")
expected = sort_df(filtered_df.query("room == 1"), "patient_id")
# the content of category column is still checked:
assert_frame_equal(result, expected, check_dtype=False, check_categorical=False)
vec = fstore.FeatureVector(
Expand All @@ -4876,8 +4870,8 @@ def test_parquet_filters(self, engine):
.to_dataframe()
.reset_index()
)
expected = self._sort_df(filtered_df.query("bad == 95"), "patient_id")
result = self._sort_df(result, "patient_id")
expected = sort_df(filtered_df.query("bad == 95"), "patient_id")
result = sort_df(result, "patient_id")
assert_frame_equal(result, expected, check_dtype=False, check_categorical=False)


Expand Down
Loading

0 comments on commit 2bc950e

Please sign in to comment.