Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Optimize bytewax pod resource with zero-copy #3826

Merged
merged 3 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os
from typing import List

Expand All @@ -7,11 +8,11 @@
from bytewax.execution import cluster_main
from bytewax.inputs import ManualInputConfig
from bytewax.outputs import ManualOutputConfig
from tqdm import tqdm

from feast import FeatureStore, FeatureView, RepoConfig
from feast.utils import _convert_arrow_to_proto, _run_pyarrow_field_mapping

logger = logging.getLogger(__name__)
DEFAULT_BATCH_SIZE = 1000


Expand All @@ -29,14 +30,20 @@ def __init__(
self.feature_view = feature_view
self.worker_index = worker_index
self.paths = paths
self.mini_batch_size = int(
os.getenv("BYTEWAX_MINI_BATCH_SIZE", DEFAULT_BATCH_SIZE)
)

self._run_dataflow()

def process_path(self, path):
logger.info(f"Processing path {path}")
dataset = pq.ParquetDataset(path, use_legacy_dataset=False)
batches = []
for fragment in dataset.fragments:
for batch in fragment.to_table().to_batches():
for batch in fragment.to_table().to_batches(
max_chunksize=self.mini_batch_size
):
batches.append(batch)

return batches
Expand All @@ -45,40 +52,26 @@ def input_builder(self, worker_index, worker_count, _state):
return [(None, self.paths[self.worker_index])]

def output_builder(self, worker_index, worker_count):
def yield_batch(iterable, batch_size):
"""Yield mini-batches from an iterable."""
for i in range(0, len(iterable), batch_size):
yield iterable[i : i + batch_size]

def output_fn(batch):
table = pa.Table.from_batches([batch])
def output_fn(mini_batch):
table: pa.Table = pa.Table.from_batches([mini_batch])

if self.feature_view.batch_source.field_mapping is not None:
table = _run_pyarrow_field_mapping(
table, self.feature_view.batch_source.field_mapping
)

join_key_to_value_type = {
entity.name: entity.dtype.to_value_type()
for entity in self.feature_view.entity_columns
}

rows_to_write = _convert_arrow_to_proto(
table, self.feature_view, join_key_to_value_type
)
provider = self.feature_store._get_provider()
with tqdm(total=len(rows_to_write)) as progress:
# break rows_to_write to mini-batches
batch_size = int(
os.getenv("BYTEWAX_MINI_BATCH_SIZE", DEFAULT_BATCH_SIZE)
)
for mini_batch in yield_batch(rows_to_write, batch_size):
provider.online_write_batch(
config=self.config,
table=self.feature_view,
data=mini_batch,
progress=progress.update,
)
self.feature_store._get_provider().online_write_batch(
config=self.config,
table=self.feature_view,
data=rows_to_write,
progress=None,
)

return output_fn

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def __init__(
self.offline_store = offline_store
self.online_store = online_store

# TODO: Configure k8s here
k8s_config.load_config()

self.k8s_client = client.api_client.ApiClient()
Expand Down Expand Up @@ -299,6 +298,9 @@ def _create_kubernetes_job(self, job_id, paths, feature_view):
len(paths), # Create a pod for each parquet file
self.batch_engine_config.env,
)
logger.info(
f"Created job `dataflow-{job_id}` on namespace `{self.namespace}`"
)
except FailToCreateError as failures:
return BytewaxMaterializationJob(job_id, self.namespace, error=failures)

Expand Down Expand Up @@ -361,7 +363,7 @@ def _create_job_definition(self, job_id, namespace, pods, env, index_offset=0):
},
{
"name": "BYTEWAX_REPLICAS",
"value": f"{pods}",
"value": "1",
},
{
"name": "BYTEWAX_KEEP_CONTAINER_ALIVE",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,22 @@ def status(self):
if job_status.active is not None:
if job_status.completion_time is None:
return MaterializationJobStatus.RUNNING
elif job_status.failed is not None:
self._error = Exception(f"Job {self.job_id()} failed")
return MaterializationJobStatus.ERROR
elif job_status.active is None:
if job_status.completion_time is not None:
if job_status.conditions[0].type == "Complete":
return MaterializationJobStatus.SUCCEEDED
else:
if (
job_status.completion_time is not None
and job_status.conditions[0].type == "Complete"
):
return MaterializationJobStatus.SUCCEEDED

if (
job_status.conditions is not None
and job_status.conditions[0].type == "Failed"
):
self._error = Exception(
f"Job {self.job_id()} failed with reason: "
f"{job_status.conditions[0].message}"
)
return MaterializationJobStatus.ERROR
return MaterializationJobStatus.WAITING

def should_be_retried(self):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os

import yaml
Expand All @@ -8,6 +9,8 @@
)

if __name__ == "__main__":
logging.basicConfig(level=logging.INFO)

with open("/var/feast/feature_store.yaml") as f:
feast_config = yaml.safe_load(f)

Expand Down
Loading