Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix file offline store logic for feature views without ttl #2971

Merged
merged 2 commits into from Jul 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
8 changes: 8 additions & 0 deletions sdk/python/feast/infra/offline_stores/file.py
Expand Up @@ -635,6 +635,14 @@ def _filter_ttl(
)
]

df_to_join = df_to_join.persist()
else:
df_to_join = df_to_join[
# do not drop entity rows if one of the sources returns NaNs
df_to_join[timestamp_field].isna()
| (df_to_join[timestamp_field] <= df_to_join[entity_df_event_timestamp_col])
]

df_to_join = df_to_join.persist()

return df_to_join
Expand Down
Expand Up @@ -115,60 +115,70 @@ def get_expected_training_df(
entity_df.to_dict("records"), event_timestamp
)

# Set sufficiently large ttl that it effectively functions as infinite for the calculations below.
default_ttl = timedelta(weeks=52)

# Manually do point-in-time join of driver, customer, and order records against
# the entity df
for entity_row in entity_rows:
customer_record = find_asof_record(
customer_records,
ts_key=customer_fv.batch_source.timestamp_field,
ts_start=entity_row[event_timestamp] - customer_fv.ttl,
ts_start=entity_row[event_timestamp]
- get_feature_view_ttl(customer_fv, default_ttl),
ts_end=entity_row[event_timestamp],
filter_keys=["customer_id"],
filter_values=[entity_row["customer_id"]],
)
driver_record = find_asof_record(
driver_records,
ts_key=driver_fv.batch_source.timestamp_field,
ts_start=entity_row[event_timestamp] - driver_fv.ttl,
ts_start=entity_row[event_timestamp]
- get_feature_view_ttl(driver_fv, default_ttl),
ts_end=entity_row[event_timestamp],
filter_keys=["driver_id"],
filter_values=[entity_row["driver_id"]],
)
order_record = find_asof_record(
order_records,
ts_key=customer_fv.batch_source.timestamp_field,
ts_start=entity_row[event_timestamp] - order_fv.ttl,
ts_start=entity_row[event_timestamp]
- get_feature_view_ttl(order_fv, default_ttl),
ts_end=entity_row[event_timestamp],
filter_keys=["customer_id", "driver_id"],
filter_values=[entity_row["customer_id"], entity_row["driver_id"]],
)
origin_record = find_asof_record(
location_records,
ts_key=location_fv.batch_source.timestamp_field,
ts_start=order_record[event_timestamp] - location_fv.ttl,
ts_start=order_record[event_timestamp]
- get_feature_view_ttl(location_fv, default_ttl),
ts_end=order_record[event_timestamp],
filter_keys=["location_id"],
filter_values=[order_record["origin_id"]],
)
destination_record = find_asof_record(
location_records,
ts_key=location_fv.batch_source.timestamp_field,
ts_start=order_record[event_timestamp] - location_fv.ttl,
ts_start=order_record[event_timestamp]
- get_feature_view_ttl(location_fv, default_ttl),
ts_end=order_record[event_timestamp],
filter_keys=["location_id"],
filter_values=[order_record["destination_id"]],
)
global_record = find_asof_record(
global_records,
ts_key=global_fv.batch_source.timestamp_field,
ts_start=order_record[event_timestamp] - global_fv.ttl,
ts_start=order_record[event_timestamp]
- get_feature_view_ttl(global_fv, default_ttl),
ts_end=order_record[event_timestamp],
)

field_mapping_record = find_asof_record(
field_mapping_records,
ts_key=field_mapping_fv.batch_source.timestamp_field,
ts_start=order_record[event_timestamp] - field_mapping_fv.ttl,
ts_start=order_record[event_timestamp]
- get_feature_view_ttl(field_mapping_fv, default_ttl),
ts_end=order_record[event_timestamp],
)

Expand Down Expand Up @@ -666,6 +676,78 @@ def test_historical_features_persisting(
)


@pytest.mark.integration
@pytest.mark.universal_offline_stores
@pytest.mark.parametrize("full_feature_names", [True, False], ids=lambda v: str(v))
def test_historical_features_with_no_ttl(
environment, universal_data_sources, full_feature_names
):
store = environment.feature_store

(entities, datasets, data_sources) = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)

# Remove ttls.
feature_views.customer.ttl = timedelta(seconds=0)
feature_views.order.ttl = timedelta(seconds=0)
feature_views.global_fv.ttl = timedelta(seconds=0)
feature_views.field_mapping.ttl = timedelta(seconds=0)

store.apply([driver(), customer(), location(), *feature_views.values()])

entity_df = datasets.entity_df.drop(
columns=["order_id", "origin_id", "destination_id"]
)

job = store.get_historical_features(
entity_df=entity_df,
features=[
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
"order:order_is_success",
"global_stats:num_rides",
"global_stats:avg_ride_length",
"field_mapping:feature_name",
],
full_feature_names=full_feature_names,
)

event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL
expected_df = get_expected_training_df(
datasets.customer_df,
feature_views.customer,
datasets.driver_df,
feature_views.driver,
datasets.orders_df,
feature_views.order,
datasets.location_df,
feature_views.location,
datasets.global_df,
feature_views.global_fv,
datasets.field_mapping_df,
feature_views.field_mapping,
entity_df,
event_timestamp,
full_feature_names,
).drop(
columns=[
response_feature_name("conv_rate_plus_100", full_feature_names),
response_feature_name("conv_rate_plus_100_rounded", full_feature_names),
response_feature_name("avg_daily_trips", full_feature_names),
response_feature_name("conv_rate", full_feature_names),
"origin__temperature",
"destination__temperature",
]
)

assert_frame_equal(
expected_df,
job.to_df(),
keys=[event_timestamp, "driver_id", "customer_id"],
)


@pytest.mark.integration
@pytest.mark.universal_offline_stores
def test_historical_features_from_bigquery_sources_containing_backfills(environment):
Expand Down Expand Up @@ -781,6 +863,13 @@ def response_feature_name(feature: str, full_feature_names: bool) -> str:
return feature


def get_feature_view_ttl(
feature_view: FeatureView, default_ttl: timedelta
) -> timedelta:
"""Returns the ttl of a feature view if it is non-zero. Otherwise returns the specified default."""
return feature_view.ttl if feature_view.ttl else default_ttl


def assert_feature_service_correctness(
store, feature_service, full_feature_names, entity_df, expected_df, event_timestamp
):
Expand Down