Skip to content

Commit

Permalink
Keep only unique entity rows in entity dataframe in historical retrie…
Browse files Browse the repository at this point in the history
…val job (#93)

Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com>
  • Loading branch information
pyalex committed Aug 19, 2021
1 parent 8015e86 commit 2a48612
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions python/feast_spark/pyspark/historical_feature_retrieval_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -579,16 +579,18 @@ def filter_feature_table_by_time_range(

time_range_filtered_df = feature_table_df.filter(feature_table_timestamp_filter)

entities_projected = (
entity_df.withColumnRenamed(
entity_event_timestamp_column, ENTITY_EVENT_TIMESTAMP_ALIAS
)
.select(feature_table.entity_names + [ENTITY_EVENT_TIMESTAMP_ALIAS])
.distinct()
)

time_range_filtered_df = (
time_range_filtered_df.repartition(200)
.join(
broadcast(
entity_df.withColumnRenamed(
entity_event_timestamp_column, ENTITY_EVENT_TIMESTAMP_ALIAS
)
),
on=feature_table.entity_names,
how="inner",
broadcast(entities_projected), on=feature_table.entity_names, how="inner",
)
.withColumn(
"distance",
Expand All @@ -605,7 +607,6 @@ def filter_feature_table_by_time_range(
),
)
.where(col("distance") == col("min_distance"))
.select(time_range_filtered_df.columns + [ENTITY_EVENT_TIMESTAMP_ALIAS])
)
if SparkContext._active_spark_context._jsc.sc().getCheckpointDir().nonEmpty():
time_range_filtered_df = time_range_filtered_df.checkpoint()
Expand Down

0 comments on commit 2a48612

Please sign in to comment.