Skip to content

Commit

Permalink
Merge pull request #12213 from cdapio/feature_release/CDAP-16709-manu…
Browse files Browse the repository at this point in the history
…al-broadcast

CDAP-16709 implement manual broadcasts
  • Loading branch information
albertshau committed May 28, 2020
2 parents 3fa6fb1 + 65d9edc commit 4807776
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@ public static void setupTest() throws Exception {
setupBatchArtifacts(APP_ARTIFACT_ID, DataPipelineApp.class);
}

@Test
public void testBroadcastJoin() throws Exception {
Schema expectedSchema = Schema.recordOf("purchases.users",
Schema.Field.of("purchases_region", Schema.of(Schema.Type.STRING)),
Schema.Field.of("purchases_purchase_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("purchases_user_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("users_region", Schema.of(Schema.Type.STRING)),
Schema.Field.of("users_user_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("users_name", Schema.of(Schema.Type.STRING)));
Set<StructuredRecord> expected = new HashSet<>();
expected.add(StructuredRecord.builder(expectedSchema)
.set("purchases_region", "us")
.set("purchases_purchase_id", 123)
.set("purchases_user_id", 0)
.set("users_region", "us")
.set("users_user_id", 0)
.set("users_name", "alice").build());

testSimpleAutoJoin(Arrays.asList("users", "purchases"), Collections.singletonList("users"),
expected, Engine.SPARK);
testSimpleAutoJoin(Arrays.asList("users", "purchases"), Collections.singletonList("purchases"),
expected, Engine.SPARK);
}

@Test
public void testAutoInnerJoin() throws Exception {
Schema expectedSchema = Schema.recordOf("purchases.users",
Expand Down Expand Up @@ -147,7 +171,7 @@ public void testAutoLeftOuterJoin() throws Exception {
.set("purchases_purchase_id", 456)
.set("purchases_user_id", 2).build());

//testSimpleAutoJoin(Collections.singletonList("purchases"), expected, Engine.SPARK);
testSimpleAutoJoin(Collections.singletonList("purchases"), expected, Engine.SPARK);
testSimpleAutoJoin(Collections.singletonList("purchases"), expected, Engine.MAPREDUCE);
}

Expand Down Expand Up @@ -220,6 +244,11 @@ public void testAutoOuterJoin() throws Exception {

private void testSimpleAutoJoin(List<String> required, Set<StructuredRecord> expected,
Engine engine) throws Exception {
testSimpleAutoJoin(required, Collections.emptyList(), expected, engine);
}

private void testSimpleAutoJoin(List<String> required, List<String> broadcast,
Set<StructuredRecord> expected, Engine engine) throws Exception {
/*
users ------|
|--> join --> sink
Expand All @@ -235,7 +264,7 @@ private void testSimpleAutoJoin(List<String> required, Set<StructuredRecord> exp
.addStage(new ETLStage("purchases", MockSource.getPlugin(purchaseInput, PURCHASE_SCHEMA)))
.addStage(new ETLStage("join", MockAutoJoiner.getPlugin(Arrays.asList("purchases", "users"),
Arrays.asList("region", "user_id"),
required)))
required, broadcast)))
.addStage(new ETLStage("sink", MockSink.getPlugin(output)))
.addConnection("users", "join")
.addConnection("purchases", "join")
Expand Down Expand Up @@ -273,6 +302,53 @@ private void testSimpleAutoJoin(List<String> required, Set<StructuredRecord> exp
Assert.assertEquals(expected, new HashSet<>(outputRecords));
}

@Test
public void testDoubleBroadcastJoin() throws Exception {
Schema expectedSchema = Schema.recordOf(
"purchases.users.interests",
Schema.Field.of("purchases_region", Schema.of(Schema.Type.STRING)),
Schema.Field.of("purchases_purchase_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("purchases_user_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("users_region", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("users_user_id", Schema.nullableOf(Schema.of(Schema.Type.INT))),
Schema.Field.of("users_name", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("interests_region", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("interests_user_id", Schema.nullableOf(Schema.of(Schema.Type.INT))),
Schema.Field.of("interests_interest", Schema.nullableOf(Schema.of(Schema.Type.STRING))));

Set<StructuredRecord> expected = new HashSet<>();
expected.add(StructuredRecord.builder(expectedSchema)
.set("purchases_region", "us")
.set("purchases_purchase_id", 123)
.set("purchases_user_id", 0)
.set("users_region", "us")
.set("users_user_id", 0)
.set("users_name", "alice")
.set("interests_region", "us")
.set("interests_user_id", 0)
.set("interests_interest", "food").build());
expected.add(StructuredRecord.builder(expectedSchema)
.set("purchases_region", "us")
.set("purchases_purchase_id", 123)
.set("purchases_user_id", 0)
.set("users_region", "us")
.set("users_user_id", 0)
.set("users_name", "alice")
.set("interests_region", "us")
.set("interests_user_id", 0)
.set("interests_interest", "sports").build());
expected.add(StructuredRecord.builder(expectedSchema)
.set("purchases_region", "us")
.set("purchases_purchase_id", 456)
.set("purchases_user_id", 2)
.set("interests_region", "us")
.set("interests_user_id", 2)
.set("interests_interest", "gaming").build());

testTripleAutoJoin(Collections.singletonList("purchases"), Arrays.asList("purchases", "interests"),
expected, Engine.SPARK);
}

@Test
public void testTripleAutoLeftRequiredJoin() throws Exception {
Schema expectedSchema = Schema.recordOf(
Expand Down Expand Up @@ -418,8 +494,13 @@ public void testTripleAutoTwoRequiredJoin() throws Exception {
testTripleAutoJoin(Arrays.asList("users", "interests"), expected, Engine.MAPREDUCE);
}

public void testTripleAutoJoin(List<String> required, Set<StructuredRecord> expected,
Engine engine) throws Exception {
private void testTripleAutoJoin(List<String> required, Set<StructuredRecord> expected,
Engine engine) throws Exception {
testTripleAutoJoin(required, Collections.emptyList(), expected, engine);
}

private void testTripleAutoJoin(List<String> required, List<String> broadcast,
Set<StructuredRecord> expected, Engine engine) throws Exception {
/*
users ------|
|
Expand All @@ -440,7 +521,7 @@ public void testTripleAutoJoin(List<String> required, Set<StructuredRecord> expe
.addStage(new ETLStage("interests", MockSource.getPlugin(interestInput, INTEREST_SCHEMA)))
.addStage(new ETLStage("join", MockAutoJoiner.getPlugin(Arrays.asList("purchases", "users", "interests"),
Arrays.asList("region", "user_id"),
required)))
required, broadcast)))
.addStage(new ETLStage("sink", MockSink.getPlugin(output)))
.addConnection("users", "join")
.addConnection("purchases", "join")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ public JoinDefinition build() {
throw new InvalidJoinException("At least two stages must be specified.");
}

if (stages.stream().allMatch(JoinStage::isBroadcast)) {
throw new InvalidJoinException("Cannot broadcast all stages.");
}

// validate the join condition
if (condition == null) {
throw new InvalidJoinException("A join condition must be specified.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ public Builder setRequired(boolean required) {
}

/**
* Set whether the stage data should be broadcast during the join. In order to be broadcast, the stage data must
* Hint that the stage data should be broadcast during the join. In order to be broadcast, the stage data must
* be below 8gb and fit entirely in memory. You cannot broadcast both sides of a join.
* This is just a hint and will not always be honored.
* MapReduce pipelines will currently ignore this flag. Spark pipelines will hint to Spark to broadcast, but Spark
* may still decide to do a normal join depending on the type of join being performed and the datasets involved.
*/
public Builder setBroadcast(boolean broadcast) {
this.broadcast = broadcast;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,34 @@ public void testJoinKeyMismatchedFieldTypeThrowsException() {
}
}

@Test
public void testAllBroadcastThrowsException() {
JoinStage purchases = JoinStage.builder("purchases", PURCHASE_SCHEMA).setBroadcast(true).build();
JoinStage users = JoinStage.builder("users", USER_SCHEMA).setBroadcast(true).build();

try {
JoinDefinition.builder()
.select(new JoinField("purchases", "id", "purchase_id"),
new JoinField("users", "id", "user_id"),
new JoinField("purchases", "ts"),
new JoinField("purchases", "price"),
new JoinField("purchases", "coupon"),
new JoinField("users", "name"),
new JoinField("users", "email"),
new JoinField("users", "age"),
new JoinField("users", "bday"))
.from(purchases, users)
.on(JoinCondition.onKeys()
.addKey(new JoinKey("purchases", Collections.singletonList("user_id")))
.addKey(new JoinKey("users", Collections.singletonList("id")))
.build())
.build();
Assert.fail("Invalid join condition did not fail as expected");
} catch (InvalidJoinException e) {
// expected
}
}

private void testUserPurchaseSchema(JoinStage purchases, JoinStage users, Schema expected) {
JoinDefinition definition = JoinDefinition.builder()
.select(new JoinField("purchases", "id", "purchase_id"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
Expand Down Expand Up @@ -364,7 +365,20 @@ public void runPipeline(PipelinePhase pipelinePhase, String sourcePluginType,
*/
private SparkCollection<Object> handleAutoJoin(JoinDefinition joinDefinition,
Map<String, SparkCollection<Object>> inputDataCollections) {
Iterator<JoinStage> stageIter = joinDefinition.getStages().iterator();
// sort stages to join so that broadcasts happen last. This is to ensure that the left side is not a broadcast
// so that we don't try to broadcast both sides of the join. It also causes less data to be shuffled for the
// non-broadcast joins.
List<JoinStage> joinOrder = new ArrayList<>(joinDefinition.getStages());
joinOrder.sort((s1, s2) -> {
if (s1.isBroadcast() && !s2.isBroadcast()) {
return 1;
} else if (!s1.isBroadcast() && s2.isBroadcast()) {
return -1;
}
return 0;
});

Iterator<JoinStage> stageIter = joinOrder.iterator();
JoinStage left = stageIter.next();
SparkCollection<Object> leftCollection = inputDataCollections.get(left.getStageName());
Schema leftSchema = left.getSchema();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ public void initialize() throws Exception {

SparkConf sparkConf = new SparkConf();
sparkConf.set("spark.speculation", "false");
// turn off auto-broadcast by default until we better understand the implications and can set this to a
// value that we are confident is safe.
sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "-1");
context.setSparkConf(sparkConf);

Map<String, String> properties = context.getSpecification().getProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConversions;
import scala.collection.Seq;
Expand Down Expand Up @@ -111,6 +112,9 @@ public SparkCollection<T> join(JoinRequest joinRequest) {
}
seenRequired = seenRequired || toJoin.isRequired();

if (toJoin.isBroadcast()) {
right = functions.broadcast(right);
}
joined = joined.join(right, joinOn, joinType);

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConversions;
import scala.collection.Seq;
Expand Down Expand Up @@ -112,7 +113,10 @@ public SparkCollection<T> join(JoinRequest joinRequest) {
joinType = "outer";
}
seenRequired = seenRequired || toJoin.isRequired();


if (toJoin.isBroadcast()) {
right = functions.broadcast(right);
}
joined = joined.join(right, joinOn, joinType);

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,19 @@ public JoinDefinition define(AutoJoinerContext context) {
Map<String, JoinStage> inputStages = context.getInputStages();
List<JoinStage> from = new ArrayList<>(inputStages.size());
Set<String> required = new HashSet<>(conf.getRequired());
Set<String> broadcast = new HashSet<>(conf.getBroadcast());
List<JoinField> selectedFields = new ArrayList<>();
JoinCondition.OnKeys.Builder condition = JoinCondition.onKeys()
.setNullSafe(conf.isNullSafe());
for (String stageName : conf.getStages()) {
JoinStage stage = inputStages.get(stageName);
JoinStage.Builder stageBuilder = JoinStage.builder(inputStages.get(stageName));
if (!required.contains(stageName)) {
stage = JoinStage.builder(stage).isOptional().build();
stageBuilder.isOptional();
}
if (broadcast.contains(stageName)) {
stageBuilder.setBroadcast(true);
}
JoinStage stage = stageBuilder.build();
from.add(stage);

condition.addKey(new JoinKey(stageName, conf.getKey()));
Expand Down Expand Up @@ -110,6 +115,7 @@ public static class Conf extends PluginConfig {
@Nullable
private Boolean nullSafe;

private String broadcast;

List<String> getKey() {
return GSON.fromJson(key, LIST);
Expand All @@ -126,13 +132,23 @@ List<String> getRequired() {
boolean isNullSafe() {
return nullSafe == null ? true : nullSafe;
}

List<String> getBroadcast() {
return broadcast == null ? Collections.emptyList() : GSON.fromJson(broadcast, LIST);
}
}

public static ETLPlugin getPlugin(List<String> stages, List<String> key, List<String> required) {
return getPlugin(stages, key, required, Collections.emptyList());
}

public static ETLPlugin getPlugin(List<String> stages, List<String> key, List<String> required,
List<String> broadcast) {
Map<String, String> properties = new HashMap<>();
properties.put("stages", GSON.toJson(stages));
properties.put("required", GSON.toJson(required));
properties.put("key", GSON.toJson(key));
properties.put("broadcast", GSON.toJson(broadcast));
return new ETLPlugin(NAME, BatchJoiner.PLUGIN_TYPE, properties, null);
}

Expand All @@ -152,6 +168,7 @@ private static PluginClass getPluginClass() {
properties.put("required", new PluginPropertyField("required", "", "string", false, false));
properties.put("key", new PluginPropertyField("key", "", "string", true, false));
properties.put("nullSafe", new PluginPropertyField("nullSafe", "", "boolean", false, false));
properties.put("broadcast", new PluginPropertyField("broadcast", "", "string", false, false));
return new PluginClass(BatchJoiner.PLUGIN_TYPE, NAME, "", MockAutoJoiner.class.getName(), "conf", properties);
}

Expand Down

0 comments on commit 4807776

Please sign in to comment.