Skip to content

Commit

Permalink
[Feature store] Add support for multiple entities (#818)
Browse files Browse the repository at this point in the history
  • Loading branch information
katyakats committed Mar 23, 2021
1 parent 39e49f0 commit 88e9450
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 33 deletions.
14 changes: 7 additions & 7 deletions mlrun/datastore/sources.py
Expand Up @@ -28,13 +28,13 @@ def get_source_from_dict(source):
return source_kind_to_driver[kind].from_dict(source)


def get_source_step(source, key_field=None, time_field=None):
def get_source_step(source, key_fields=None, time_field=None):
"""initialize the source driver"""
if hasattr(source, "to_csv"):
source = DataFrameSource(source)
if not key_field and not source.key_field:
if not key_fields and not source.key_fields:
raise mlrun.errors.MLRunInvalidArgumentError("key column is not defined")
return source.to_step(key_field, time_field)
return source.to_step(key_fields, time_field)


class BaseSourceDriver(DataSource):
Expand Down Expand Up @@ -170,17 +170,17 @@ def to_step(self, key_field=None, time_field=None):
class DataFrameSource:
support_storey = True

def __init__(self, df, key_field=None, time_field=None):
def __init__(self, df, key_fields=None, time_field=None):
self._df = df
self.key_field = key_field
self.key_fields = key_fields
self.time_field = time_field

def to_step(self, key_field=None, time_field=None):
def to_step(self, key_fields=None, time_field=None):
import storey

return storey.DataframeSource(
dfs=self._df,
key_field=self.key_field or key_field,
key_field=self.key_fields or key_fields,
time_field=self.time_field or time_field,
)

Expand Down
52 changes: 28 additions & 24 deletions mlrun/datastore/targets.py
Expand Up @@ -72,7 +72,7 @@ def validate_target_placement(graph, final_step, targets):
def add_target_states(graph, resource, targets, to_df=False, final_state=None):
"""add the target states to the graph"""
targets = targets or []
key_column = resource.spec.entities[0].name
key_columns = list(resource.spec.entities.keys())
timestamp_key = resource.spec.timestamp_key
features = resource.spec.features
table = None
Expand All @@ -85,7 +85,7 @@ def add_target_states(graph, resource, targets, to_df=False, final_state=None):
graph,
target.after_state or final_state,
features=features,
key_column=key_column,
key_columns=key_columns,
timestamp_key=timestamp_key,
)
if to_df:
Expand All @@ -95,7 +95,7 @@ def add_target_states(graph, resource, targets, to_df=False, final_state=None):
graph,
final_state,
features=features,
key_column=key_column,
key_columns=key_columns,
timestamp_key=timestamp_key,
)

Expand Down Expand Up @@ -226,7 +226,7 @@ def update_resource_status(self, status="", producer=None, is_dir=None, size=Non
return target

def add_writer_state(
self, graph, after, features, key_column=None, timestamp_key=None
self, graph, after, features, key_columns=None, timestamp_key=None
):
"""add storey writer state to graph"""
raise NotImplementedError()
Expand Down Expand Up @@ -255,7 +255,7 @@ def _write_dataframe(df, fs, target_path, **kwargs):
df.to_parquet(fp, **kwargs)

def add_writer_state(
self, graph, after, features, key_column=None, timestamp_key=None
self, graph, after, features, key_columns=None, timestamp_key=None
):
column_list = list(features.keys())
if timestamp_key and timestamp_key not in column_list:
Expand All @@ -268,7 +268,7 @@ def add_writer_state(
class_name="storey.WriteToParquet",
path=self._target_path,
columns=column_list,
index_cols=key_column,
index_cols=key_columns,
storage_options=self._get_store().get_storage_options(),
**self.attributes,
)
Expand Down Expand Up @@ -299,13 +299,14 @@ def _write_dataframe(df, fs, target_path, **kwargs):
df.to_csv(fp, **kwargs)

def add_writer_state(
self, graph, after, features, key_column=None, timestamp_key=None
self, graph, after, features, key_columns=None, timestamp_key=None
):
column_list = list(features.keys())
if timestamp_key:
column_list = [timestamp_key] + column_list
if key_column not in column_list:
column_list.insert(0, key_column)
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)
graph.add_step(
name="WriteToCSV",
after=after,
Expand All @@ -314,7 +315,7 @@ def add_writer_state(
path=self._target_path,
columns=column_list,
header=True,
index_cols=key_column,
index_cols=key_columns,
storage_options=self._get_store().get_storage_options(),
**self.attributes,
)
Expand Down Expand Up @@ -342,14 +343,15 @@ def get_table_object(self):
return Table(uri, V3ioDriver(webapi=endpoint))

def add_writer_state(
self, graph, after, features, key_column=None, timestamp_key=None
self, graph, after, features, key_columns=None, timestamp_key=None
):
table = self._resource.uri
column_list = [
key for key, feature in features.items() if not feature.aggregate
]
if key_column not in column_list:
column_list.insert(0, key_column)
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)
graph.add_step(
name="WriteToTable",
after=after,
Expand Down Expand Up @@ -379,16 +381,17 @@ class StreamTarget(BaseStoreTarget):
support_storey = True

def add_writer_state(
self, graph, after, features, key_column=None, timestamp_key=None
self, graph, after, features, key_columns=None, timestamp_key=None
):
from storey import V3ioDriver

endpoint, uri = parse_v3io_path(self._target_path)
column_list = list(features.keys())
if timestamp_key and timestamp_key not in column_list:
column_list = [timestamp_key] + column_list
if key_column not in column_list:
column_list.insert(0, key_column)
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)
graph.add_step(
name="WriteToStream",
after=after,
Expand All @@ -412,24 +415,25 @@ class TSDBTarget(BaseStoreTarget):
support_storey = True

def add_writer_state(
self, graph, after, features, key_column=None, timestamp_key=None
self, graph, after, features, key_columns=None, timestamp_key=None
):
endpoint, uri = parse_v3io_path(self._target_path)
column_list = list(features.keys())
if not timestamp_key:
raise mlrun.errors.MLRunInvalidArgumentError(
"feature set timestamp_key must be specified for TSDBTarget writer"
)
if key_column not in column_list:
column_list.insert(0, key_column)
for key in reversed(key_columns):
if key not in column_list:
column_list.insert(0, key)
graph.add_step(
name="WriteToTSDB",
class_name="storey.WriteToTSDB",
after=after,
graph_shape="cylinder",
path=uri,
time_col=timestamp_key,
index_cols=key_column,
index_cols=key_columns,
columns=column_list,
**self.attributes,
)
Expand All @@ -453,7 +457,7 @@ def __init__(
super().__init__(name, "", attributes, after_state=after_state)

def add_writer_state(
self, graph, after, features, key_column=None, timestamp_key=None
self, graph, after, features, key_columns=None, timestamp_key=None
):
attributes = copy(self.attributes)
class_name = attributes.pop("class_name")
Expand All @@ -480,16 +484,16 @@ def update_resource_status(self, status="", producer=None, is_dir=None):
pass

def add_writer_state(
self, graph, after, features, key_column=None, timestamp_key=None
self, graph, after, features, key_columns=None, timestamp_key=None
):
# todo: column filter
graph.add_step(
name="WriteToDataFrame",
after=after,
graph_shape="cylinder",
class_name="storey.ReduceToDataFrame",
index=key_column,
insert_key_column_as=key_column,
index=key_columns,
insert_key_column_as=key_columns,
insert_time_column_as=timestamp_key,
)

Expand Down
5 changes: 3 additions & 2 deletions mlrun/feature_store/ingestion.py
Expand Up @@ -93,10 +93,11 @@ def _add_data_states(
cache.cache_table(featureset.uri, table, True)

entity_columns = list(featureset.spec.entities.keys())
key_field = entity_columns[0] if entity_columns else None
key_fields = entity_columns if entity_columns else None

if source is not None:
source = get_source_step(
source, key_field=key_field, time_field=featureset.spec.timestamp_key,
source, key_fields=key_fields, time_field=featureset.spec.timestamp_key,
)
graph.set_flow_source(source)

Expand Down
56 changes: 56 additions & 0 deletions tests/system/feature_store/test_feature_store.py
Expand Up @@ -341,6 +341,62 @@ def test_read_csv(self):
), f"{termination_result}\n!=\n{expected}"
os.remove(csv_path)

def test_multiple_entities(self):

current_time = pd.Timestamp.now()
data = pd.DataFrame(
{
"time": [
current_time,
current_time - pd.Timedelta(minutes=1),
current_time - pd.Timedelta(minutes=2),
current_time - pd.Timedelta(minutes=3),
current_time - pd.Timedelta(minutes=4),
current_time - pd.Timedelta(minutes=5),
],
"first_name": ["moshe", "yosi", "yosi", "yosi", "moshe", "yosi"],
"last_name": ["cohen", "levi", "levi", "levi", "cohen", "levi"],
"bid": [2000, 10, 11, 12, 2500, 14],
}
)

# write to kv
data_set = fs.FeatureSet(
"tests2", entities=[Entity("first_name"), Entity("last_name")]
)

data_set.add_aggregation(
name="bids",
column="bid",
operations=["sum", "max"],
windows=["1h"],
period="10m",
)
fs.infer_metadata(
data_set,
data, # source
entity_columns=["first_name", "last_name"],
timestamp_key="time",
options=fs.InferOptions.default(),
)

data_set.plot(
str(self.results_path / "pipe.png"), rankdir="LR", with_targets=True
)
fs.ingest(data_set, data, return_df=True)

features = [
"tests2.bids_sum_1h",
]

vector = fs.FeatureVector("my-vec", features)
svc = fs.get_online_feature_service(vector)

resp = svc.get([{"first_name": "yosi", "last_name": "levi"}])
print(resp[0])

svc.close()


def verify_ingest(
base_data, keys, infer=False, targets=None, infer_options=fs.InferOptions.default()
Expand Down

0 comments on commit 88e9450

Please sign in to comment.