Skip to content

Commit

Permalink
Honor local parallelism of fused vertices [HZ-2493] [5.3.z] (#24942)
Browse files Browse the repository at this point in the history
Backports #24859
Fixes #24683

In addition, the following members are renamed:
1. xform2vertex → transform2vertex: We should avoid using "x" for "trans" for the sake of readability.
2. findFusableChain() → findFusibleChain(): Fusible is the correct spelling of fusable.
  • Loading branch information
burakgok committed Jul 4, 2023
1 parent af9eba4 commit 916253a
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public class Planner {
private static final int MAXIMUM_WATERMARK_GAP = 1000;

public final DAG dag = new DAG();
public final Map<Transform, PlannerVertex> xform2vertex = new HashMap<>();
public final Map<Transform, PlannerVertex> transform2vertex = new HashMap<>();
private final PipelineImpl pipeline;

Planner(PipelineImpl pipeline) {
Expand Down Expand Up @@ -119,7 +119,7 @@ DAG createDag(Context context) {
List<Transform> transforms = new ArrayList<>(adjacencyMap.keySet());
for (int i = 0; i < transforms.size(); i++) {
Transform transform = transforms.get(i);
List<Transform> chain = findFusableChain(transform, adjacencyMap);
List<Transform> chain = findFusibleChain(transform, adjacencyMap);
if (chain == null) {
continue;
}
Expand Down Expand Up @@ -149,7 +149,7 @@ DAG createDag(Context context) {
return dag;
}

private static List<Transform> findFusableChain(
private static List<Transform> findFusibleChain(
@Nonnull Transform transform,
@Nonnull Map<Transform, List<Transform>> adjacencyMap
) {
Expand Down Expand Up @@ -212,12 +212,11 @@ private static Transform fuseFlatMapTransforms(List<Transform> chain) {
}
fused = new FlatMapTransform(name, chain.get(0).upstream().get(0), flatMapFn);
}
// if the first stage of the chain is rebalanced, then we set
// the rebalance flag of the created fused stage. Only consider
// the case when first element of the chain is rebalanced
// because there isn't any other case. If any stage in the
// middle includes rebalance, then those stages are not fused
// by findFusableChain().
fused.localParallelism(chain.get(0).localParallelism());
// If the first stage of the chain is rebalanced, then we set the rebalance flag
// of the created fused stage. Only consider the case when first element of the
// chain is rebalanced because there isn't any other case. If any stage in the
// middle includes rebalance, then those stages are not fused by findFusibleChain().
fused.setRebalanceInput(0, chain.get(0).shouldRebalanceInput(0));
return fused;
}
Expand Down Expand Up @@ -263,14 +262,14 @@ public PlannerVertex addVertex(Transform transform, String name, int localParall
ProcessorMetaSupplier metaSupplier) {
PlannerVertex pv = new PlannerVertex(dag.newVertex(name, metaSupplier));
pv.v.localParallelism(localParallelism);
xform2vertex.put(transform, pv);
transform2vertex.put(transform, pv);
return pv;
}

public void addEdges(Transform transform, Vertex toVertex, ObjIntConsumer<Edge> configureEdgeFn) {
int destOrdinal = 0;
for (Transform fromTransform : transform.upstream()) {
PlannerVertex fromPv = xform2vertex.get(fromTransform);
PlannerVertex fromPv = transform2vertex.get(fromTransform);
Edge edge = from(fromPv.v, fromPv.nextAvailableOrdinal()).to(toVertex, destOrdinal);
dag.edge(edge);
configureEdgeFn.accept(edge, destOrdinal);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ public HashJoinTransform(
@SuppressWarnings("unchecked")
public void addToDag(Planner p, Context context) {
determineLocalParallelism(LOCAL_PARALLELISM_USE_DEFAULT, context, p.isPreserveOrder());
PlannerVertex primary = p.xform2vertex.get(this.upstream().get(0));
PlannerVertex primary = p.transform2vertex.get(this.upstream().get(0));
List keyFns = toList(this.clauses, JoinClause::leftKeyFn);

List<Tag> tags = this.tags;
Expand All @@ -147,7 +147,7 @@ public void addToDag(Planner p, Context context) {
String collectorName = name() + "-collector";
int collectorOrdinal = 1;
for (Transform fromTransform : tailList(this.upstream())) {
PlannerVertex fromPv = p.xform2vertex.get(fromTransform);
PlannerVertex fromPv = p.transform2vertex.get(fromTransform);
JoinClause<?, ?, ?, ?> clause = this.clauses.get(collectorOrdinal - 1);
FunctionEx<Object, Object> getKeyFn = (FunctionEx<Object, Object>) clause.rightKeyFn();
FunctionEx<Object, Object> projectFn = (FunctionEx<Object, Object>) clause.rightProjectFn();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ public PeekTransform(
@Override
public void addToDag(Planner p, Context context) {
determineLocalParallelism(LOCAL_PARALLELISM_USE_DEFAULT, context, p.isPreserveOrder());
PlannerVertex peekedPv = p.xform2vertex.get(this.upstream().get(0));
PlannerVertex peekedPv = p.transform2vertex.get(this.upstream().get(0));
// Peeking transform doesn't add a vertex, so point to the upstream
// transform's vertex:
p.xform2vertex.put(this, peekedPv);
p.transform2vertex.put(this, peekedPv);
peekedPv.v.updateMetaSupplier(sup -> peekOutputP(toStringFn, shouldLogFn, sup));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@

public class StreamStageTest extends PipelineStreamTestSupport {

private static BiFunction<String, Integer, String> ENRICHING_FORMAT_FN =
private static final BiFunction<String, Integer, String> ENRICHING_FORMAT_FN =
(prefix, i) -> String.format("%s-%04d", prefix, i);

@Rule
Expand Down Expand Up @@ -284,7 +284,7 @@ public void fusing_mapToNull_trailing() {
test_fusing(
stage -> stage
.flatMap(Traversers::traverseItems)
.map(item -> (String) null),
.map(item -> null),
item -> Stream.empty()
);
}
Expand All @@ -297,23 +297,18 @@ private void test_fusing(Function<GeneralStage<Integer>, GeneralStage<String>> a
// When
StreamStage<Integer> sourceStage = streamStageFromList(input);
GeneralStage<String> mappedStage = addToPipelineFn.apply(sourceStage);
mappedStage.writeTo(sink);

// Then
mappedStage.writeTo(sink);
assertVertexCount(p.toDag(), 4);
assertContainsFused(true);
DAG dag = p.toDag();
assertContainsFused(dag, true);
assertVertexCount(dag, 4);
execute();
assertEquals(
streamToString(input.stream().flatMap(plainFlatMapFn), Objects::toString),
streamToString(sinkList.stream(), Object::toString));
}

private void assertVertexCount(DAG dag, int expectedCount) {
int[] count = {0};
dag.iterator().forEachRemaining(v -> count[0]++);
assertEquals("unexpected vertex count in DAG:\n" + dag.toDotString(), expectedCount, count[0]);
}

@Test
public void fusing_testWithBranch() {
// Given
Expand All @@ -327,8 +322,9 @@ public void fusing_testWithBranch() {
p.writeTo(sink, mapped1, mapped2);

// Then
assertContainsFused(false);
assertVertexCount(p.toDag(), 6);
DAG dag = p.toDag();
assertContainsFused(dag, false);
assertVertexCount(dag, 6);
execute();
assertEquals(
streamToString(input.stream().flatMap(t -> Stream.of(t + "-x-branch1", t + "-x-branch2")), identity()),
Expand All @@ -349,19 +345,47 @@ public void fusing_when_localParallelismDifferent_then_notFused() {
.writeTo(sink);

// Then
assertContainsFused(false);
assertVertexCount(p.toDag(), 5);
DAG dag = p.toDag();
assertContainsFused(dag, false);
assertVertexCount(dag, 5);
execute();
assertEquals(
streamToString(input.stream().map(t -> t + "-ab"), identity()),
streamToString(sinkList.stream(), Object::toString));
}

private void assertContainsFused(boolean expectedContains) {
String dotString = p.toDag().toDotString();
@Test
public void fusing_testLocalParallelism() {
// Given
List<Integer> input = sequence(itemCount);

// When
streamStageFromList(input)
.filter(item -> item % 2 == 0).setLocalParallelism(3)
.map(item -> item / 2).setLocalParallelism(3)
.flatMap(item -> Traversers.traverseItems(2 * item, 2 * item + 1)).setLocalParallelism(5)
.map(item -> item + 1).setLocalParallelism(5)
.writeTo(sink);

// Then
DAG dag = p.toDag();
assertEquals(3, dag.getVertex("fused(filter, map)").getLocalParallelism());
assertEquals(5, dag.getVertex("fused(flat-map, map-2)").getLocalParallelism());
execute();
assertEquals(
input.stream().map(t -> t + 1).collect(toList()),
sinkList.stream().sorted().collect(toList()));
}

private void assertContainsFused(DAG dag, boolean expectedContains) {
String dotString = dag.toDotString();
assertEquals(dotString, expectedContains, dotString.contains("fused"));
}

private void assertVertexCount(DAG dag, int expectedCount) {
assertEquals("Unexpected vertex count in DAG:\n" + dag.toDotString(), expectedCount, dag.vertices().size());
}

@Test
public void mapUsingService() {
// Given
Expand Down

0 comments on commit 916253a

Please sign in to comment.