Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Implement spark offline store offline_write_batch method #3076

Merged
merged 6 commits into from
Aug 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import tempfile
import warnings
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas
Expand Down Expand Up @@ -191,6 +191,68 @@ def get_historical_features(
),
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, SparkOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when spark type required"
)
if not isinstance(feature_view.batch_source, SparkSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not spark source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
)

spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
)

if feature_view.batch_source.path:
# write data to disk so that it can be loaded into spark (for preserving column types)
with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file:
print(tmp_file.name)
pq.write_table(table, tmp_file.name)

# load data
df_batch = spark_session.read.parquet(tmp_file.name)

# load existing data to get spark table schema
df_existing = spark_session.read.format(
feature_view.batch_source.file_format
).load(feature_view.batch_source.path)

# cast columns if applicable
df_batch = _cast_data_frame(df_batch, df_existing)

df_batch.write.format(feature_view.batch_source.file_format).mode(
"append"
).save(feature_view.batch_source.path)
elif feature_view.batch_source.query:
raise NotImplementedError(
"offline_write_batch not implemented for batch sources specified by query"
)
else:
raise NotImplementedError(
"offline_write_batch not implemented for batch sources specified by a table"
)

@staticmethod
@log_exceptions_and_usage(offline_store="spark")
def pull_all_from_table_or_query(
Expand Down Expand Up @@ -388,6 +450,24 @@ def _format_datetime(t: datetime) -> str:
return dt


def _cast_data_frame(
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
) -> pyspark.sql.DataFrame:
"""Convert new dataframe's columns to the same types as existing dataframe while preserving the order of columns"""
existing_dtypes = {k: v for k, v in df_existing.dtypes}
new_dtypes = {k: v for k, v in df_new.dtypes}

select_expression = []
for col, new_type in new_dtypes.items():
existing_type = existing_dtypes[col]
if new_type != existing_type:
select_expression.append(f"cast({col} as {existing_type}) as {col}")
else:
select_expression.append(col)

return df_new.selectExpr(*select_expression)


MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """
/*
Compute a deterministic hash for the `left_table_query_string` that will be used throughout
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
import shutil
import tempfile
import uuid
from typing import Dict, List

Expand Down Expand Up @@ -48,6 +51,8 @@ def __init__(self, project_name: str, *args, **kwargs):

def teardown(self):
self.spark_session.stop()
for table in self.tables:
shutil.rmtree(table)

def create_offline_store_config(self):
self.spark_offline_store_config = SparkOfflineStoreConfig()
Expand Down Expand Up @@ -86,11 +91,17 @@ def create_data_source(
.appName("pytest-pyspark-local-testing")
.getOrCreate()
)
self.spark_session.createDataFrame(df).createOrReplaceTempView(destination_name)
self.tables.append(destination_name)

temp_dir = tempfile.mkdtemp(prefix="spark_offline_store_test_data")

path = os.path.join(temp_dir, destination_name)
self.tables.append(path)

self.spark_session.createDataFrame(df).write.parquet(path)
return SparkSource(
table=destination_name,
name=destination_name,
file_format="parquet",
path=path,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping or {"ts_1": "ts"},
Expand Down