Skip to content

Commit

Permalink
[Feature store] Moving filtering by date from ingest to ParquetSource (
Browse files Browse the repository at this point in the history
  • Loading branch information
katyakats committed Apr 28, 2021
1 parent 7c02b9e commit 69a9674
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 70 deletions.
45 changes: 14 additions & 31 deletions mlrun/datastore/sources.py
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from copy import copy
from typing import Dict
from datetime import datetime
from typing import Dict, Optional, Union

import mlrun

Expand All @@ -28,15 +29,13 @@ def get_source_from_dict(source):
return source_kind_to_driver[kind].from_dict(source)


def get_source_step(
source, key_fields=None, time_field=None, start_time=None, end_time=None,
):
def get_source_step(source, key_fields=None, time_field=None):
"""initialize the source driver"""
if hasattr(source, "to_csv"):
source = DataFrameSource(source)
if not key_fields and not source.key_fields:
raise mlrun.errors.MLRunInvalidArgumentError("key column is not defined")
return source.to_step(key_fields, time_field, start_time, end_time)
return source.to_step(key_fields, time_field)


class BaseSourceDriver(DataSource):
Expand All @@ -47,14 +46,9 @@ def _get_store(self):
store, _ = mlrun.store_manager.get_or_create_store(self.path)
return store

def to_step(
self, key_field=None, time_field=None, start_time=None, end_time=None,
):
def to_step(self, key_field=None, time_field=None):
import storey

if start_time or end_time:
raise NotImplementedError("BaseSource does not support filtering by time")

return storey.SyncEmitSource()

def get_table_object(self):
Expand Down Expand Up @@ -93,14 +87,9 @@ def __init__(
):
super().__init__(name, path, attributes, key_field, time_field, schedule)

def to_step(
self, key_field=None, time_field=None, start_time=None, end_time=None,
):
def to_step(self, key_field=None, time_field=None):
import storey

if start_time or end_time:
raise NotImplementedError("CSVSource does not support filtering by time")

attributes = self.attributes or {}
return storey.CSVSource(
paths=self.path,
Expand Down Expand Up @@ -134,8 +123,12 @@ def __init__(
key_field: str = None,
time_field: str = None,
schedule: str = None,
start_time: Optional[Union[str, datetime]] = None,
end_time: Optional[Union[str, datetime]] = None,
):
super().__init__(name, path, attributes, key_field, time_field, schedule)
self.start_time = start_time
self.end_time = end_time

def to_step(
self, key_field=None, time_field=None, start_time=None, end_time=None,
Expand All @@ -148,8 +141,8 @@ def to_step(
key_field=self.key_field or key_field,
time_field=self.time_field or time_field,
storage_options=self._get_store().get_storage_options(),
end_filter=end_time,
start_filter=start_time,
end_filter=self.end_time,
start_filter=self.start_time,
filter_column=self.time_field or time_field,
**attributes,
)
Expand Down Expand Up @@ -192,16 +185,9 @@ def __init__(self, df, key_fields=None, time_field=None):
self.key_fields = key_fields
self.time_field = time_field

def to_step(
self, key_fields=None, time_field=None, start_time=None, end_time=None,
):
def to_step(self, key_fields=None, time_field=None):
import storey

if start_time or end_time:
raise NotImplementedError(
"DataFrameSource does not support filtering by time"
)

return storey.DataframeSource(
dfs=self._df,
key_field=self.key_fields or key_fields,
Expand Down Expand Up @@ -241,13 +227,10 @@ def __init__(
self.workers = workers

def to_step(
self, key_field=None, time_field=None, start_time=None, end_time=None,
self, key_field=None, time_field=None,
):
import storey

if start_time or end_time:
raise NotImplementedError("Source does not support filtering by time")

return storey.SyncEmitSource(
key_field=self.key_field or key_field,
time_field=self.time_field or time_field,
Expand Down
12 changes: 1 addition & 11 deletions mlrun/feature_store/api.py
Expand Up @@ -144,8 +144,6 @@ def ingest(
run_config: RunConfig = None,
mlrun_context=None,
spark_context=None,
start_time=None,
end_time=None,
) -> pd.DataFrame:
"""Read local DataFrame, file, URL, or source into the feature store
Ingest reads from the source, run the graph transformations, infers metadata and stats
Expand Down Expand Up @@ -180,8 +178,6 @@ def ingest(
:param mlrun_context: mlrun context (when running as a job), for internal use !
:param spark_context: local spark session for spark ingestion, example for creating the spark context:
`spark = SparkSession.builder.appName("Spark function").getOrCreate()`
:param start_time: datetime/string, low limit of time needed to be filtered. format '2020-11-01 17:33:15'
:param end_time: datetime/string, high limit of time needed to be filtered. format '2020-12-01 17:33:15'
"""
if featureset:
if isinstance(featureset, str):
Expand Down Expand Up @@ -257,13 +253,7 @@ def ingest(

targets = targets or featureset.spec.targets or get_default_targets()
df = init_featureset_graph(
source,
featureset,
namespace,
targets=targets,
return_df=return_df,
start_time=start_time,
end_time=end_time,
source, featureset, namespace, targets=targets, return_df=return_df,
)
infer_from_static_df(df, featureset, options=infer_stats)
_post_ingestion(mlrun_context, featureset, spark_context)
Expand Down
25 changes: 3 additions & 22 deletions mlrun/feature_store/ingestion.py
Expand Up @@ -32,13 +32,7 @@


def init_featureset_graph(
source,
featureset,
namespace,
targets=None,
return_df=True,
start_time=None,
end_time=None,
source, featureset, namespace, targets=None, return_df=True,
):
"""create storey ingestion graph/DAG from feature set object"""

Expand All @@ -55,8 +49,6 @@ def init_featureset_graph(
targets=targets,
source=source,
return_df=return_df,
start_time=start_time,
end_time=end_time,
)

server = create_graph_server(graph=graph, parameters={})
Expand Down Expand Up @@ -130,14 +122,7 @@ def context_to_ingestion_params(context):


def _add_data_states(
graph,
cache,
featureset,
targets,
source,
return_df=False,
start_time=None,
end_time=None,
graph, cache, featureset, targets, source, return_df=False,
):
_, default_final_state, _ = graph.check_and_process_graph(allow_empty=True)
validate_target_list(targets=targets)
Expand All @@ -154,11 +139,7 @@ def _add_data_states(

if source is not None:
source = get_source_step(
source,
key_fields=key_fields,
time_field=featureset.spec.timestamp_key,
start_time=start_time,
end_time=end_time,
source, key_fields=key_fields, time_field=featureset.spec.timestamp_key,
)
graph.set_flow_source(source)

Expand Down
8 changes: 2 additions & 6 deletions tests/system/feature_store/test_feature_store.py
Expand Up @@ -284,15 +284,11 @@ def test_filtering_parquet_by_time(self):
"myparquet",
path=os.path.relpath(str(self.assets_path / "testdata.parquet")),
time_field="timestamp",
)

resp = fs.ingest(
measurements,
source,
start_time=datetime(2020, 12, 1, 17, 33, 15),
end_time="2020-12-01 17:33:16",
return_df=True,
)

resp = fs.ingest(measurements, source, return_df=True,)
assert len(resp) == 10

def test_ordered_pandas_asof_merge(self):
Expand Down

0 comments on commit 69a9674

Please sign in to comment.