Skip to content

Commit

Permalink
Add persistent checkpoint in historical retrieval (#91)
Browse files Browse the repository at this point in the history
Signed-off-by: Oleksii Moskalenko <moskalenko.alexey@gmail.com>
  • Loading branch information
pyalex committed Aug 18, 2021
1 parent 7e99515 commit b033508
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 2 deletions.
7 changes: 6 additions & 1 deletion python/feast_spark/pyspark/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def __init__(
entity_source: Dict,
destination: Dict,
extra_packages: Optional[List[str]] = None,
checkpoint_path: Optional[str] = None,
):
"""
Args:
Expand Down Expand Up @@ -265,6 +266,7 @@ def __init__(
self._entity_source = entity_source
self._destination = destination
self._extra_packages = extra_packages if extra_packages else []
self._checkpoint_path = checkpoint_path

def get_name(self) -> str:
all_feature_tables_names = [ft["name"] for ft in self._feature_tables]
Expand All @@ -285,7 +287,7 @@ def get_arguments(self) -> List[str]:
def json_b64_encode(obj) -> str:
return b64encode(json.dumps(obj).encode("utf8")).decode("ascii")

return [
args = [
"--feature-tables",
json_b64_encode(self._feature_tables),
"--feature-tables-sources",
Expand All @@ -295,6 +297,9 @@ def json_b64_encode(obj) -> str:
"--destination",
json_b64_encode(self._destination),
]
if self._checkpoint_path:
args.extend(["--checkpoint", self._checkpoint_path])
return args

def get_destination_path(self) -> str:
return self._destination["path"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from logging.config import dictConfig
from typing import Any, Dict, List, NamedTuple, Optional

from pyspark import SparkContext
from pyspark.sql import DataFrame, SparkSession, Window
from pyspark.sql import functions as func
from pyspark.sql.functions import (
Expand Down Expand Up @@ -602,8 +603,9 @@ def filter_feature_table_by_time_range(
)
.where(col("distance") == col("min_distance"))
.select(time_range_filtered_df.columns + [ENTITY_EVENT_TIMESTAMP_ALIAS])
.localCheckpoint()
)
if SparkContext._active_spark_context._jsc.sc().getCheckpointDir().nonEmpty():
time_range_filtered_df = time_range_filtered_df.checkpoint()

return time_range_filtered_df

Expand Down Expand Up @@ -848,6 +850,7 @@ def _get_args():
parser.add_argument(
"--destination", type=str, help="Retrieval result destination in json string"
)
parser.add_argument("--checkpoint", type=str, help="Spark Checkpoint location")
return parser.parse_args()


Expand Down Expand Up @@ -876,6 +879,9 @@ def json_b64_decode(s: str) -> Any:
feature_tables_sources_conf = json_b64_decode(args.feature_tables_sources)
entity_source_conf = json_b64_decode(args.entity_source)
destination_conf = json_b64_decode(args.destination)
if args.checkpoint:
spark.sparkContext.setCheckpointDir(args.checkpoint)

try:
start_job(
spark,
Expand Down
1 change: 1 addition & 0 deletions python/feast_spark/pyspark/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ def start_historical_feature_retrieval_job(
],
destination={"format": output_format, "path": output_path},
extra_packages=extra_packages,
checkpoint_path=client.config.get(opt.CHECKPOINT_PATH),
)
)

Expand Down

0 comments on commit b033508

Please sign in to comment.