Skip to content

Commit

Permalink
feat: Implement spark offline store offline_write_batch method (#3076)
Browse files Browse the repository at this point in the history
* create integration spark data sets as files rather than a temp table

Signed-off-by: niklasvm <niklasvm@gmail.com>

* add offline_write_batch method to spark offline store

Signed-off-by: niklasvm <niklasvm@gmail.com>

* remove some comments

Signed-off-by: niklasvm <niklasvm@gmail.com>

* fix linting issue

Signed-off-by: niklasvm <niklasvm@gmail.com>

* fix more linting issues

Signed-off-by: niklasvm <niklasvm@gmail.com>

* fix flake8 errors

Signed-off-by: niklasvm <niklasvm@gmail.com>

Signed-off-by: niklasvm <niklasvm@gmail.com>
  • Loading branch information
niklasvm committed Aug 18, 2022
1 parent cdd1b07 commit 5b0cc87
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 4 deletions.
@@ -1,7 +1,7 @@
import tempfile
import warnings
from datetime import datetime
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import pandas
Expand Down Expand Up @@ -191,6 +191,68 @@ def get_historical_features(
),
)

@staticmethod
def offline_write_batch(
config: RepoConfig,
feature_view: FeatureView,
table: pyarrow.Table,
progress: Optional[Callable[[int], Any]],
):
if not feature_view.batch_source:
raise ValueError(
"feature view does not have a batch source to persist offline data"
)
if not isinstance(config.offline_store, SparkOfflineStoreConfig):
raise ValueError(
f"offline store config is of type {type(config.offline_store)} when spark type required"
)
if not isinstance(feature_view.batch_source, SparkSource):
raise ValueError(
f"feature view batch source is {type(feature_view.batch_source)} not spark source"
)

pa_schema, column_names = offline_utils.get_pyarrow_schema_from_batch_source(
config, feature_view.batch_source
)
if column_names != table.column_names:
raise ValueError(
f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. "
f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}."
)

spark_session = get_spark_session_or_start_new_with_repoconfig(
store_config=config.offline_store
)

if feature_view.batch_source.path:
# write data to disk so that it can be loaded into spark (for preserving column types)
with tempfile.NamedTemporaryFile(suffix=".parquet") as tmp_file:
print(tmp_file.name)
pq.write_table(table, tmp_file.name)

# load data
df_batch = spark_session.read.parquet(tmp_file.name)

# load existing data to get spark table schema
df_existing = spark_session.read.format(
feature_view.batch_source.file_format
).load(feature_view.batch_source.path)

# cast columns if applicable
df_batch = _cast_data_frame(df_batch, df_existing)

df_batch.write.format(feature_view.batch_source.file_format).mode(
"append"
).save(feature_view.batch_source.path)
elif feature_view.batch_source.query:
raise NotImplementedError(
"offline_write_batch not implemented for batch sources specified by query"
)
else:
raise NotImplementedError(
"offline_write_batch not implemented for batch sources specified by a table"
)

@staticmethod
@log_exceptions_and_usage(offline_store="spark")
def pull_all_from_table_or_query(
Expand Down Expand Up @@ -388,6 +450,24 @@ def _format_datetime(t: datetime) -> str:
return dt


def _cast_data_frame(
df_new: pyspark.sql.DataFrame, df_existing: pyspark.sql.DataFrame
) -> pyspark.sql.DataFrame:
"""Convert new dataframe's columns to the same types as existing dataframe while preserving the order of columns"""
existing_dtypes = {k: v for k, v in df_existing.dtypes}
new_dtypes = {k: v for k, v in df_new.dtypes}

select_expression = []
for col, new_type in new_dtypes.items():
existing_type = existing_dtypes[col]
if new_type != existing_type:
select_expression.append(f"cast({col} as {existing_type}) as {col}")
else:
select_expression.append(col)

return df_new.selectExpr(*select_expression)


MULTIPLE_FEATURE_VIEW_POINT_IN_TIME_JOIN = """
/*
Compute a deterministic hash for the `left_table_query_string` that will be used throughout
Expand Down
@@ -1,3 +1,6 @@
import os
import shutil
import tempfile
import uuid
from typing import Dict, List

Expand Down Expand Up @@ -48,6 +51,8 @@ def __init__(self, project_name: str, *args, **kwargs):

def teardown(self):
self.spark_session.stop()
for table in self.tables:
shutil.rmtree(table)

def create_offline_store_config(self):
self.spark_offline_store_config = SparkOfflineStoreConfig()
Expand Down Expand Up @@ -86,11 +91,17 @@ def create_data_source(
.appName("pytest-pyspark-local-testing")
.getOrCreate()
)
self.spark_session.createDataFrame(df).createOrReplaceTempView(destination_name)
self.tables.append(destination_name)

temp_dir = tempfile.mkdtemp(prefix="spark_offline_store_test_data")

path = os.path.join(temp_dir, destination_name)
self.tables.append(path)

self.spark_session.createDataFrame(df).write.parquet(path)
return SparkSource(
table=destination_name,
name=destination_name,
file_format="parquet",
path=path,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping or {"ts_1": "ts"},
Expand Down

0 comments on commit 5b0cc87

Please sign in to comment.