Skip to content

Commit

Permalink
Speed-up join in historical retrieval by replacing pandas with native…
Browse files Browse the repository at this point in the history
… spark (#89)
  • Loading branch information
pyalex committed Aug 16, 2021
1 parent 9f2f084 commit 96141ae
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 164 deletions.
123 changes: 49 additions & 74 deletions python/feast_spark/pyspark/historical_feature_retrieval_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,13 @@
from logging.config import dictConfig
from typing import Any, Dict, List, NamedTuple, Optional

import numpy as np
import pandas as pd
from pyspark.sql import DataFrame, SparkSession, Window
from pyspark.sql.functions import (
col,
expr,
monotonically_increasing_id,
row_number,
struct,
)
from pyspark.sql.pandas.functions import PandasUDFType, pandas_udf
from pyspark.sql.types import BooleanType, LongType
from pyspark.sql import functions as func
from pyspark.sql.functions import col, monotonically_increasing_id, row_number
from pyspark.sql.types import LongType

EVENT_TIMESTAMP_ALIAS = "event_timestamp"
ENTITY_EVENT_TIMESTAMP_ALIAS = "event_timestamp_entity"
CREATED_TIMESTAMP_ALIAS = "created_timestamp"


Expand Down Expand Up @@ -283,14 +276,13 @@ class FeatureTable(NamedTuple):
entities (List[Field]): Primary keys for the features.
features (List[Field]): Feature list.
max_age (int): In seconds. determines the lower bound of the timestamp of the retrieved feature.
If not specified, this would be unbounded
project (str): Feast project name.
"""

name: str
entities: List[Field]
features: List[Field]
max_age: Optional[int] = None
max_age: int
project: Optional[str] = None

@property
Expand Down Expand Up @@ -427,14 +419,10 @@ def as_of_join(

join_cond = (
entity_with_id[entity_event_timestamp_column]
>= aliased_feature_table_df[feature_event_timestamp_column_with_prefix]
== aliased_feature_table_df[
f"{feature_table.name}__{ENTITY_EVENT_TIMESTAMP_ALIAS}"
]
)
if feature_table.max_age:
join_cond = join_cond & (
aliased_feature_table_df[feature_event_timestamp_column_with_prefix]
>= entity_with_id[entity_event_timestamp_column]
- expr(f"INTERVAL {feature_table.max_age} seconds")
)

for key in feature_table.entity_names:
join_cond = join_cond & (
Expand Down Expand Up @@ -557,50 +545,19 @@ class SchemaError(Exception):
pass


def _make_time_filter_pandas_udf(
spark: SparkSession,
entity_pandas: pd.DataFrame,
feature_table: FeatureTable,
entity_event_timestamp_column: str,
):
entity_br = spark.sparkContext.broadcast(
entity_pandas.rename(
columns={entity_event_timestamp_column: EVENT_TIMESTAMP_ALIAS}
)
)
entity_names = feature_table.entity_names
max_age = feature_table.max_age

@pandas_udf(BooleanType(), PandasUDFType.SCALAR)
def within_time_boundaries(features: pd.DataFrame) -> pd.Series:
features["_row_id"] = np.arange(len(features))
merged = features.merge(
entity_br.value,
how="left",
on=entity_names,
suffixes=("_feature", "_entity"),
)
merged["distance"] = (
merged[f"{EVENT_TIMESTAMP_ALIAS}_entity"]
- merged[f"{EVENT_TIMESTAMP_ALIAS}_feature"]
)
merged["within"] = merged["distance"].dt.total_seconds().between(0, max_age)

return merged.groupby(["_row_id"]).max()["within"]

return within_time_boundaries


def _filter_feature_table_by_time_range(
spark: SparkSession,
def filter_feature_table_by_time_range(
feature_table_df: DataFrame,
feature_table: FeatureTable,
feature_event_timestamp_column: str,
entity_pandas: pd.DataFrame,
entity_df: DataFrame,
entity_event_timestamp_column: str,
):
entity_max_timestamp = entity_pandas[entity_event_timestamp_column].max()
entity_min_timestamp = entity_pandas[entity_event_timestamp_column].min()
) -> DataFrame:
entity_max_timestamp = entity_df.agg(
{entity_event_timestamp_column: "max"}
).collect()[0][0]
entity_min_timestamp = entity_df.agg(
{entity_event_timestamp_column: "min"}
).collect()[0][0]

feature_table_timestamp_filter = (
col(feature_event_timestamp_column).between(
Expand All @@ -613,17 +570,32 @@ def _filter_feature_table_by_time_range(

time_range_filtered_df = feature_table_df.filter(feature_table_timestamp_filter)

if feature_table.max_age:
within_time_boundaries_udf = _make_time_filter_pandas_udf(
spark, entity_pandas, feature_table, entity_event_timestamp_column
time_range_filtered_df = (
time_range_filtered_df.join(
entity_df.withColumnRenamed(
entity_event_timestamp_column, ENTITY_EVENT_TIMESTAMP_ALIAS
),
on=feature_table.entity_names,
how="inner",
)

time_range_filtered_df = time_range_filtered_df.withColumn(
"within_time_boundaries",
within_time_boundaries_udf(
struct(feature_table.entity_names + [feature_event_timestamp_column])
.withColumn(
"distance",
col(ENTITY_EVENT_TIMESTAMP_ALIAS).cast("long")
- col(EVENT_TIMESTAMP_ALIAS).cast("long"),
)
.where((col("distance") >= 0) & (col("distance") <= feature_table.max_age))
.withColumn(
"min_distance",
func.min("distance").over(
Window.partitionBy(
feature_table.entity_names + [ENTITY_EVENT_TIMESTAMP_ALIAS]
)
),
).filter("within_time_boundaries = true")
)
.where(col("distance") == col("min_distance"))
.select(time_range_filtered_df.columns + [ENTITY_EVENT_TIMESTAMP_ALIAS])
.localCheckpoint()
)

return time_range_filtered_df

Expand Down Expand Up @@ -807,15 +779,14 @@ def retrieve_historical_features(
f"{expected_entity.name} ({expected_entity.spark_type}) is not present in the entity dataframe."
)

entity_pandas = entity_df.toPandas()
entity_df.cache()

feature_table_dfs = [
_filter_feature_table_by_time_range(
spark,
filter_feature_table_by_time_range(
feature_table_df,
feature_table,
feature_table_source.event_timestamp_column,
entity_pandas,
entity_df,
entity_source.event_timestamp_column,
)
for feature_table_df, feature_table, feature_table_source in zip(
Expand Down Expand Up @@ -873,11 +844,15 @@ def _get_args():


def _feature_table_from_dict(dct: Dict[str, Any]) -> FeatureTable:
assert (
dct.get("max_age") is not None and dct["max_age"] > 0
), "FeatureTable.maxAge must not be None and should be a positive number"

return FeatureTable(
name=dct["name"],
entities=[Field(**e) for e in dct["entities"]],
features=[Field(**f) for f in dct["features"]],
max_age=dct.get("max_age"),
max_age=dct["max_age"],
project=dct.get("project"),
)

Expand Down

0 comments on commit 96141ae

Please sign in to comment.