Skip to content

Commit

Permalink
[Feature Store] Split graph support (#872)
Browse files Browse the repository at this point in the history
  • Loading branch information
urihoenig committed Apr 22, 2021
1 parent bb2259a commit 6e2ac39
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 31 deletions.
2 changes: 0 additions & 2 deletions docs/store/feature-store-demo.ipynb
Expand Up @@ -15,8 +15,6 @@
"\n",
"Install the latest MLRun package and the following package before running the demo and restart the notebook\n",
"\n",
" !pip install storey\n",
"\n",
"Setting up the environment and project"
]
},
Expand Down
72 changes: 46 additions & 26 deletions mlrun/datastore/targets.py
Expand Up @@ -88,7 +88,7 @@ def add_target_states(graph, resource, targets, to_df=False, final_state=None):
driver.add_writer_state(
graph,
target.after_state or final_state,
features=features,
features=features if not target.after_state else None,
key_columns=key_columns,
timestamp_key=timestamp_key,
)
Expand Down Expand Up @@ -154,11 +154,13 @@ def __init__(
path=None,
attributes: typing.Dict[str, str] = None,
after_state=None,
columns=None,
):
self.name = name
self.path = str(path) if path is not None else None
self.after_state = after_state
self.attributes = attributes or {}
self.columns = columns or []

self._target = None
self._resource = None
Expand All @@ -168,6 +170,20 @@ def _get_store(self):
store, _ = mlrun.store_manager.get_or_create_store(self._target_path)
return store

def _get_column_list(self, features, timestamp_key, key_columns):
column_list = None
if self.columns:
return self.columns
elif features:
column_list = list(features.keys())
if timestamp_key and timestamp_key not in column_list:
column_list = [timestamp_key] + column_list
if key_columns:
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)
return column_list

def write_dataframe(
self, df, key_column=None, timestamp_key=None, **kwargs,
) -> typing.Optional[int]:
Expand Down Expand Up @@ -205,6 +221,8 @@ def from_spec(cls, spec: DataTargetBase, resource=None):
driver.name = spec.name
driver.path = spec.path
driver.attributes = spec.attributes
if hasattr(spec, "columns"):
driver.columns = spec.columns
driver._resource = resource
return driver

Expand Down Expand Up @@ -263,9 +281,9 @@ def _write_dataframe(df, fs, target_path, **kwargs):
def add_writer_state(
self, graph, after, features, key_columns=None, timestamp_key=None
):
column_list = list(features.keys())
if timestamp_key and timestamp_key not in column_list:
column_list = [timestamp_key] + column_list
column_list = self._get_column_list(
features=features, timestamp_key=timestamp_key, key_columns=None
)

graph.add_step(
name="WriteToParquet",
Expand Down Expand Up @@ -307,12 +325,10 @@ def _write_dataframe(df, fs, target_path, **kwargs):
def add_writer_state(
self, graph, after, features, key_columns=None, timestamp_key=None
):
column_list = list(features.keys())
if timestamp_key:
column_list = [timestamp_key] + column_list
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)
column_list = self._get_column_list(
features=features, timestamp_key=timestamp_key, key_columns=key_columns
)

graph.add_step(
name="WriteToCSV",
after=after,
Expand Down Expand Up @@ -352,12 +368,17 @@ def add_writer_state(
self, graph, after, features, key_columns=None, timestamp_key=None
):
table = self._resource.uri
column_list = [
key for key, feature in features.items() if not feature.aggregate
]
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)
column_list = self._get_column_list(
features=features, timestamp_key=None, key_columns=key_columns
)
if not self.columns:
aggregate_features = (
[key for key, feature in features.items() if feature.aggregate]
if features
else []
)
column_list = [col for col in column_list if col in aggregate_features]

graph.add_step(
name="WriteToTable",
after=after,
Expand Down Expand Up @@ -411,12 +432,10 @@ def add_writer_state(
from storey import V3ioDriver

endpoint, uri = parse_v3io_path(self._target_path)
column_list = list(features.keys())
if timestamp_key and timestamp_key not in column_list:
column_list = [timestamp_key] + column_list
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)
column_list = self._get_column_list(
features=features, timestamp_key=timestamp_key, key_columns=key_columns
)

graph.add_step(
name="WriteToStream",
after=after,
Expand All @@ -443,14 +462,15 @@ def add_writer_state(
self, graph, after, features, key_columns=None, timestamp_key=None
):
endpoint, uri = parse_v3io_path(self._target_path)
column_list = list(features.keys())
if not timestamp_key:
raise mlrun.errors.MLRunInvalidArgumentError(
"feature set timestamp_key must be specified for TSDBTarget writer"
)
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)

column_list = self._get_column_list(
features=features, timestamp_key=None, key_columns=key_columns
)

graph.add_step(
name="WriteToTSDB",
class_name="storey.WriteToTSDB",
Expand Down
8 changes: 8 additions & 0 deletions mlrun/feature_store/api.py
Expand Up @@ -292,6 +292,13 @@ def infer(

namespace = namespace or get_caller_globals()
if featureset.spec.require_processing():
_, default_final_state, _ = featureset.graph.check_and_process_graph(
allow_empty=True
)
if not default_final_state:
raise mlrun.errors.MLRunPreconditionFailedError(
"Split flow graph must have a default final state defined"
)
# find/update entities schema
if len(featureset.spec.entities) == 0:
infer_from_static_df(
Expand All @@ -309,6 +316,7 @@ def infer(

# keep for backwards compatibility
infer_metadata = infer
preview = infer


def _run_ingestion_job(
Expand Down
108 changes: 105 additions & 3 deletions tests/system/feature_store/test_feature_store.py
Expand Up @@ -11,7 +11,7 @@
import mlrun.feature_store as fs
from mlrun.data_types.data_types import ValueType
from mlrun.datastore.sources import CSVSource
from mlrun.datastore.targets import CSVTarget, TargetTypes
from mlrun.datastore.targets import CSVTarget, ParquetTarget, TargetTypes
from mlrun.feature_store import Entity, FeatureSet
from mlrun.feature_store.steps import FeaturesetValidator
from mlrun.features import MinMaxValidator
Expand Down Expand Up @@ -399,6 +399,106 @@ def test_multiple_entities(self):

svc.close()

_split_graph_expected_default = pd.DataFrame(
{
"time": [
pd.Timestamp("2016-05-25 13:30:00.023"),
pd.Timestamp("2016-05-25 13:30:00.048"),
pd.Timestamp("2016-05-25 13:30:00.049"),
pd.Timestamp("2016-05-25 13:30:00.072"),
],
"ticker": ["GOOG", "GOOG", "AAPL", "GOOG"],
"bid": [720.50, 720.50, 97.99, 720.50],
"ask": [720.93, 720.93, 98.01, 720.88],
"xx": [2161.50, 2161.50, 293.97, 2161.50],
"zz": [9, 9, 9, 9],
"extra": [55478.50, 55478.50, 7545.23, 55478.50],
}
)

_split_graph_expected_side = pd.DataFrame(
{
"time": [
pd.Timestamp("2016-05-25 13:30:00.023"),
pd.Timestamp("2016-05-25 13:30:00.023"),
pd.Timestamp("2016-05-25 13:30:00.030"),
pd.Timestamp("2016-05-25 13:30:00.041"),
pd.Timestamp("2016-05-25 13:30:00.048"),
pd.Timestamp("2016-05-25 13:30:00.049"),
pd.Timestamp("2016-05-25 13:30:00.072"),
pd.Timestamp("2016-05-25 13:30:00.075"),
],
"ticker": ["GOOG", "MSFT", "MSFT", "MSFT", "GOOG", "AAPL", "GOOG", "MSFT"],
"bid": [720.50, 51.95, 51.97, 51.99, 720.50, 97.99, 720.50, 52.01],
"ask": [720.93, 51.96, 51.98, 52.00, 720.93, 98.01, 720.88, 52.03],
"extra2": [
12248.50,
883.15,
883.49,
883.83,
12248.50,
1665.83,
12248.50,
884.17,
],
}
)

def test_split_graph(self):
quotes_set = fs.FeatureSet("stock-quotes", entities=[fs.Entity("ticker")])

quotes_set.graph.to("MyMap", "somemap1", field="multi1", multiplier=3).to(
"storey.Extend", _fn="({'extra': event['bid'] * 77})"
).to("storey.Filter", "filter", _fn="(event['bid'] > 70)").to(
FeaturesetValidator()
)

side_step_name = "side-step"
quotes_set.graph.to(
"storey.Extend", name=side_step_name, _fn="({'extra2': event['bid'] * 17})"
)
with pytest.raises(mlrun.errors.MLRunPreconditionFailedError):
fs.infer_metadata(quotes_set, quotes)

non_default_target_name = "side-target"
quotes_set.set_targets(
targets=[
CSVTarget(name=non_default_target_name, after_state=side_step_name)
],
default_final_state="FeaturesetValidator",
)

quotes_set.plot(with_targets=True)

inf_out = fs.infer_metadata(quotes_set, quotes)
ing_out = fs.ingest(quotes_set, quotes, return_df=True)

default_file_path = quotes_set.get_target_path(TargetTypes.parquet)
side_file_path = quotes_set.get_target_path(non_default_target_name)

side_file_out = pd.read_csv(side_file_path)
default_file_out = pd.read_parquet(default_file_path)
self._split_graph_expected_default.set_index("ticker", inplace=True)

assert all(self._split_graph_expected_default == default_file_out.round(2))
assert all(self._split_graph_expected_default == ing_out.round(2))
assert all(self._split_graph_expected_default == inf_out.round(2))

assert all(
self._split_graph_expected_side.sort_index(axis=1)
== side_file_out.sort_index(axis=1).round(2)
)

def test_forced_columns_target(self):
columns = ["time", "ask"]
targets = [ParquetTarget(columns=columns)]
quotes_set, _ = prepare_feature_set(
"forced-columns", "ticker", quotes, timestamp_key="time", targets=targets
)

df = pd.read_parquet(quotes_set.get_target_path())
assert all(df.columns.values == columns)


def verify_ingest(
base_data, keys, infer=False, targets=None, infer_options=fs.InferOptions.default()
Expand All @@ -422,12 +522,14 @@ def verify_ingest(
assert all(df.values[idx] == data.values[idx])


def prepare_feature_set(name: str, entity: str, data: pd.DataFrame, timestamp_key=None):
def prepare_feature_set(
name: str, entity: str, data: pd.DataFrame, timestamp_key=None, targets=None
):
df_source = mlrun.datastore.sources.DataFrameSource(data, entity, timestamp_key)

feature_set = fs.FeatureSet(
name, entities=[fs.Entity(entity)], timestamp_key=timestamp_key
)
feature_set.set_targets()
feature_set.set_targets(targets=targets, with_defaults=False if targets else True)
df = fs.ingest(feature_set, df_source, infer_options=fs.InferOptions.default())
return feature_set, df

0 comments on commit 6e2ac39

Please sign in to comment.