Skip to content

Commit 48845d1

Browse files
vancior98lindong28
authored andcommitted
[FLINK-31185][python] Support side-output in broadcast processing
This closes apache#22003.
1 parent 484d41c commit 48845d1

File tree

9 files changed

+269
-20
lines changed

9 files changed

+269
-20
lines changed

flink-python/pyflink/datastream/tests/test_data_stream.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,40 @@ def process_element2(self, value, ctx: 'CoProcessFunction.Context'):
589589
side_expected = ['0', '0', '1', '1', '2', '3']
590590
self.assert_equals_sorted(side_expected, side_sink.get_results())
591591

592+
def test_co_broadcast_side_output(self):
593+
tag = OutputTag("side", Types.INT())
594+
595+
class MyBroadcastProcessFunction(BroadcastProcessFunction):
596+
597+
def process_element(self, value, ctx):
598+
yield value[0]
599+
yield tag, value[1]
600+
601+
def process_broadcast_element(self, value, ctx):
602+
yield value[1]
603+
yield tag, value[0]
604+
605+
self.env.set_parallelism(2)
606+
ds = self.env.from_collection([('a', 0), ('b', 1), ('c', 2)],
607+
type_info=Types.ROW([Types.STRING(), Types.INT()]))
608+
ds_broadcast = self.env.from_collection([(3, 'd'), (4, 'f')],
609+
type_info=Types.ROW([Types.INT(), Types.STRING()]))
610+
map_state_desc = MapStateDescriptor(
611+
"dummy", key_type_info=Types.INT(), value_type_info=Types.STRING()
612+
)
613+
ds = ds.connect(ds_broadcast.broadcast(map_state_desc)).process(
614+
MyBroadcastProcessFunction(), output_type=Types.STRING()
615+
)
616+
side_sink = DataStreamTestSinkFunction()
617+
ds.get_side_output(tag).add_sink(side_sink)
618+
ds.add_sink(self.test_sink)
619+
620+
self.env.execute("test_co_broadcast_process_side_output")
621+
main_expected = ['a', 'b', 'c', 'd', 'd', 'f', 'f']
622+
self.assert_equals_sorted(main_expected, self.test_sink.get_results())
623+
side_expected = ['0', '1', '2', '3', '3', '4', '4']
624+
self.assert_equals_sorted(side_expected, side_sink.get_results())
625+
592626
def test_keyed_process_side_output(self):
593627
tag = OutputTag("side", Types.INT())
594628

@@ -665,6 +699,49 @@ def process_element2(self, value, ctx: 'KeyedCoProcessFunction.Context'):
665699
side_expected = ['1', '1', '2', '2', '3', '3', '4', '4']
666700
self.assert_equals_sorted(side_expected, side_sink.get_results())
667701

702+
def test_keyed_co_broadcast_side_output(self):
703+
tag = OutputTag("side", Types.INT())
704+
705+
class MyKeyedBroadcastProcessFunction(KeyedBroadcastProcessFunction):
706+
707+
def __init__(self):
708+
self.reducing_state = None # type: ReducingState
709+
710+
def open(self, context: RuntimeContext):
711+
self.reducing_state = context.get_reducing_state(
712+
ReducingStateDescriptor("reduce", lambda i, j: i+j, Types.INT())
713+
)
714+
715+
def process_element(self, value, ctx):
716+
self.reducing_state.add(value[1])
717+
yield value[0]
718+
yield tag, self.reducing_state.get()
719+
720+
def process_broadcast_element(self, value, ctx):
721+
yield value[1]
722+
yield tag, value[0]
723+
724+
self.env.set_parallelism(2)
725+
ds = self.env.from_collection([('a', 0), ('b', 1), ('a', 2), ('b', 3)],
726+
type_info=Types.ROW([Types.STRING(), Types.INT()]))
727+
ds_broadcast = self.env.from_collection([(5, 'c'), (6, 'd')],
728+
type_info=Types.ROW([Types.INT(), Types.STRING()]))
729+
map_state_desc = MapStateDescriptor(
730+
"dummy", key_type_info=Types.INT(), value_type_info=Types.STRING()
731+
)
732+
ds = ds.key_by(lambda e: e[0]).connect(ds_broadcast.broadcast(map_state_desc)).process(
733+
MyKeyedBroadcastProcessFunction(), output_type=Types.STRING()
734+
)
735+
side_sink = DataStreamTestSinkFunction()
736+
ds.get_side_output(tag).add_sink(side_sink)
737+
ds.add_sink(self.test_sink)
738+
739+
self.env.execute("test_keyed_co_broadcast_process_side_output")
740+
main_expected = ['a', 'a', 'b', 'b', 'c', 'c', 'd', 'd']
741+
self.assert_equals_sorted(main_expected, self.test_sink.get_results())
742+
side_expected = ['0', '1', '2', '4', '5', '5', '6', '6']
743+
self.assert_equals_sorted(side_expected, side_sink.get_results())
744+
668745
def test_side_output_stream_execute_and_collect(self):
669746
tag = OutputTag("side", Types.INT())
670747

flink-python/pyflink/fn_execution/datastream/embedded/operations.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,8 @@ def extract_process_function(
100100
side_output_context = SideOutputContext(j_side_output_context)
101101

102102
def process_func(values):
103+
if values is None:
104+
return
103105
for value in values:
104106
if isinstance(value, tuple) and isinstance(value[0], OutputTag):
105107
output_tag = value[0] # type: OutputTag
@@ -108,6 +110,8 @@ def process_func(values):
108110
yield value
109111
else:
110112
def process_func(values):
113+
if values is None:
114+
return
111115
yield from values
112116

113117
def open_func():
@@ -174,14 +178,10 @@ def process_element_func2(value):
174178
process_broadcast_element = user_defined_func.process_broadcast_element
175179

176180
def process_element_func1(value):
177-
elements = process_element(value, read_only_broadcast_ctx)
178-
if elements:
179-
yield from elements
181+
yield from process_func(process_element(value, read_only_broadcast_ctx))
180182

181183
def process_element_func2(value):
182-
elements = process_broadcast_element(value, broadcast_ctx)
183-
if elements:
184-
yield from elements
184+
yield from process_func(process_broadcast_element(value, broadcast_ctx))
185185

186186
return TwoInputOperation(
187187
open_func, close_func, process_element_func1, process_element_func2)
@@ -221,19 +221,20 @@ def on_timer_func(timestamp):
221221
timer_context = InternalKeyedBroadcastProcessFunctionOnTimerContext(
222222
j_timer_context, user_defined_function_proto.key_type_info, j_operator_state_backend)
223223

224+
keyed_state_backend = KeyedStateBackend(
225+
read_only_broadcast_ctx,
226+
j_keyed_state_backend)
227+
runtime_context.set_keyed_state_backend(keyed_state_backend)
228+
224229
process_element = user_defined_func.process_element
225230
process_broadcast_element = user_defined_func.process_broadcast_element
226231
on_timer = user_defined_func.on_timer
227232

228233
def process_element_func1(value):
229-
elements = process_element(value[1], read_only_broadcast_ctx)
230-
if elements:
231-
yield from elements
234+
yield from process_func(process_element(value[1], read_only_broadcast_ctx))
232235

233236
def process_element_func2(value):
234-
elements = process_broadcast_element(value, broadcast_ctx)
235-
if elements:
236-
yield from elements
237+
yield from process_func(process_broadcast_element(value, broadcast_ctx))
237238

238239
def on_timer_func(timestamp):
239240
yield from on_timer(timestamp, timer_context)

flink-python/src/main/java/org/apache/flink/python/chain/PythonOperatorChainingOptimizer.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,8 @@
5656
import org.apache.flink.streaming.api.transformations.TimestampsAndWatermarksTransformation;
5757
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
5858
import org.apache.flink.streaming.api.transformations.UnionTransformation;
59+
import org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
60+
import org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
5961
import org.apache.flink.streaming.runtime.partitioner.ForwardPartitioner;
6062

6163
import org.apache.flink.shaded.guava30.com.google.common.collect.Lists;
@@ -409,6 +411,11 @@ private static boolean areOperatorsChainable(
409411
return false;
410412
}
411413

414+
if (upTransform instanceof PythonBroadcastStateTransformation
415+
|| upTransform instanceof PythonKeyedBroadcastStateTransformation) {
416+
return false;
417+
}
418+
412419
DataStreamPythonFunctionOperator<?> upOperator =
413420
(DataStreamPythonFunctionOperator<?>)
414421
((SimpleOperatorFactory<?>) getOperatorFactory(upTransform)).getOperator();

flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import org.apache.flink.streaming.api.transformations.PartitionTransformation;
4343
import org.apache.flink.streaming.api.transformations.SideOutputTransformation;
4444
import org.apache.flink.streaming.api.transformations.TwoInputTransformation;
45+
import org.apache.flink.streaming.api.transformations.python.DelegateOperatorTransformation;
4546
import org.apache.flink.streaming.api.transformations.python.PythonBroadcastStateTransformation;
4647
import org.apache.flink.streaming.api.transformations.python.PythonKeyedBroadcastStateTransformation;
4748
import org.apache.flink.streaming.api.utils.ByteArrayWrapper;
@@ -152,6 +153,8 @@ public static StreamOperatorFactory<?> getOperatorFactory(Transformation<?> tran
152153
return ((TwoInputTransformation<?, ?, ?>) transform).getOperatorFactory();
153154
} else if (transform instanceof AbstractMultipleInputTransformation) {
154155
return ((AbstractMultipleInputTransformation<?>) transform).getOperatorFactory();
156+
} else if (transform instanceof DelegateOperatorTransformation<?>) {
157+
return ((DelegateOperatorTransformation<?>) transform).getOperatorFactory();
155158
} else {
156159
return null;
157160
}
@@ -214,6 +217,9 @@ private static AbstractPythonFunctionOperator<?> getPythonOperator(
214217
} else if (transformation instanceof AbstractMultipleInputTransformation) {
215218
operatorFactory =
216219
((AbstractMultipleInputTransformation<?>) transformation).getOperatorFactory();
220+
} else if (transformation instanceof DelegateOperatorTransformation) {
221+
operatorFactory =
222+
((DelegateOperatorTransformation<?>) transformation).getOperatorFactory();
217223
}
218224

219225
if (operatorFactory instanceof SimpleOperatorFactory
@@ -260,6 +266,9 @@ public static boolean isPythonDataStreamOperator(Transformation<?> transform) {
260266
} else if (transform instanceof TwoInputTransformation) {
261267
return isPythonDataStreamOperator(
262268
((TwoInputTransformation<?, ?, ?>) transform).getOperatorFactory());
269+
} else if (transform instanceof PythonBroadcastStateTransformation
270+
|| transform instanceof PythonKeyedBroadcastStateTransformation) {
271+
return true;
263272
} else {
264273
return false;
265274
}
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.flink.streaming.api.transformations.python;
19+
20+
import org.apache.flink.api.common.typeinfo.TypeInformation;
21+
import org.apache.flink.configuration.Configuration;
22+
import org.apache.flink.python.env.PythonEnvironmentManager;
23+
import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
24+
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
25+
import org.apache.flink.streaming.api.operators.python.AbstractPythonFunctionOperator;
26+
import org.apache.flink.streaming.api.operators.python.DataStreamPythonFunctionOperator;
27+
import org.apache.flink.util.OutputTag;
28+
29+
import javax.annotation.Nullable;
30+
31+
import java.util.Collection;
32+
import java.util.HashMap;
33+
import java.util.Map;
34+
35+
/**
36+
* For those {@link org.apache.flink.api.dag.Transformation} that don't have an operator entity,
37+
* {@link DelegateOperatorTransformation} provides a {@link SimpleOperatorFactory} containing a
38+
* {@link DelegateOperator} , which can hold special configurations during transformation
39+
* preprocessing for Python jobs, and later be queried at translation stage. Currently, those
40+
* configurations include {@link OutputTag}s, {@code numPartitions} and general {@link
41+
* Configuration}.
42+
*/
43+
public interface DelegateOperatorTransformation<OUT> {
44+
45+
SimpleOperatorFactory<OUT> getOperatorFactory();
46+
47+
static void configureOperator(
48+
DelegateOperatorTransformation<?> transformation,
49+
AbstractPythonFunctionOperator<?> operator) {
50+
DelegateOperator<?> delegateOperator =
51+
(DelegateOperator<?>) transformation.getOperatorFactory().getOperator();
52+
53+
operator.getConfiguration().addAll(delegateOperator.getConfiguration());
54+
55+
if (operator instanceof DataStreamPythonFunctionOperator) {
56+
DataStreamPythonFunctionOperator<?> dataStreamOperator =
57+
(DataStreamPythonFunctionOperator<?>) operator;
58+
dataStreamOperator.addSideOutputTags(delegateOperator.getSideOutputTags());
59+
if (delegateOperator.getNumPartitions() != null) {
60+
dataStreamOperator.setNumPartitions(delegateOperator.getNumPartitions());
61+
}
62+
}
63+
}
64+
65+
/**
66+
* {@link DelegateOperator} holds configurations, e.g. {@link OutputTag}s, which will be applied
67+
* to the actual python operator at translation stage.
68+
*/
69+
class DelegateOperator<OUT> extends AbstractPythonFunctionOperator<OUT>
70+
implements DataStreamPythonFunctionOperator<OUT> {
71+
72+
private final Map<String, OutputTag<?>> sideOutputTags = new HashMap<>();
73+
private @Nullable Integer numPartitions = null;
74+
75+
public DelegateOperator() {
76+
super(new Configuration());
77+
}
78+
79+
@Override
80+
public void addSideOutputTags(Collection<OutputTag<?>> outputTags) {
81+
for (OutputTag<?> outputTag : outputTags) {
82+
sideOutputTags.put(outputTag.getId(), outputTag);
83+
}
84+
}
85+
86+
@Override
87+
public Collection<OutputTag<?>> getSideOutputTags() {
88+
return sideOutputTags.values();
89+
}
90+
91+
@Override
92+
public void setNumPartitions(int numPartitions) {
93+
this.numPartitions = numPartitions;
94+
}
95+
96+
@Nullable
97+
public Integer getNumPartitions() {
98+
return numPartitions;
99+
}
100+
101+
@Override
102+
public TypeInformation<OUT> getProducedType() {
103+
throw new RuntimeException("This should not be invoked on a DelegateOperator!");
104+
}
105+
106+
@Override
107+
public DataStreamPythonFunctionInfo getPythonFunctionInfo() {
108+
throw new RuntimeException("This should not be invoked on a DelegateOperator!");
109+
}
110+
111+
@Override
112+
public <T> DataStreamPythonFunctionOperator<T> copy(
113+
DataStreamPythonFunctionInfo pythonFunctionInfo,
114+
TypeInformation<T> outputTypeInfo) {
115+
throw new RuntimeException("This should not be invoked on a DelegateOperator!");
116+
}
117+
118+
@Override
119+
protected void invokeFinishBundle() throws Exception {
120+
throw new RuntimeException("This should not be invoked on a DelegateOperator!");
121+
}
122+
123+
@Override
124+
protected PythonEnvironmentManager createPythonEnvironmentManager() {
125+
throw new RuntimeException("This should not be invoked on a DelegateOperator!");
126+
}
127+
}
128+
}

flink-python/src/main/java/org/apache/flink/streaming/api/transformations/python/PythonBroadcastStateTransformation.java

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import org.apache.flink.api.dag.Transformation;
2424
import org.apache.flink.configuration.Configuration;
2525
import org.apache.flink.streaming.api.functions.python.DataStreamPythonFunctionInfo;
26+
import org.apache.flink.streaming.api.operators.SimpleOperatorFactory;
2627
import org.apache.flink.streaming.api.transformations.AbstractBroadcastStateTransformation;
2728

2829
import java.util.List;
@@ -34,10 +35,12 @@
3435
*/
3536
@Internal
3637
public class PythonBroadcastStateTransformation<IN1, IN2, OUT>
37-
extends AbstractBroadcastStateTransformation<IN1, IN2, OUT> {
38+
extends AbstractBroadcastStateTransformation<IN1, IN2, OUT>
39+
implements DelegateOperatorTransformation<OUT> {
3840

3941
private final Configuration configuration;
4042
private final DataStreamPythonFunctionInfo dataStreamPythonFunctionInfo;
43+
private final SimpleOperatorFactory<OUT> delegateOperatorFactory;
4144

4245
public PythonBroadcastStateTransformation(
4346
String name,
@@ -57,6 +60,7 @@ public PythonBroadcastStateTransformation(
5760
parallelism);
5861
this.configuration = configuration;
5962
this.dataStreamPythonFunctionInfo = dataStreamPythonFunctionInfo;
63+
this.delegateOperatorFactory = SimpleOperatorFactory.of(new DelegateOperator<>());
6064
updateManagedMemoryStateBackendUseCase(false);
6165
}
6266

@@ -67,4 +71,8 @@ public Configuration getConfiguration() {
6771
public DataStreamPythonFunctionInfo getDataStreamPythonFunctionInfo() {
6872
return dataStreamPythonFunctionInfo;
6973
}
74+
75+
public SimpleOperatorFactory<OUT> getOperatorFactory() {
76+
return delegateOperatorFactory;
77+
}
7078
}

0 commit comments

Comments
 (0)