Skip to content

Commit

Permalink
Fix multi-entity online retrieval (#1435)
Browse files Browse the repository at this point in the history
* Fix bug when multiple entity keys are required for online retrieval

Signed-off-by: Willem Pienaar <git@willem.co>

* Add clearer method naming to key variables

Signed-off-by: Willem Pienaar <git@willem.co>
  • Loading branch information
woop committed Apr 5, 2021
1 parent 67bd17f commit 9e5377c
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 62 deletions.
52 changes: 43 additions & 9 deletions sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from collections import OrderedDict, defaultdict
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple, Union
Expand Down Expand Up @@ -278,10 +278,8 @@ def materialize_incremental(
>>> from datetime import datetime, timedelta
>>> from feast.feature_store import FeatureStore
>>>
>>> fs = FeatureStore(config=RepoConfig(provider="gcp"))
>>> fs.materialize_incremental(
>>> end_date=datetime.utcnow() - timedelta(minutes=5)
>>> )
>>> fs = FeatureStore(config=RepoConfig(provider="gcp", registry="gs://my-fs/", project="my_fs_proj"))
>>> fs.materialize_incremental(end_date=datetime.utcnow() - timedelta(minutes=5))
"""
feature_views_to_materialize = []
if feature_views is None:
Expand Down Expand Up @@ -333,8 +331,7 @@ def materialize(
>>>
>>> fs = FeatureStore(config=RepoConfig(provider="gcp"))
>>> fs.materialize(
>>> start_date=datetime.utcnow() - timedelta(hours=3),
>>> end_date=datetime.utcnow() - timedelta(minutes=10)
>>> start_date=datetime.utcnow() - timedelta(hours=3), end_date=datetime.utcnow() - timedelta(minutes=10)
>>> )
"""
feature_views_to_materialize = []
Expand Down Expand Up @@ -440,11 +437,11 @@ def _get_online_features(

provider = self._get_provider()

entity_keys = []
union_of_entity_keys = []
result_rows: List[GetOnlineFeaturesResponse.FieldValues] = []

for row in entity_rows:
entity_keys.append(_entity_row_to_key(row))
union_of_entity_keys.append(_entity_row_to_key(row))
result_rows.append(_entity_row_to_field_values(row))

all_feature_views = self._registry.list_feature_views(
Expand All @@ -453,6 +450,7 @@ def _get_online_features(

grouped_refs = _group_refs(feature_refs, all_feature_views)
for table, requested_features in grouped_refs:
entity_keys = _get_table_entity_keys(table, union_of_entity_keys)
read_rows = provider.online_read(
project=project, table=table, entity_keys=entity_keys,
)
Expand Down Expand Up @@ -640,4 +638,40 @@ def _get_requested_feature_views(
feature_refs: List[str], all_feature_views: List[FeatureView]
) -> List[FeatureView]:
"""Get list of feature views based on feature references"""
# TODO: Get rid of this function. We only need _group_refs
return list(view for view, _ in _group_refs(feature_refs, all_feature_views))


def _get_table_entity_keys(
table: FeatureView, entity_keys: List[EntityKeyProto]
) -> List[EntityKeyProto]:
required_entities = OrderedDict.fromkeys(sorted(table.entities))
entity_key_protos = []
for entity_key in entity_keys:
required_entities_to_values = required_entities.copy()
for i in range(len(entity_key.entity_names)):
entity_name = entity_key.entity_names[i]
entity_value = entity_key.entity_values[i]

if entity_name in required_entities_to_values:
if required_entities_to_values[entity_name] is not None:
raise ValueError(
f"Duplicate entity keys detected. Table {table.name} expects {table.entities}. The entity "
f"{entity_name} was provided at least twice"
)
required_entities_to_values[entity_name] = entity_value

entity_names = []
entity_values = []
for entity_name, entity_value in required_entities_to_values.items():
if entity_value is None:
raise ValueError(
f"Table {table.name} expects entity field {table.entities}. No entity value was found for "
f"{entity_name}"
)
entity_names.append(entity_name)
entity_values.append(entity_value)
entity_key_protos.append(
EntityKeyProto(entity_names=entity_names, entity_values=entity_values)
)
return entity_key_protos
10 changes: 5 additions & 5 deletions sdk/python/feast/infra/local_sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,12 @@ def online_write_batch(

with conn:
for entity_key, values, timestamp, created_ts in data:
for feature_name, val in values.items():
entity_key_bin = serialize_entity_key(entity_key)
timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)
entity_key_bin = serialize_entity_key(entity_key)
timestamp = _to_naive_utc(timestamp)
if created_ts is not None:
created_ts = _to_naive_utc(created_ts)

for feature_name, val in values.items():
conn.execute(
f"""
UPDATE {_table_id(project, table)}
Expand Down
42 changes: 33 additions & 9 deletions sdk/python/tests/example_feature_repo_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from google.protobuf.duration_pb2 import Duration
from datetime import timedelta

from feast import BigQuerySource, Entity, Feature, FeatureView, ValueType

Expand All @@ -8,18 +8,32 @@
created_timestamp_column="created_timestamp",
)

customer_profile_source = BigQuerySource(
table_ref="rh_prod.ride_hailing_co.customers",
event_timestamp_column="event_timestamp",
)

customer_driver_combined_source = BigQuerySource(
table_ref="rh_prod.ride_hailing_co.customer_driver",
event_timestamp_column="event_timestamp",
)

driver = Entity(
name="driver", # The name is derived from this argument, not object name.
value_type=ValueType.INT64,
description="driver id",
)

customer = Entity(
name="customer", # The name is derived from this argument, not object name.
value_type=ValueType.STRING,
)


driver_locations = FeatureView(
name="driver_locations",
entities=["driver"],
ttl=Duration(seconds=86400 * 1),
ttl=timedelta(days=1),
features=[
Feature(name="lat", dtype=ValueType.FLOAT),
Feature(name="lon", dtype=ValueType.STRING),
Expand All @@ -29,16 +43,26 @@
tags={},
)

driver_locations_2 = FeatureView(
name="driver_locations_2",
entities=["driver"],
ttl=Duration(seconds=86400 * 1),
customer_profile = FeatureView(
name="customer_profile",
entities=["customer"],
ttl=timedelta(days=1),
features=[
Feature(name="lat", dtype=ValueType.FLOAT),
Feature(name="lon", dtype=ValueType.STRING),
Feature(name="avg_orders_day", dtype=ValueType.FLOAT),
Feature(name="name", dtype=ValueType.STRING),
Feature(name="age", dtype=ValueType.INT64),
],
online=True,
input=driver_locations_source,
input=customer_profile_source,
tags={},
)

customer_driver_combined = FeatureView(
name="customer_driver_combined",
entities=["customer", "driver"],
ttl=timedelta(days=1),
features=[Feature(name="trips", dtype=ValueType.INT64)],
online=True,
input=customer_driver_combined_source,
tags={},
)
139 changes: 100 additions & 39 deletions sdk/python/tests/test_online_retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,25 +18,24 @@ def test_online() -> None:
runner = CliRunner()
with runner.local_repo(get_example_repo("example_feature_repo_1.py")) as store:
# Write some data to two tables
registry = store._registry
table = registry.get_feature_view(
project=store.config.project, name="driver_locations"
)
table_2 = registry.get_feature_view(
project=store.config.project, name="driver_locations_2"

driver_locations_fv = store.get_feature_view(name="driver_locations")
customer_profile_fv = store.get_feature_view(name="customer_profile")
customer_driver_combined_fv = store.get_feature_view(
name="customer_driver_combined"
)

provider = store._get_provider()

entity_key = EntityKeyProto(
driver_key = EntityKeyProto(
entity_names=["driver"], entity_values=[ValueProto(int64_val=1)]
)
provider.online_write_batch(
project=store.config.project,
table=table,
table=driver_locations_fv,
data=[
(
entity_key,
driver_key,
{
"lat": ValueProto(double_val=0.1),
"lon": ValueProto(string_val="1.0"),
Expand All @@ -48,15 +47,19 @@ def test_online() -> None:
progress=None,
)

customer_key = EntityKeyProto(
entity_names=["customer"], entity_values=[ValueProto(int64_val=5)]
)
provider.online_write_batch(
project=store.config.project,
table=table_2,
table=customer_profile_fv,
data=[
(
entity_key,
customer_key,
{
"lat": ValueProto(double_val=2.0),
"lon": ValueProto(string_val="2.0"),
"avg_orders_day": ValueProto(float_val=1.0),
"name": ValueProto(string_val="John"),
"age": ValueProto(int64_val=3),
},
datetime.utcnow(),
datetime.utcnow(),
Expand All @@ -65,15 +68,44 @@ def test_online() -> None:
progress=None,
)

# Retrieve two features using two keys, one valid one non-existing
result = store.get_online_features(
feature_refs=["driver_locations:lon", "driver_locations_2:lon"],
entity_rows=[{"driver": 1}, {"driver": 123}],
customer_key = EntityKeyProto(
entity_names=["customer", "driver"],
entity_values=[ValueProto(int64_val=5), ValueProto(int64_val=1)],
)
provider.online_write_batch(
project=store.config.project,
table=customer_driver_combined_fv,
data=[
(
customer_key,
{"trips": ValueProto(int64_val=7)},
datetime.utcnow(),
datetime.utcnow(),
)
],
progress=None,
)

assert "driver_locations__lon" in result.to_dict()
assert result.to_dict()["driver_locations__lon"] == ["1.0", None]
assert result.to_dict()["driver_locations_2__lon"] == ["2.0", None]
# Retrieve two features using two keys, one valid one non-existing
result = store.get_online_features(
feature_refs=[
"driver_locations:lon",
"customer_profile:avg_orders_day",
"customer_profile:name",
"customer_driver_combined:trips",
],
entity_rows=[{"driver": 1, "customer": 5}, {"driver": 1, "customer": 5}],
).to_dict()

assert "driver_locations__lon" in result
assert "customer_profile__avg_orders_day" in result
assert "customer_profile__name" in result
assert result["driver"] == [1, 1]
assert result["customer"] == [5, 5]
assert result["driver_locations__lon"] == ["1.0", "1.0"]
assert result["customer_profile__avg_orders_day"] == [1.0, 1.0]
assert result["customer_profile__name"] == ["John", "John"]
assert result["customer_driver_combined__trips"] == [7, 7]

# invalid table reference
with pytest.raises(ValueError):
Expand All @@ -96,10 +128,16 @@ def test_online() -> None:

# Should download the registry and cache it permanently (or until manually refreshed)
result = fs_fast_ttl.get_online_features(
feature_refs=["driver_locations:lon", "driver_locations_2:lon"],
entity_rows=[{"driver": 1}, {"driver": 123}],
)
assert result.to_dict()["driver_locations__lon"] == ["1.0", None]
feature_refs=[
"driver_locations:lon",
"customer_profile:avg_orders_day",
"customer_profile:name",
"customer_driver_combined:trips",
],
entity_rows=[{"driver": 1, "customer": 5}],
).to_dict()
assert result["driver_locations__lon"] == ["1.0"]
assert result["customer_driver_combined__trips"] == [7]

# Rename the registry.db so that it cant be used for refreshes
os.rename(store.config.registry, store.config.registry + "_fake")
Expand All @@ -110,19 +148,30 @@ def test_online() -> None:
# Will try to reload registry because it has expired (it will fail because we deleted the actual registry file)
with pytest.raises(FileNotFoundError):
fs_fast_ttl.get_online_features(
feature_refs=["driver_locations:lon", "driver_locations_2:lon"],
entity_rows=[{"driver": 1}, {"driver": 123}],
)
feature_refs=[
"driver_locations:lon",
"customer_profile:avg_orders_day",
"customer_profile:name",
"customer_driver_combined:trips",
],
entity_rows=[{"driver": 1, "customer": 5}],
).to_dict()

# Restore registry.db so that we can see if it actually reloads registry
os.rename(store.config.registry + "_fake", store.config.registry)

# Test if registry is actually reloaded and whether results return
result = fs_fast_ttl.get_online_features(
feature_refs=["driver_locations:lon", "driver_locations_2:lon"],
entity_rows=[{"driver": 1}, {"driver": 123}],
)
assert result.to_dict()["driver_locations__lon"] == ["1.0", None]
feature_refs=[
"driver_locations:lon",
"customer_profile:avg_orders_day",
"customer_profile:name",
"customer_driver_combined:trips",
],
entity_rows=[{"driver": 1, "customer": 5}],
).to_dict()
assert result["driver_locations__lon"] == ["1.0"]
assert result["customer_driver_combined__trips"] == [7]

# Create a registry with infinite cache (for users that want to manually refresh the registry)
fs_infinite_ttl = FeatureStore(
Expand All @@ -138,10 +187,16 @@ def test_online() -> None:

# Should return results (and fill the registry cache)
result = fs_infinite_ttl.get_online_features(
feature_refs=["driver_locations:lon", "driver_locations_2:lon"],
entity_rows=[{"driver": 1}, {"driver": 123}],
)
assert result.to_dict()["driver_locations__lon"] == ["1.0", None]
feature_refs=[
"driver_locations:lon",
"customer_profile:avg_orders_day",
"customer_profile:name",
"customer_driver_combined:trips",
],
entity_rows=[{"driver": 1, "customer": 5}],
).to_dict()
assert result["driver_locations__lon"] == ["1.0"]
assert result["customer_driver_combined__trips"] == [7]

# Wait a bit so that an arbitrary TTL would take effect
time.sleep(2)
Expand All @@ -151,10 +206,16 @@ def test_online() -> None:

# TTL is infinite so this method should use registry cache
result = fs_infinite_ttl.get_online_features(
feature_refs=["driver_locations:lon", "driver_locations_2:lon"],
entity_rows=[{"driver": 1}, {"driver": 123}],
)
assert result.to_dict()["driver_locations__lon"] == ["1.0", None]
feature_refs=[
"driver_locations:lon",
"customer_profile:avg_orders_day",
"customer_profile:name",
"customer_driver_combined:trips",
],
entity_rows=[{"driver": 1, "customer": 5}],
).to_dict()
assert result["driver_locations__lon"] == ["1.0"]
assert result["customer_driver_combined__trips"] == [7]

# Force registry reload (should fail because file is missing)
with pytest.raises(FileNotFoundError):
Expand Down

0 comments on commit 9e5377c

Please sign in to comment.