Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accept Pandas dataframe as input for historical feature retrieval #1071

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
42 changes: 38 additions & 4 deletions sdk/python/feast/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,14 @@
# limitations under the License.
import logging
import multiprocessing
import os
import shutil
import tempfile
import uuid
from datetime import datetime
from itertools import groupby
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse

import grpc
import pandas as pd
Expand All @@ -34,6 +38,7 @@
CONFIG_SERVING_URL_KEY,
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT,
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION,
CONFIG_SPARK_STAGING_LOCATION,
FEAST_DEFAULT_OPTIONS,
)
from feast.core.CoreService_pb2 import (
Expand Down Expand Up @@ -88,6 +93,7 @@
GetOnlineFeaturesRequestV2,
)
from feast.serving.ServingService_pb2_grpc import ServingServiceStub
from feast.staging.storage_client import get_staging_client

_logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -780,7 +786,7 @@ def get_online_features(
def get_historical_features(
self,
feature_refs: List[str],
entity_source: Union[FileSource, BigQuerySource],
entity_source: Union[pd.DataFrame, FileSource, BigQuerySource],
project: str = None,
) -> RetrievalJob:
"""
Expand All @@ -791,9 +797,14 @@ def get_historical_features(
Each feature reference should have the following format:
"feature_table:feature" where "feature_table" & "feature" refer to
the feature and feature table names respectively.
entity_source (Union[FileSource, BigQuerySource]): Source for the entity rows.
The user needs to make sure that the source is accessible from the Spark cluster
that will be used for the retrieval job.
entity_source (Union[pd.DataFrame, FileSource, BigQuerySource]): Source for the entity rows.
If entity_source is a Panda DataFrame, the dataframe will be exported to the staging
location as parquet file. It is also assumed that the column event_timestamp is present
in the dataframe, and is of type datetime without timezone information.

The user needs to make sure that the source (or staging location, if entity_source is
a Panda DataFrame) is accessible from the Spark cluster that will be used for the
retrieval job.
project: Specifies the project that contains the feature tables
which the requested features belong to.

Expand Down Expand Up @@ -821,6 +832,29 @@ def get_historical_features(
)
output_format = self._config.get(CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT)

if isinstance(entity_source, pd.DataFrame):
staging_location = self._config.get(CONFIG_SPARK_STAGING_LOCATION)
entity_staging_uri = urlparse(
os.path.join(staging_location, str(uuid.uuid4()))
)
staging_client = get_staging_client(entity_staging_uri.scheme)
with tempfile.NamedTemporaryFile() as df_export_path:
entity_source.to_parquet(df_export_path.name)
bucket = (
None
if entity_staging_uri.scheme == "fs"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's fs scheme? maybe file?

else entity_staging_uri.netloc
)
staging_client.upload_file(
df_export_path.name, bucket, entity_staging_uri.path
)
entity_source = FileSource(
"event_timestamp",
"created_timestamp",
ParquetFormat(),
entity_staging_uri.path,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should be full url

)

return start_historical_feature_retrieval_job(
self, entity_source, feature_tables, output_format, output_location
)
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/feast/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class AuthProvider(Enum):
# Spark Job Config
CONFIG_SPARK_LAUNCHER = "spark_launcher" # standalone, dataproc, emr

CONFIG_SPARK_STAGING_LOCATION = "spark_staging_location"

CONFIG_SPARK_INGESTION_JOB_JAR = "spark_ingestion_jar"

CONFIG_SPARK_STANDALONE_MASTER = "spark_standalone_master"
Expand All @@ -75,7 +77,6 @@ class AuthProvider(Enum):
CONFIG_SPARK_DATAPROC_CLUSTER_NAME = "dataproc_cluster_name"
CONFIG_SPARK_DATAPROC_PROJECT = "dataproc_project"
CONFIG_SPARK_DATAPROC_REGION = "dataproc_region"
CONFIG_SPARK_DATAPROC_STAGING_LOCATION = "dataproc_staging_location"

CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_FORMAT = "historical_feature_output_format"
CONFIG_SPARK_HISTORICAL_FEATURE_OUTPUT_LOCATION = "historical_feature_output_location"
Expand All @@ -87,7 +88,6 @@ class AuthProvider(Enum):
CONFIG_SPARK_EMR_REGION = "emr_region"
CONFIG_SPARK_EMR_CLUSTER_ID = "emr_cluster_id"
CONFIG_SPARK_EMR_CLUSTER_TEMPLATE_PATH = "emr_cluster_template_path"
CONFIG_SPARK_EMR_STAGING_LOCATION = "emr_staging_location"
CONFIG_SPARK_EMR_LOG_LOCATION = "emr_log_location"


Expand Down
7 changes: 3 additions & 4 deletions sdk/python/feast/pyspark/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
CONFIG_SPARK_DATAPROC_CLUSTER_NAME,
CONFIG_SPARK_DATAPROC_PROJECT,
CONFIG_SPARK_DATAPROC_REGION,
CONFIG_SPARK_DATAPROC_STAGING_LOCATION,
CONFIG_SPARK_EMR_CLUSTER_ID,
CONFIG_SPARK_EMR_CLUSTER_TEMPLATE_PATH,
CONFIG_SPARK_EMR_LOG_LOCATION,
CONFIG_SPARK_EMR_REGION,
CONFIG_SPARK_EMR_STAGING_LOCATION,
CONFIG_SPARK_HOME,
CONFIG_SPARK_INGESTION_JOB_JAR,
CONFIG_SPARK_LAUNCHER,
CONFIG_SPARK_STAGING_LOCATION,
CONFIG_SPARK_STANDALONE_MASTER,
)
from feast.data_source import BigQuerySource, DataSource, FileSource, KafkaSource
Expand Down Expand Up @@ -54,7 +53,7 @@ def _dataproc_launcher(config: Config) -> JobLauncher:

return gcloud.DataprocClusterLauncher(
config.get(CONFIG_SPARK_DATAPROC_CLUSTER_NAME),
config.get(CONFIG_SPARK_DATAPROC_STAGING_LOCATION),
config.get(CONFIG_SPARK_STAGING_LOCATION),
config.get(CONFIG_SPARK_DATAPROC_REGION),
config.get(CONFIG_SPARK_DATAPROC_PROJECT),
)
Expand All @@ -71,7 +70,7 @@ def _get_optional(option):
region=config.get(CONFIG_SPARK_EMR_REGION),
existing_cluster_id=_get_optional(CONFIG_SPARK_EMR_CLUSTER_ID),
new_cluster_template_path=_get_optional(CONFIG_SPARK_EMR_CLUSTER_TEMPLATE_PATH),
staging_location=config.get(CONFIG_SPARK_EMR_STAGING_LOCATION),
staging_location=config.get(CONFIG_SPARK_STAGING_LOCATION),
emr_log_location=config.get(CONFIG_SPARK_EMR_LOG_LOCATION),
)

Expand Down