Skip to content

Commit

Permalink
Revert "Revert "[BEAM-2914] Add portable merging window support to Py…
Browse files Browse the repository at this point in the history
…thon. (apache#12995)""

This reverts commit 9c60fd5.
  • Loading branch information
robertwb authored and kileys committed Mar 11, 2021
1 parent 23b85b7 commit 00812db
Show file tree
Hide file tree
Showing 5 changed files with 315 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ def test_callbacks_with_exception(self):
def test_register_finalizations(self):
raise unittest.SkipTest("BEAM-11021")

def test_custom_merging_window(self):
raise unittest.SkipTest("BEAM-11004")

# Inherits all other tests.


Expand Down
241 changes: 229 additions & 12 deletions sdks/python/apache_beam/runners/portability/fn_api_runner/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import collections
import copy
import itertools
import uuid
import weakref
from typing import TYPE_CHECKING
from typing import Any
from typing import DefaultDict
Expand Down Expand Up @@ -55,6 +57,7 @@
from apache_beam.runners.portability.fn_api_runner.translations import split_buffer_id
from apache_beam.runners.portability.fn_api_runner.translations import unique_name
from apache_beam.runners.worker import bundle_processor
from apache_beam.transforms import core
from apache_beam.transforms import trigger
from apache_beam.transforms import window
from apache_beam.transforms.window import GlobalWindow
Expand All @@ -69,7 +72,6 @@
from apache_beam.runners.portability.fn_api_runner.fn_runner import DataOutput
from apache_beam.runners.portability.fn_api_runner.fn_runner import OutputTimers
from apache_beam.runners.portability.fn_api_runner.translations import DataSideInput
from apache_beam.transforms import core
from apache_beam.transforms.window import BoundedWindow

ENCODED_IMPULSE_VALUE = WindowedValueCoder(
Expand Down Expand Up @@ -338,6 +340,222 @@ def from_runner_api_parameter(window_coder_id, context):
context.coders[window_coder_id.decode('utf-8')])


class GenericMergingWindowFn(window.WindowFn):

URN = 'internal-generic-merging'

TO_SDK_TRANSFORM = 'read'
FROM_SDK_TRANSFORM = 'write'

_HANDLES = {} # type: Dict[str, GenericMergingWindowFn]

def __init__(self, execution_context, windowing_strategy_proto):
# type: (FnApiRunnerExecutionContext, beam_runner_api_pb2.WindowingStrategy) -> None
self._worker_handler = None # type: Optional[worker_handlers.WorkerHandler]
self._handle_id = handle_id = uuid.uuid4().hex
self._HANDLES[handle_id] = self
# ExecutionContexts are expensive, we don't want to keep them in the
# static dictionary forever. Instead we hold a weakref and pop self
# out of the dict once this context goes away.
self._execution_context_ref_obj = weakref.ref(
execution_context, lambda _: self._HANDLES.pop(handle_id, None))
self._windowing_strategy_proto = windowing_strategy_proto
self._counter = 0
# Lazily created in make_process_bundle_descriptor()
self._process_bundle_descriptor = None
self._bundle_processor_id = None # type: Optional[str]
self.windowed_input_coder_impl = None # type: Optional[CoderImpl]
self.windowed_output_coder_impl = None # type: Optional[CoderImpl]

def _execution_context_ref(self):
# type: () -> FnApiRunnerExecutionContext
result = self._execution_context_ref_obj()
assert result is not None
return result

def payload(self):
# type: () -> bytes
return self._handle_id.encode('utf-8')

@staticmethod
@window.urns.RunnerApiFn.register_urn(URN, bytes)
def from_runner_api_parameter(handle_id, unused_context):
# type: (bytes, Any) -> GenericMergingWindowFn
return GenericMergingWindowFn._HANDLES[handle_id.decode('utf-8')]

def assign(self, assign_context):
# type: (window.WindowFn.AssignContext) -> Iterable[window.BoundedWindow]
raise NotImplementedError()

def merge(self, merge_context):
# type: (window.WindowFn.MergeContext) -> None
worker_handler = self.worker_handle()

assert self.windowed_input_coder_impl is not None
assert self.windowed_output_coder_impl is not None
process_bundle_id = self.uid('process')
to_worker = worker_handler.data_conn.output_stream(
process_bundle_id, self.TO_SDK_TRANSFORM)
to_worker.write(
self.windowed_input_coder_impl.encode_nested(
window.GlobalWindows.windowed_value((b'', merge_context.windows))))
to_worker.close()

process_bundle_req = beam_fn_api_pb2.InstructionRequest(
instruction_id=process_bundle_id,
process_bundle=beam_fn_api_pb2.ProcessBundleRequest(
process_bundle_descriptor_id=self._bundle_processor_id))
result_future = worker_handler.control_conn.push(process_bundle_req)
for output in worker_handler.data_conn.input_elements(
process_bundle_id, [self.FROM_SDK_TRANSFORM],
abort_callback=lambda: bool(result_future.is_done() and result_future.
get().error)):
if isinstance(output, beam_fn_api_pb2.Elements.Data):
windowed_result = self.windowed_output_coder_impl.decode_nested(
output.data)
for merge_result, originals in windowed_result.value[1][1]:
merge_context.merge(originals, merge_result)
else:
raise RuntimeError("Unexpected data: %s" % output)

result = result_future.get()
if result.error:
raise RuntimeError(result.error)
# The result was "returned" via the merge callbacks on merge_context above.

def get_window_coder(self):
# type: () -> coders.Coder
return self._execution_context_ref().pipeline_context.coders[
self._windowing_strategy_proto.window_coder_id]

def worker_handle(self):
# type: () -> worker_handlers.WorkerHandler
if self._worker_handler is None:
worker_handler_manager = self._execution_context_ref(
).worker_handler_manager
self._worker_handler = worker_handler_manager.get_worker_handlers(
self._windowing_strategy_proto.environment_id, 1)[0]
process_bundle_decriptor = self.make_process_bundle_descriptor(
self._worker_handler.data_api_service_descriptor(),
self._worker_handler.state_api_service_descriptor())
worker_handler_manager.register_process_bundle_descriptor(
process_bundle_decriptor)
return self._worker_handler

def make_process_bundle_descriptor(
self, data_api_service_descriptor, state_api_service_descriptor):
# type: (Optional[endpoints_pb2.ApiServiceDescriptor], Optional[endpoints_pb2.ApiServiceDescriptor]) -> beam_fn_api_pb2.ProcessBundleDescriptor

"""Creates a ProcessBundleDescriptor for invoking the WindowFn's
merge operation.
"""
def make_channel_payload(coder_id):
# type: (str) -> bytes
data_spec = beam_fn_api_pb2.RemoteGrpcPort(coder_id=coder_id)
if data_api_service_descriptor:
data_spec.api_service_descriptor.url = (data_api_service_descriptor.url)
return data_spec.SerializeToString()

pipeline_context = self._execution_context_ref().pipeline_context
global_windowing_strategy_id = self.uid('global_windowing_strategy')
global_windowing_strategy_proto = core.Windowing(
window.GlobalWindows()).to_runner_api(pipeline_context)
coders = dict(pipeline_context.coders.get_id_to_proto_map())

def make_coder(urn, *components):
# type: (str, str) -> str
coder_proto = beam_runner_api_pb2.Coder(
spec=beam_runner_api_pb2.FunctionSpec(urn=urn),
component_coder_ids=components)
coder_id = self.uid('coder')
coders[coder_id] = coder_proto
pipeline_context.coders.put_proto(coder_id, coder_proto)
return coder_id

bytes_coder_id = make_coder(common_urns.coders.BYTES.urn)
window_coder_id = self._windowing_strategy_proto.window_coder_id
global_window_coder_id = make_coder(common_urns.coders.GLOBAL_WINDOW.urn)
iter_window_coder_id = make_coder(
common_urns.coders.ITERABLE.urn, window_coder_id)
input_coder_id = make_coder(
common_urns.coders.KV.urn, bytes_coder_id, iter_window_coder_id)
output_coder_id = make_coder(
common_urns.coders.KV.urn,
bytes_coder_id,
make_coder(
common_urns.coders.KV.urn,
iter_window_coder_id,
make_coder(
common_urns.coders.ITERABLE.urn,
make_coder(
common_urns.coders.KV.urn,
window_coder_id,
iter_window_coder_id))))
windowed_input_coder_id = make_coder(
common_urns.coders.WINDOWED_VALUE.urn,
input_coder_id,
global_window_coder_id)
windowed_output_coder_id = make_coder(
common_urns.coders.WINDOWED_VALUE.urn,
output_coder_id,
global_window_coder_id)

self.windowed_input_coder_impl = pipeline_context.coders[
windowed_input_coder_id].get_impl()
self.windowed_output_coder_impl = pipeline_context.coders[
windowed_output_coder_id].get_impl()

self._bundle_processor_id = self.uid('merge_windows')
return beam_fn_api_pb2.ProcessBundleDescriptor(
id=self._bundle_processor_id,
transforms={
self.TO_SDK_TRANSFORM: beam_runner_api_pb2.PTransform(
unique_name='MergeWindows/Read',
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_INPUT_URN,
payload=make_channel_payload(windowed_input_coder_id)),
outputs={'input': 'input'}),
'Merge': beam_runner_api_pb2.PTransform(
unique_name='MergeWindows/Merge',
spec=beam_runner_api_pb2.FunctionSpec(
urn=common_urns.primitives.MERGE_WINDOWS.urn,
payload=self._windowing_strategy_proto.window_fn.
SerializeToString()),
inputs={'input': 'input'},
outputs={'output': 'output'}),
self.FROM_SDK_TRANSFORM: beam_runner_api_pb2.PTransform(
unique_name='MergeWindows/Write',
spec=beam_runner_api_pb2.FunctionSpec(
urn=bundle_processor.DATA_OUTPUT_URN,
payload=make_channel_payload(windowed_output_coder_id)),
inputs={'output': 'output'}),
},
pcollections={
'input': beam_runner_api_pb2.PCollection(
unique_name='input',
windowing_strategy_id=global_windowing_strategy_id,
coder_id=input_coder_id),
'output': beam_runner_api_pb2.PCollection(
unique_name='output',
windowing_strategy_id=global_windowing_strategy_id,
coder_id=output_coder_id),
},
coders=coders,
windowing_strategies={
global_windowing_strategy_id: global_windowing_strategy_proto,
},
environments=dict(
self._execution_context_ref().pipeline_components.environments.
items()),
state_api_service_descriptor=state_api_service_descriptor,
timer_api_service_descriptor=data_api_service_descriptor)

def uid(self, name=''):
# type: (str) -> str
self._counter += 1
return '%s_%s_%s' % (self._handle_id, name, self._counter)


class FnApiRunnerExecutionContext(object):
"""
:var pcoll_buffers: (dict): Mapping of
Expand Down Expand Up @@ -443,23 +661,22 @@ def _make_safe_windowing_strategy(self, id):
windowing_strategy_proto = self.pipeline_components.windowing_strategies[id]
if windowing_strategy_proto.window_fn.urn in SAFE_WINDOW_FNS:
return id
elif (windowing_strategy_proto.merge_status ==
beam_runner_api_pb2.MergeStatus.NON_MERGING) or True:
else:
safe_id = id + '_safe'
while safe_id in self.pipeline_components.windowing_strategies:
safe_id += '_'
safe_proto = copy.copy(windowing_strategy_proto)
safe_proto.window_fn.urn = GenericNonMergingWindowFn.URN
safe_proto.window_fn.payload = (
windowing_strategy_proto.window_coder_id.encode('utf-8'))
if (windowing_strategy_proto.merge_status ==
beam_runner_api_pb2.MergeStatus.NON_MERGING):
safe_proto.window_fn.urn = GenericNonMergingWindowFn.URN
safe_proto.window_fn.payload = (
windowing_strategy_proto.window_coder_id.encode('utf-8'))
else:
window_fn = GenericMergingWindowFn(self, windowing_strategy_proto)
safe_proto.window_fn.urn = GenericMergingWindowFn.URN
safe_proto.window_fn.payload = window_fn.payload()
self.pipeline_context.windowing_strategies.put_proto(safe_id, safe_proto)
return safe_id
elif windowing_strategy_proto.window_fn.urn == python_urns.PICKLED_WINDOWFN:
return id
else:
raise NotImplementedError(
'[BEAM-10119] Unknown merging WindowFn: %s' %
windowing_strategy_proto)

@property
def state_servicer(self):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import print_function

import collections
import gc
import logging
import os
import random
Expand All @@ -46,6 +47,7 @@
from tenacity import stop_after_attempt

import apache_beam as beam
from apache_beam.coders import coders
from apache_beam.coders.coders import StrUtf8Coder
from apache_beam.io import restriction_trackers
from apache_beam.io.watermark_estimators import ManualWatermarkEstimator
Expand Down Expand Up @@ -780,6 +782,21 @@ def test_windowing(self):
| beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))))
assert_that(res, equal_to([('k', [1, 2]), ('k', [100, 101, 102])]))

def test_custom_merging_window(self):
with self.create_pipeline() as p:
res = (
p
| beam.Create([1, 2, 100, 101, 102])
| beam.Map(lambda t: window.TimestampedValue(('k', t), t))
| beam.WindowInto(CustomMergingWindowFn())
| beam.GroupByKey()
| beam.Map(lambda k_vs1: (k_vs1[0], sorted(k_vs1[1]))))
assert_that(
res, equal_to([('k', [1]), ('k', [101]), ('k', [2, 100, 102])]))
gc.collect()
from apache_beam.runners.portability.fn_api_runner.execution import GenericMergingWindowFn
self.assertEqual(GenericMergingWindowFn._HANDLES, {})

@unittest.skip('BEAM-9119: test is flaky')
def test_large_elements(self):
with self.create_pipeline() as p:
Expand Down Expand Up @@ -2002,6 +2019,26 @@ def test_gbk_many_values(self):
assert_that(r, equal_to([VALUES_PER_ELEMENT * NUM_OF_ELEMENTS]))


# TODO(robertwb): Why does pickling break when this is inlined?
class CustomMergingWindowFn(window.WindowFn):
def assign(self, assign_context):
return [
window.IntervalWindow(
assign_context.timestamp, assign_context.timestamp + 1)
]

def merge(self, merge_context):
evens = [w for w in merge_context.windows if w.start % 2 == 0]
if evens:
merge_context.merge(
evens,
window.IntervalWindow(
min(w.start for w in evens), max(w.end for w in evens)))

def get_window_coder(self):
return coders.IntervalWindowCoder()


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,9 @@ def test_flattened_side_input(self):
super(SparkRunnerTest,
self).test_flattened_side_input(with_transcoding=False)

def test_custom_merging_window(self):
raise unittest.SkipTest("BEAM-11004")

# Inherits all other tests from PortableRunnerTest.


Expand Down

0 comments on commit 00812db

Please sign in to comment.