From 3980e0c9a762a6ec3bcee5a0e9cdf532994bb1c9 Mon Sep 17 00:00:00 2001 From: Tornike Gurgenidze Date: Sat, 30 Mar 2024 14:32:34 +0400 Subject: [PATCH] feat: Rewrite ibis point-in-time-join w/o feast abstractions (#4023) * feat: refactor ibis point-in-time-join Signed-off-by: tokoko * fix formatting, linting Signed-off-by: tokoko --------- Signed-off-by: tokoko --- .../contrib/ibis_offline_store/ibis.py | 248 ++++++++++-------- .../requirements/py3.10-ci-requirements.txt | 13 +- .../requirements/py3.9-ci-requirements.txt | 13 +- .../unit/infra/offline_stores/test_ibis.py | 138 ++++++++++ setup.py | 1 + 5 files changed, 295 insertions(+), 118 deletions(-) create mode 100644 sdk/python/tests/unit/infra/offline_stores/test_ibis.py diff --git a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py index 72e0d970c6..8787d70158 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py +++ b/sdk/python/feast/infra/offline_stores/contrib/ibis_offline_store/ibis.py @@ -72,112 +72,6 @@ def _get_entity_df_event_timestamp_range( return entity_df_event_timestamp_range - @staticmethod - def _get_historical_features_one( - feature_view: FeatureView, - entity_table: Table, - feature_refs: List[str], - full_feature_names: bool, - timestamp_range: Tuple, - acc_table: Table, - event_timestamp_col: str, - ) -> Table: - fv_table: Table = ibis.read_parquet(feature_view.batch_source.name) - - for old_name, new_name in feature_view.batch_source.field_mapping.items(): - if old_name in fv_table.columns: - fv_table = fv_table.rename({new_name: old_name}) - - timestamp_field = feature_view.batch_source.timestamp_field - - # TODO mutate only if tz-naive - fv_table = fv_table.mutate( - **{ - timestamp_field: fv_table[timestamp_field].cast( - dt.Timestamp(timezone="UTC") - ) - } - ) - - full_name_prefix = feature_view.projection.name_alias or feature_view.name - - feature_refs = [ - fr.split(":")[1] - for fr in feature_refs - if fr.startswith(f"{full_name_prefix}:") - ] - - timestamp_range_start_minus_ttl = ( - timestamp_range[0] - feature_view.ttl - if feature_view.ttl and feature_view.ttl > timedelta(0, 0, 0, 0, 0, 0, 0) - else timestamp_range[0] - ) - - timestamp_range_start_minus_ttl = ibis.literal( - timestamp_range_start_minus_ttl.strftime("%Y-%m-%d %H:%M:%S.%f") - ).cast(dt.Timestamp(timezone="UTC")) - - timestamp_range_end = ibis.literal( - timestamp_range[1].strftime("%Y-%m-%d %H:%M:%S.%f") - ).cast(dt.Timestamp(timezone="UTC")) - - fv_table = fv_table.filter( - ibis.and_( - fv_table[timestamp_field] <= timestamp_range_end, - fv_table[timestamp_field] >= timestamp_range_start_minus_ttl, - ) - ) - - # join_key_map = feature_view.projection.join_key_map or {e.name: e.name for e in feature_view.entity_columns} - # predicates = [fv_table[k] == entity_table[v] for k, v in join_key_map.items()] - - if feature_view.projection.join_key_map: - predicates = [ - fv_table[k] == entity_table[v] - for k, v in feature_view.projection.join_key_map.items() - ] - else: - predicates = [ - fv_table[e.name] == entity_table[e.name] - for e in feature_view.entity_columns - ] - - predicates.append( - fv_table[timestamp_field] <= entity_table[event_timestamp_col] - ) - - fv_table = fv_table.inner_join( - entity_table, predicates, lname="", rname="{name}_y" - ) - - fv_table = ( - fv_table.group_by(by="entity_row_id") - .order_by(ibis.desc(fv_table[timestamp_field])) - .mutate(rn=ibis.row_number()) - ) - - fv_table = fv_table.filter(fv_table["rn"] == ibis.literal(0)) - - select_cols = ["entity_row_id"] - select_cols.extend(feature_refs) - fv_table = fv_table.select(select_cols) - - if full_feature_names: - fv_table = fv_table.rename( - {f"{full_name_prefix}__{feature}": feature for feature in feature_refs} - ) - - acc_table = acc_table.left_join( - fv_table, - predicates=[fv_table.entity_row_id == acc_table.entity_row_id], - lname="", - rname="{name}_yyyy", - ) - - acc_table = acc_table.drop(s.endswith("_yyyy")) - - return acc_table - @staticmethod def _to_utc(entity_df: pd.DataFrame, event_timestamp_col): entity_df_event_timestamp = entity_df.loc[ @@ -228,9 +122,11 @@ def get_historical_features( entity_schema=entity_schema, ) + # TODO get range with ibis timestamp_range = IbisOfflineStore._get_entity_df_event_timestamp_range( entity_df, event_timestamp_col ) + entity_df = IbisOfflineStore._to_utc(entity_df, event_timestamp_col) entity_table = ibis.memtable(entity_df) @@ -238,20 +134,61 @@ def get_historical_features( entity_table, feature_views, event_timestamp_col ) - res: Table = entity_table + def read_fv(feature_view, feature_refs, full_feature_names): + fv_table: Table = ibis.read_parquet(feature_view.batch_source.name) - for fv in feature_views: - res = IbisOfflineStore._get_historical_features_one( - fv, - entity_table, + for old_name, new_name in feature_view.batch_source.field_mapping.items(): + if old_name in fv_table.columns: + fv_table = fv_table.rename({new_name: old_name}) + + timestamp_field = feature_view.batch_source.timestamp_field + + # TODO mutate only if tz-naive + fv_table = fv_table.mutate( + **{ + timestamp_field: fv_table[timestamp_field].cast( + dt.Timestamp(timezone="UTC") + ) + } + ) + + full_name_prefix = feature_view.projection.name_alias or feature_view.name + + feature_refs = [ + fr.split(":")[1] + for fr in feature_refs + if fr.startswith(f"{full_name_prefix}:") + ] + + if full_feature_names: + fv_table = fv_table.rename( + { + f"{full_name_prefix}__{feature}": feature + for feature in feature_refs + } + ) + + feature_refs = [ + f"{full_name_prefix}__{feature}" for feature in feature_refs + ] + + return ( + fv_table, + feature_view.batch_source.timestamp_field, + feature_view.projection.join_key_map + or {e.name: e.name for e in feature_view.entity_columns}, feature_refs, - full_feature_names, - timestamp_range, - res, - event_timestamp_col, + feature_view.ttl, ) - res = res.drop("entity_row_id") + res = point_in_time_join( + entity_table=entity_table, + feature_tables=[ + read_fv(feature_view, feature_refs, full_feature_names) + for feature_view in feature_views + ], + event_timestamp_col=event_timestamp_col, + ) return IbisRetrievalJob( res, @@ -285,6 +222,10 @@ def pull_all_from_table_or_query( table = table.select(*fields) + # TODO get rid of this fix + if "__log_date" in table.columns: + table = table.drop("__log_date") + table = table.filter( ibis.and_( table[timestamp_field] >= ibis.literal(start_date), @@ -320,6 +261,7 @@ def write_logged_features( else: kwargs = {} + # TODO always write to directory table.to_parquet( f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs ) @@ -405,3 +347,77 @@ def persist( @property def metadata(self) -> Optional[RetrievalMetadata]: return self._metadata + + +def point_in_time_join( + entity_table: Table, + feature_tables: List[Tuple[Table, str, Dict[str, str], List[str], timedelta]], + event_timestamp_col="event_timestamp", +): + # TODO handle ttl + all_entities = [event_timestamp_col] + for feature_table, timestamp_field, join_key_map, _, _ in feature_tables: + all_entities.extend(join_key_map.values()) + + r = ibis.literal("") + + for e in set(all_entities): + r = r.concat(entity_table[e].cast("string")) # type: ignore + + entity_table = entity_table.mutate(entity_row_id=r) + + acc_table = entity_table + + for ( + feature_table, + timestamp_field, + join_key_map, + feature_refs, + ttl, + ) in feature_tables: + predicates = [ + feature_table[k] == entity_table[v] for k, v in join_key_map.items() + ] + + predicates.append( + feature_table[timestamp_field] <= entity_table[event_timestamp_col], + ) + + if ttl: + predicates.append( + feature_table[timestamp_field] + >= entity_table[event_timestamp_col] - ibis.literal(ttl) + ) + + feature_table = feature_table.inner_join( + entity_table, predicates, lname="", rname="{name}_y" + ) + + feature_table = feature_table.drop(s.endswith("_y")) + + feature_table = ( + feature_table.group_by(by="entity_row_id") + .order_by(ibis.desc(feature_table[timestamp_field])) + .mutate(rn=ibis.row_number()) + ) + + feature_table = feature_table.filter( + feature_table["rn"] == ibis.literal(0) + ).drop("rn") + + select_cols = ["entity_row_id"] + select_cols.extend(feature_refs) + feature_table = feature_table.select(select_cols) + + acc_table = acc_table.left_join( + feature_table, + predicates=[feature_table.entity_row_id == acc_table.entity_row_id], + lname="", + rname="{name}_yyyy", + ) + + acc_table = acc_table.drop(s.endswith("_yyyy")) + + acc_table = acc_table.drop("entity_row_id") + + return acc_table diff --git a/sdk/python/requirements/py3.10-ci-requirements.txt b/sdk/python/requirements/py3.10-ci-requirements.txt index 4b1e26e1f3..ac1994da37 100644 --- a/sdk/python/requirements/py3.10-ci-requirements.txt +++ b/sdk/python/requirements/py3.10-ci-requirements.txt @@ -164,6 +164,12 @@ docker==7.0.0 # testcontainers docutils==0.19 # via sphinx +duckdb==0.10.1 + # via + # duckdb-engine + # ibis-framework +duckdb-engine==0.11.2 + # via ibis-framework entrypoints==0.4 # via altair exceptiongroup==1.2.0 @@ -310,7 +316,7 @@ httpx==0.27.0 # via # feast (setup.py) # jupyterlab -ibis-framework==8.0.0 +ibis-framework[duckdb]==8.0.0 # via # feast (setup.py) # ibis-substrait @@ -848,8 +854,13 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx sqlalchemy[mypy]==1.4.52 # via + # duckdb-engine # feast (setup.py) + # ibis-framework # sqlalchemy + # sqlalchemy-views +sqlalchemy-views==0.3.2 + # via ibis-framework sqlalchemy2-stubs==0.0.2a38 # via sqlalchemy sqlglot==20.11.0 diff --git a/sdk/python/requirements/py3.9-ci-requirements.txt b/sdk/python/requirements/py3.9-ci-requirements.txt index 99dee08c05..367b5dc050 100644 --- a/sdk/python/requirements/py3.9-ci-requirements.txt +++ b/sdk/python/requirements/py3.9-ci-requirements.txt @@ -164,6 +164,12 @@ docker==7.0.0 # testcontainers docutils==0.19 # via sphinx +duckdb==0.10.1 + # via + # duckdb-engine + # ibis-framework +duckdb-engine==0.11.2 + # via ibis-framework entrypoints==0.4 # via altair exceptiongroup==1.2.0 @@ -310,7 +316,7 @@ httpx==0.27.0 # via # feast (setup.py) # jupyterlab -ibis-framework==8.0.0 +ibis-framework[duckdb]==8.0.0 # via # feast (setup.py) # ibis-substrait @@ -858,8 +864,13 @@ sphinxcontrib-serializinghtml==1.1.10 # via sphinx sqlalchemy[mypy]==1.4.52 # via + # duckdb-engine # feast (setup.py) + # ibis-framework # sqlalchemy + # sqlalchemy-views +sqlalchemy-views==0.3.2 + # via ibis-framework sqlalchemy2-stubs==0.0.2a38 # via sqlalchemy sqlglot==20.11.0 diff --git a/sdk/python/tests/unit/infra/offline_stores/test_ibis.py b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py new file mode 100644 index 0000000000..5f105e2af7 --- /dev/null +++ b/sdk/python/tests/unit/infra/offline_stores/test_ibis.py @@ -0,0 +1,138 @@ +from datetime import datetime, timedelta +from typing import Dict, List, Tuple + +import ibis +import pyarrow as pa + +from feast.infra.offline_stores.contrib.ibis_offline_store.ibis import ( + point_in_time_join, +) + + +def pa_datetime(year, month, day): + return pa.scalar(datetime(year, month, day), type=pa.timestamp("s", tz="UTC")) + + +def customer_table(): + return pa.Table.from_arrays( + arrays=[ + pa.array([1, 1, 2]), + pa.array( + [ + pa_datetime(2024, 1, 1), + pa_datetime(2024, 1, 2), + pa_datetime(2024, 1, 1), + ] + ), + ], + names=["customer_id", "event_timestamp"], + ) + + +def features_table_1(): + return pa.Table.from_arrays( + arrays=[ + pa.array([1, 1, 1, 2]), + pa.array( + [ + pa_datetime(2023, 12, 31), + pa_datetime(2024, 1, 2), + pa_datetime(2024, 1, 3), + pa_datetime(2023, 1, 3), + ] + ), + pa.array([11, 22, 33, 22]), + ], + names=["customer_id", "event_timestamp", "feature1"], + ) + + +def point_in_time_join_brute( + entity_table: pa.Table, + feature_tables: List[Tuple[pa.Table, str, Dict[str, str], List[str], timedelta]], + event_timestamp_col="event_timestamp", +): + ret_fields = [entity_table.schema.field(n) for n in entity_table.schema.names] + + from operator import itemgetter + + ret = entity_table.to_pydict() + batch_dict = entity_table.to_pydict() + + for i, row_timestmap in enumerate(batch_dict[event_timestamp_col]): + for ( + feature_table, + timestamp_key, + join_key_map, + feature_refs, + ttl, + ) in feature_tables: + if i == 0: + ret_fields.extend( + [ + feature_table.schema.field(f) + for f in feature_table.schema.names + if f not in join_key_map.values() and f != timestamp_key + ] + ) + + def check_equality(ft_dict, batch_dict, x, y): + return all( + [ft_dict[k][x] == batch_dict[v][y] for k, v in join_key_map.items()] + ) + + ft_dict = feature_table.to_pydict() + found_matches = [ + (j, ft_dict[timestamp_key][j]) + for j in range(entity_table.num_rows) + if check_equality(ft_dict, batch_dict, j, i) + and ft_dict[timestamp_key][j] <= row_timestmap + and ft_dict[timestamp_key][j] >= row_timestmap - ttl + ] + + index_found = ( + max(found_matches, key=itemgetter(1))[0] if found_matches else None + ) + for col in ft_dict.keys(): + if col not in feature_refs: + continue + + if col not in ret: + ret[col] = [] + + if index_found is not None: + ret[col].append(ft_dict[col][index_found]) + else: + ret[col].append(None) + + return pa.Table.from_pydict(ret, schema=pa.schema(ret_fields)) + + +def test_point_in_time_join(): + expected = point_in_time_join_brute( + customer_table(), + feature_tables=[ + ( + features_table_1(), + "event_timestamp", + {"customer_id": "customer_id"}, + ["feature1"], + timedelta(days=10), + ) + ], + ) + + actual = point_in_time_join( + ibis.memtable(customer_table()), + feature_tables=[ + ( + ibis.memtable(features_table_1()), + "event_timestamp", + {"customer_id": "customer_id"}, + ["feature1"], + timedelta(days=10), + ) + ], + ).to_pyarrow() + + assert actual.equals(expected) diff --git a/setup.py b/setup.py index f142da7ec3..2d7bf63778 100644 --- a/setup.py +++ b/setup.py @@ -212,6 +212,7 @@ + HAZELCAST_REQUIRED + IBIS_REQUIRED + GRPCIO_REQUIRED + + DUCKDB_REQUIRED ) DOCS_REQUIRED = CI_REQUIRED