From 844bb83b84c77b5c0cfed147d918ffcae1077c36 Mon Sep 17 00:00:00 2001 From: niklasvm Date: Sat, 3 Sep 2022 13:24:38 +0200 Subject: [PATCH] implement to_remote_storage method Signed-off-by: niklasvm --- .../contrib/spark_offline_store/spark.py | 77 ++++++++++++++++++- .../spark_offline_store/tests/data_source.py | 4 + 2 files changed, 80 insertions(+), 1 deletion(-) diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py index 58519014b44..0a4ec05c23d 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py @@ -1,4 +1,6 @@ +import os import tempfile +import uuid import warnings from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Tuple, Union @@ -13,6 +15,7 @@ from pyspark import SparkConf from pyspark.sql import SparkSession from pytz import utc +from sdk.python.feast.infra.utils import aws_utils from feast import FeatureView, OnDemandFeatureView from feast.data_source import DataSource @@ -46,6 +49,12 @@ class SparkOfflineStoreConfig(FeastConfigBaseModel): """ Configuration overlay for the spark session """ # sparksession is not serializable and we dont want to pass it around as an argument + staging_location: Optional[StrictStr] = None + """ Remote path for batch materialization jobs""" + + region: Optional[StrictStr] = None + """ AWS Region if applicable for s3-based staging locations""" + class SparkOfflineStore(OfflineStore): @staticmethod @@ -105,6 +114,7 @@ def pull_latest_from_table_or_query( return SparkRetrievalJob( spark_session=spark_session, query=query, + config=config, full_feature_names=False, on_demand_feature_views=None, ) @@ -129,6 +139,7 @@ def get_historical_features( "Some functionality may still be unstable so functionality can change in the future.", RuntimeWarning, ) + spark_session = get_spark_session_or_start_new_with_repoconfig( store_config=config.offline_store ) @@ -192,6 +203,7 @@ def get_historical_features( min_event_timestamp=entity_df_event_timestamp_range[0], max_event_timestamp=entity_df_event_timestamp_range[1], ), + config=config, ) @staticmethod @@ -286,7 +298,10 @@ def pull_all_from_table_or_query( """ return SparkRetrievalJob( - spark_session=spark_session, query=query, full_feature_names=False + spark_session=spark_session, + query=query, + full_feature_names=False, + config=config, ) @@ -296,6 +311,7 @@ def __init__( spark_session: SparkSession, query: str, full_feature_names: bool, + config: RepoConfig, on_demand_feature_views: Optional[List[OnDemandFeatureView]] = None, metadata: Optional[RetrievalMetadata] = None, ): @@ -305,6 +321,7 @@ def __init__( self._full_feature_names = full_feature_names self._on_demand_feature_views = on_demand_feature_views or [] self._metadata = metadata + self._config = config @property def full_feature_names(self) -> bool: @@ -342,6 +359,53 @@ def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False): raise ValueError("Cannot persist, table_name is not defined") self.to_spark_df().createOrReplaceTempView(table_name) + def supports_remote_storage_export(self) -> bool: + return self._config.offline_store.staging_location is not None + + def to_remote_storage(self) -> List[str]: + """Currently only works for local and s3-based staging locations""" + if self.supports_remote_storage_export(): + + sdf: pyspark.sql.DataFrame = self.to_spark_df() + + if self._config.offline_store.staging_location.startswith("file://"): + local_file_staging_location = os.path.abspath( + self._config.offline_store.staging_location + ) + + # write to staging location + output_uri = os.path.join( + str(local_file_staging_location), str(uuid.uuid4()) + ) + sdf.write.parquet(output_uri) + + return _list_files_in_folder(output_uri) + elif self._config.offline_store.staging_location.startswith("s3://"): + + spark_compatible_s3_staging_location = ( + self._config.offline_store.staging_location.replace( + "s3://", "s3a://" + ) + ) + + # write to staging location + output_uri = os.path.join( + str(spark_compatible_s3_staging_location), str(uuid.uuid4()) + ) + sdf.write.parquet(output_uri) + + return aws_utils.list_s3_files( + self._config.offline_store.region, output_uri + ) + + else: + raise NotImplementedError( + "to_remote_storage is only implemented for file:// and s3:// uri schemes" + ) + + else: + raise NotImplementedError() + @property def metadata(self) -> Optional[RetrievalMetadata]: """ @@ -444,6 +508,17 @@ def _format_datetime(t: datetime) -> str: return dt +def _list_files_in_folder(folder): + """List full filenames in a folder""" + files = [] + for file in os.listdir(folder): + filename = os.path.join(folder, file) + if os.path.isfile(filename): + files.append(filename) + + return files + + def _cast_data_frame( df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame ) -> pyspark.sql.DataFrame: diff --git a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py index ab1acbef73e..64a2a01cee4 100644 --- a/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py +++ b/sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py @@ -58,6 +58,10 @@ def create_offline_store_config(self): self.spark_offline_store_config = SparkOfflineStoreConfig() self.spark_offline_store_config.type = "spark" self.spark_offline_store_config.spark_conf = self.spark_conf + self.spark_offline_store_config.staging_location = "file://" + str( + tempfile.TemporaryDirectory() + ) + self.spark_offline_store_config.region = "eu-west-1" return self.spark_offline_store_config def create_data_source(