Skip to content

Commit

Permalink
feat: Implement spark materialization engine (#3184)
Browse files Browse the repository at this point in the history
* implement spark materialization engine

Signed-off-by: niklasvm <niklasvm@gmail.com>

* remove redundant code

Signed-off-by: niklasvm <niklasvm@gmail.com>

* make function private

Signed-off-by: niklasvm <niklasvm@gmail.com>

* refactor serializing into a class

Signed-off-by: niklasvm <niklasvm@gmail.com>

* switch to using `foreachPartition`

Signed-off-by: niklasvm <niklasvm@gmail.com>

* remove batch_size parameter

Signed-off-by: niklasvm <niklasvm@gmail.com>

* add partitions parameter

Signed-off-by: niklasvm <niklasvm@gmail.com>

* linting

Signed-off-by: niklasvm <niklasvm@gmail.com>

* rename spark to spark.offline and spark.engine

Signed-off-by: niklasvm <niklasvm@gmail.com>

* fix to test

Signed-off-by: niklasvm <niklasvm@gmail.com>

* forgot to stage

Signed-off-by: niklasvm <niklasvm@gmail.com>

* revert spark.offline to spark to ensure backward compatibility

Signed-off-by: niklasvm <niklasvm@gmail.com>

* fix import

Signed-off-by: niklasvm <niklasvm@gmail.com>

* remove code from testing a large data set

Signed-off-by: niklasvm <niklasvm@gmail.com>

* linting

Signed-off-by: niklasvm <niklasvm@gmail.com>

* test without repartition

Signed-off-by: niklasvm <niklasvm@gmail.com>

* test alternate connection string

Signed-off-by: niklasvm <niklasvm@gmail.com>

* use redis online creator

Signed-off-by: niklasvm <niklasvm@gmail.com>

Signed-off-by: niklasvm <niklasvm@gmail.com>
  • Loading branch information
niklasvm committed Sep 15, 2022
1 parent 7bc1dff commit a59c33a
Show file tree
Hide file tree
Showing 3 changed files with 343 additions and 0 deletions.
@@ -0,0 +1,265 @@
import tempfile
from dataclasses import dataclass
from datetime import datetime
from typing import Callable, List, Literal, Optional, Sequence, Union

import dill
import pandas as pd
import pyarrow
from tqdm import tqdm

from feast.batch_feature_view import BatchFeatureView
from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.infra.materialization.batch_materialization_engine import (
BatchMaterializationEngine,
MaterializationJob,
MaterializationJobStatus,
MaterializationTask,
)
from feast.infra.offline_stores.contrib.spark_offline_store.spark import (
SparkOfflineStore,
SparkRetrievalJob,
)
from feast.infra.online_stores.online_store import OnlineStore
from feast.infra.passthrough_provider import PassthroughProvider
from feast.infra.registry.base_registry import BaseRegistry
from feast.protos.feast.core.FeatureView_pb2 import FeatureView as FeatureViewProto
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.stream_feature_view import StreamFeatureView
from feast.utils import (
_convert_arrow_to_proto,
_get_column_names,
_run_pyarrow_field_mapping,
)


class SparkMaterializationEngineConfig(FeastConfigBaseModel):
"""Batch Materialization Engine config for spark engine"""

type: Literal["spark.engine"] = "spark.engine"
""" Type selector"""

partitions: int = 0
"""Number of partitions to use when writing data to online store. If 0, no repartitioning is done"""


@dataclass
class SparkMaterializationJob(MaterializationJob):
def __init__(
self,
job_id: str,
status: MaterializationJobStatus,
error: Optional[BaseException] = None,
) -> None:
super().__init__()
self._job_id: str = job_id
self._status: MaterializationJobStatus = status
self._error: Optional[BaseException] = error

def status(self) -> MaterializationJobStatus:
return self._status

def error(self) -> Optional[BaseException]:
return self._error

def should_be_retried(self) -> bool:
return False

def job_id(self) -> str:
return self._job_id

def url(self) -> Optional[str]:
return None


class SparkMaterializationEngine(BatchMaterializationEngine):
def update(
self,
project: str,
views_to_delete: Sequence[
Union[BatchFeatureView, StreamFeatureView, FeatureView]
],
views_to_keep: Sequence[
Union[BatchFeatureView, StreamFeatureView, FeatureView]
],
entities_to_delete: Sequence[Entity],
entities_to_keep: Sequence[Entity],
):
# Nothing to set up.
pass

def teardown_infra(
self,
project: str,
fvs: Sequence[Union[BatchFeatureView, StreamFeatureView, FeatureView]],
entities: Sequence[Entity],
):
# Nothing to tear down.
pass

def __init__(
self,
*,
repo_config: RepoConfig,
offline_store: SparkOfflineStore,
online_store: OnlineStore,
**kwargs,
):
if not isinstance(offline_store, SparkOfflineStore):
raise TypeError(
"SparkMaterializationEngine is only compatible with the SparkOfflineStore"
)
super().__init__(
repo_config=repo_config,
offline_store=offline_store,
online_store=online_store,
**kwargs,
)

def materialize(
self, registry, tasks: List[MaterializationTask]
) -> List[MaterializationJob]:
return [
self._materialize_one(
registry,
task.feature_view,
task.start_time,
task.end_time,
task.project,
task.tqdm_builder,
)
for task in tasks
]

def _materialize_one(
self,
registry: BaseRegistry,
feature_view: Union[BatchFeatureView, StreamFeatureView, FeatureView],
start_date: datetime,
end_date: datetime,
project: str,
tqdm_builder: Callable[[int], tqdm],
):
entities = []
for entity_name in feature_view.entities:
entities.append(registry.get_entity(entity_name, project))

(
join_key_columns,
feature_name_columns,
timestamp_field,
created_timestamp_column,
) = _get_column_names(feature_view, entities)

job_id = f"{feature_view.name}-{start_date}-{end_date}"

try:
offline_job: SparkRetrievalJob = (
self.offline_store.pull_latest_from_table_or_query(
config=self.repo_config,
data_source=feature_view.batch_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)
)

spark_serialized_artifacts = _SparkSerializedArtifacts.serialize(
feature_view=feature_view, repo_config=self.repo_config
)

spark_df = offline_job.to_spark_df()
if self.repo_config.batch_engine.partitions != 0:
spark_df = spark_df.repartition(
self.repo_config.batch_engine.partitions
)

spark_df.foreachPartition(
lambda x: _process_by_partition(x, spark_serialized_artifacts)
)

return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.SUCCEEDED
)
except BaseException as e:
return SparkMaterializationJob(
job_id=job_id, status=MaterializationJobStatus.ERROR, error=e
)


@dataclass
class _SparkSerializedArtifacts:
"""Class to assist with serializing unpicklable artifacts to the spark workers"""

feature_view_proto: str
repo_config_file: str

@classmethod
def serialize(cls, feature_view, repo_config):

# serialize to proto
feature_view_proto = feature_view.to_proto().SerializeToString()

# serialize repo_config to disk. Will be used to instantiate the online store
repo_config_file = tempfile.NamedTemporaryFile(delete=False).name
with open(repo_config_file, "wb") as f:
dill.dump(repo_config, f)

return _SparkSerializedArtifacts(
feature_view_proto=feature_view_proto, repo_config_file=repo_config_file
)

def unserialize(self):
# unserialize
proto = FeatureViewProto()
proto.ParseFromString(self.feature_view_proto)
feature_view = FeatureView.from_proto(proto)

# load
with open(self.repo_config_file, "rb") as f:
repo_config = dill.load(f)

provider = PassthroughProvider(repo_config)
online_store = provider.online_store
return feature_view, online_store, repo_config


def _process_by_partition(rows, spark_serialized_artifacts: _SparkSerializedArtifacts):
"""Load pandas df to online store"""

# convert to pyarrow table
dicts = []
for row in rows:
dicts.append(row.asDict())

df = pd.DataFrame.from_records(dicts)
if df.shape[0] == 0:
print("Skipping")
return

table = pyarrow.Table.from_pandas(df)

# unserialize artifacts
feature_view, online_store, repo_config = spark_serialized_artifacts.unserialize()

if feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, feature_view.batch_source.field_mapping
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(table, feature_view, join_key_to_value_type)
online_store.online_write_batch(
repo_config,
feature_view,
rows_to_write,
lambda x: None,
)
1 change: 1 addition & 0 deletions sdk/python/feast/repo_config.py
Expand Up @@ -39,6 +39,7 @@
"snowflake.engine": "feast.infra.materialization.snowflake_engine.SnowflakeMaterializationEngine",
"lambda": "feast.infra.materialization.aws_lambda.lambda_engine.LambdaMaterializationEngine",
"bytewax": "feast.infra.materialization.contrib.bytewax.bytewax_materialization_engine.BytewaxMaterializationEngine",
"spark.engine": "feast.infra.materialization.contrib.spark.spark_materialization_engine.SparkMaterializationEngine",
}

ONLINE_STORE_CLASS_FOR_TYPE = {
Expand Down
@@ -0,0 +1,77 @@
from datetime import timedelta

import pytest

from feast.entity import Entity
from feast.feature_view import FeatureView
from feast.field import Field
from feast.infra.offline_stores.contrib.spark_offline_store.tests.data_source import (
SparkDataSourceCreator,
)
from feast.types import Float32
from tests.data.data_creator import create_basic_driver_dataset
from tests.integration.feature_repos.integration_test_repo_config import (
IntegrationTestRepoConfig,
)
from tests.integration.feature_repos.repo_configuration import (
construct_test_environment,
)
from tests.integration.feature_repos.universal.online_store.redis import (
RedisOnlineStoreCreator,
)
from tests.utils.e2e_test_validation import validate_offline_online_store_consistency


@pytest.mark.integration
def test_spark_materialization_consistency():
spark_config = IntegrationTestRepoConfig(
provider="local",
online_store_creator=RedisOnlineStoreCreator,
offline_store_creator=SparkDataSourceCreator,
batch_engine={"type": "spark.engine", "partitions": 10},
)
spark_environment = construct_test_environment(
spark_config, None, entity_key_serialization_version=1
)

df = create_basic_driver_dataset()

ds = spark_environment.data_source_creator.create_data_source(
df,
spark_environment.feature_store.project,
field_mapping={"ts_1": "ts"},
)

fs = spark_environment.feature_store
driver = Entity(
name="driver_id",
join_keys=["driver_id"],
)

driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(weeks=52),
schema=[Field(name="value", dtype=Float32)],
source=ds,
)

try:

fs.apply([driver, driver_stats_fv])

print(df)

# materialization is run in two steps and
# we use timestamp from generated dataframe as a split point
split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1)

print(f"Split datetime: {split_dt}")

validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt)
finally:
fs.teardown()


if __name__ == "__main__":
test_spark_materialization_consistency()

0 comments on commit a59c33a

Please sign in to comment.