Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CDAP-16709 implement manual broadcasts #12213

Merged
merged 1 commit into from
May 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,30 @@ public static void setupTest() throws Exception {
setupBatchArtifacts(APP_ARTIFACT_ID, DataPipelineApp.class);
}

@Test
public void testBroadcastJoin() throws Exception {
Schema expectedSchema = Schema.recordOf("purchases.users",
Schema.Field.of("purchases_region", Schema.of(Schema.Type.STRING)),
Schema.Field.of("purchases_purchase_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("purchases_user_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("users_region", Schema.of(Schema.Type.STRING)),
Schema.Field.of("users_user_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("users_name", Schema.of(Schema.Type.STRING)));
Set<StructuredRecord> expected = new HashSet<>();
expected.add(StructuredRecord.builder(expectedSchema)
.set("purchases_region", "us")
.set("purchases_purchase_id", 123)
.set("purchases_user_id", 0)
.set("users_region", "us")
.set("users_user_id", 0)
.set("users_name", "alice").build());

testSimpleAutoJoin(Arrays.asList("users", "purchases"), Collections.singletonList("users"),
expected, Engine.SPARK);
testSimpleAutoJoin(Arrays.asList("users", "purchases"), Collections.singletonList("purchases"),
expected, Engine.SPARK);
}

@Test
public void testAutoInnerJoin() throws Exception {
Schema expectedSchema = Schema.recordOf("purchases.users",
Expand Down Expand Up @@ -147,7 +171,7 @@ public void testAutoLeftOuterJoin() throws Exception {
.set("purchases_purchase_id", 456)
.set("purchases_user_id", 2).build());

//testSimpleAutoJoin(Collections.singletonList("purchases"), expected, Engine.SPARK);
testSimpleAutoJoin(Collections.singletonList("purchases"), expected, Engine.SPARK);
testSimpleAutoJoin(Collections.singletonList("purchases"), expected, Engine.MAPREDUCE);
}

Expand Down Expand Up @@ -220,6 +244,11 @@ public void testAutoOuterJoin() throws Exception {

private void testSimpleAutoJoin(List<String> required, Set<StructuredRecord> expected,
Engine engine) throws Exception {
testSimpleAutoJoin(required, Collections.emptyList(), expected, engine);
}

private void testSimpleAutoJoin(List<String> required, List<String> broadcast,
Set<StructuredRecord> expected, Engine engine) throws Exception {
/*
users ------|
|--> join --> sink
Expand All @@ -235,7 +264,7 @@ private void testSimpleAutoJoin(List<String> required, Set<StructuredRecord> exp
.addStage(new ETLStage("purchases", MockSource.getPlugin(purchaseInput, PURCHASE_SCHEMA)))
.addStage(new ETLStage("join", MockAutoJoiner.getPlugin(Arrays.asList("purchases", "users"),
Arrays.asList("region", "user_id"),
required)))
required, broadcast)))
.addStage(new ETLStage("sink", MockSink.getPlugin(output)))
.addConnection("users", "join")
.addConnection("purchases", "join")
Expand Down Expand Up @@ -273,6 +302,53 @@ private void testSimpleAutoJoin(List<String> required, Set<StructuredRecord> exp
Assert.assertEquals(expected, new HashSet<>(outputRecords));
}

@Test
public void testDoubleBroadcastJoin() throws Exception {
Schema expectedSchema = Schema.recordOf(
"purchases.users.interests",
Schema.Field.of("purchases_region", Schema.of(Schema.Type.STRING)),
Schema.Field.of("purchases_purchase_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("purchases_user_id", Schema.of(Schema.Type.INT)),
Schema.Field.of("users_region", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("users_user_id", Schema.nullableOf(Schema.of(Schema.Type.INT))),
Schema.Field.of("users_name", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("interests_region", Schema.nullableOf(Schema.of(Schema.Type.STRING))),
Schema.Field.of("interests_user_id", Schema.nullableOf(Schema.of(Schema.Type.INT))),
Schema.Field.of("interests_interest", Schema.nullableOf(Schema.of(Schema.Type.STRING))));

Set<StructuredRecord> expected = new HashSet<>();
expected.add(StructuredRecord.builder(expectedSchema)
.set("purchases_region", "us")
.set("purchases_purchase_id", 123)
.set("purchases_user_id", 0)
.set("users_region", "us")
.set("users_user_id", 0)
.set("users_name", "alice")
.set("interests_region", "us")
.set("interests_user_id", 0)
.set("interests_interest", "food").build());
expected.add(StructuredRecord.builder(expectedSchema)
.set("purchases_region", "us")
.set("purchases_purchase_id", 123)
.set("purchases_user_id", 0)
.set("users_region", "us")
.set("users_user_id", 0)
.set("users_name", "alice")
.set("interests_region", "us")
.set("interests_user_id", 0)
.set("interests_interest", "sports").build());
expected.add(StructuredRecord.builder(expectedSchema)
.set("purchases_region", "us")
.set("purchases_purchase_id", 456)
.set("purchases_user_id", 2)
.set("interests_region", "us")
.set("interests_user_id", 2)
.set("interests_interest", "gaming").build());

testTripleAutoJoin(Collections.singletonList("purchases"), Arrays.asList("purchases", "interests"),
expected, Engine.SPARK);
}

@Test
public void testTripleAutoLeftRequiredJoin() throws Exception {
Schema expectedSchema = Schema.recordOf(
Expand Down Expand Up @@ -418,8 +494,13 @@ public void testTripleAutoTwoRequiredJoin() throws Exception {
testTripleAutoJoin(Arrays.asList("users", "interests"), expected, Engine.MAPREDUCE);
}

public void testTripleAutoJoin(List<String> required, Set<StructuredRecord> expected,
Engine engine) throws Exception {
private void testTripleAutoJoin(List<String> required, Set<StructuredRecord> expected,
Engine engine) throws Exception {
testTripleAutoJoin(required, Collections.emptyList(), expected, engine);
}

private void testTripleAutoJoin(List<String> required, List<String> broadcast,
Set<StructuredRecord> expected, Engine engine) throws Exception {
/*
users ------|
|
Expand All @@ -440,7 +521,7 @@ public void testTripleAutoJoin(List<String> required, Set<StructuredRecord> expe
.addStage(new ETLStage("interests", MockSource.getPlugin(interestInput, INTEREST_SCHEMA)))
.addStage(new ETLStage("join", MockAutoJoiner.getPlugin(Arrays.asList("purchases", "users", "interests"),
Arrays.asList("region", "user_id"),
required)))
required, broadcast)))
.addStage(new ETLStage("sink", MockSink.getPlugin(output)))
.addConnection("users", "join")
.addConnection("purchases", "join")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ public JoinDefinition build() {
throw new InvalidJoinException("At least two stages must be specified.");
}

if (stages.stream().allMatch(JoinStage::isBroadcast)) {
throw new InvalidJoinException("Cannot broadcast all stages.");
}

// validate the join condition
if (condition == null) {
throw new InvalidJoinException("A join condition must be specified.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,11 @@ public Builder setRequired(boolean required) {
}

/**
* Set whether the stage data should be broadcast during the join. In order to be broadcast, the stage data must
* Hint that the stage data should be broadcast during the join. In order to be broadcast, the stage data must
* be below 8gb and fit entirely in memory. You cannot broadcast both sides of a join.
* This is just a hint and will not always be honored.
* MapReduce pipelines will currently ignore this flag. Spark pipelines will hint to Spark to broadcast, but Spark
* may still decide to do a normal join depending on the type of join being performed and the datasets involved.
*/
public Builder setBroadcast(boolean broadcast) {
this.broadcast = broadcast;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,34 @@ public void testJoinKeyMismatchedFieldTypeThrowsException() {
}
}

@Test
public void testAllBroadcastThrowsException() {
JoinStage purchases = JoinStage.builder("purchases", PURCHASE_SCHEMA).setBroadcast(true).build();
JoinStage users = JoinStage.builder("users", USER_SCHEMA).setBroadcast(true).build();

try {
JoinDefinition.builder()
.select(new JoinField("purchases", "id", "purchase_id"),
new JoinField("users", "id", "user_id"),
new JoinField("purchases", "ts"),
new JoinField("purchases", "price"),
new JoinField("purchases", "coupon"),
new JoinField("users", "name"),
new JoinField("users", "email"),
new JoinField("users", "age"),
new JoinField("users", "bday"))
.from(purchases, users)
.on(JoinCondition.onKeys()
.addKey(new JoinKey("purchases", Collections.singletonList("user_id")))
.addKey(new JoinKey("users", Collections.singletonList("id")))
.build())
.build();
Assert.fail("Invalid join condition did not fail as expected");
} catch (InvalidJoinException e) {
// expected
}
}

private void testUserPurchaseSchema(JoinStage purchases, JoinStage users, Schema expected) {
JoinDefinition definition = JoinDefinition.builder()
.select(new JoinField("purchases", "id", "purchase_id"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@

import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
Expand Down Expand Up @@ -364,7 +365,20 @@ public void runPipeline(PipelinePhase pipelinePhase, String sourcePluginType,
*/
private SparkCollection<Object> handleAutoJoin(JoinDefinition joinDefinition,
Map<String, SparkCollection<Object>> inputDataCollections) {
Iterator<JoinStage> stageIter = joinDefinition.getStages().iterator();
// sort stages to join so that broadcasts happen last. This is to ensure that the left side is not a broadcast
// so that we don't try to broadcast both sides of the join. It also causes less data to be shuffled for the
// non-broadcast joins.
List<JoinStage> joinOrder = new ArrayList<>(joinDefinition.getStages());
joinOrder.sort((s1, s2) -> {
if (s1.isBroadcast() && !s2.isBroadcast()) {
return 1;
} else if (!s1.isBroadcast() && s2.isBroadcast()) {
return -1;
}
return 0;
});

Iterator<JoinStage> stageIter = joinOrder.iterator();
JoinStage left = stageIter.next();
SparkCollection<Object> leftCollection = inputDataCollections.get(left.getStageName());
Schema leftSchema = left.getSchema();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ public void initialize() throws Exception {

SparkConf sparkConf = new SparkConf();
sparkConf.set("spark.speculation", "false");
// turn off auto-broadcast by default until we better understand the implications and can set this to a
// value that we are confident is safe.
sparkConf.set("spark.sql.autoBroadcastJoinThreshold", "-1");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there any constant class we can use for these config?

Copy link
Contributor Author

@albertshau albertshau May 28, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not that I know of. If there was, it could potentially change across different Spark versions too.

context.setSparkConf(sparkConf);

Map<String, String> properties = context.getSpecification().getProperties();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConversions;
import scala.collection.Seq;
Expand Down Expand Up @@ -111,6 +112,9 @@ public SparkCollection<T> join(JoinRequest joinRequest) {
}
seenRequired = seenRequired || toJoin.isRequired();

if (toJoin.isBroadcast()) {
right = functions.broadcast(right);
}
joined = joined.join(right, joinOn, joinType);

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.catalyst.encoders.RowEncoder;
import org.apache.spark.sql.functions;
import org.apache.spark.sql.types.StructType;
import scala.collection.JavaConversions;
import scala.collection.Seq;
Expand Down Expand Up @@ -112,7 +113,10 @@ public SparkCollection<T> join(JoinRequest joinRequest) {
joinType = "outer";
}
seenRequired = seenRequired || toJoin.isRequired();


if (toJoin.isBroadcast()) {
right = functions.broadcast(right);
}
joined = joined.join(right, joinOn, joinType);

/*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,14 +68,19 @@ public JoinDefinition define(AutoJoinerContext context) {
Map<String, JoinStage> inputStages = context.getInputStages();
List<JoinStage> from = new ArrayList<>(inputStages.size());
Set<String> required = new HashSet<>(conf.getRequired());
Set<String> broadcast = new HashSet<>(conf.getBroadcast());
List<JoinField> selectedFields = new ArrayList<>();
JoinCondition.OnKeys.Builder condition = JoinCondition.onKeys()
.setNullSafe(conf.isNullSafe());
for (String stageName : conf.getStages()) {
JoinStage stage = inputStages.get(stageName);
JoinStage.Builder stageBuilder = JoinStage.builder(inputStages.get(stageName));
if (!required.contains(stageName)) {
stage = JoinStage.builder(stage).isOptional().build();
stageBuilder.isOptional();
}
if (broadcast.contains(stageName)) {
stageBuilder.setBroadcast(true);
}
JoinStage stage = stageBuilder.build();
from.add(stage);

condition.addKey(new JoinKey(stageName, conf.getKey()));
Expand Down Expand Up @@ -110,6 +115,7 @@ public static class Conf extends PluginConfig {
@Nullable
private Boolean nullSafe;

private String broadcast;

List<String> getKey() {
return GSON.fromJson(key, LIST);
Expand All @@ -126,13 +132,23 @@ List<String> getRequired() {
boolean isNullSafe() {
return nullSafe == null ? true : nullSafe;
}

List<String> getBroadcast() {
return broadcast == null ? Collections.emptyList() : GSON.fromJson(broadcast, LIST);
}
}

public static ETLPlugin getPlugin(List<String> stages, List<String> key, List<String> required) {
return getPlugin(stages, key, required, Collections.emptyList());
}

public static ETLPlugin getPlugin(List<String> stages, List<String> key, List<String> required,
List<String> broadcast) {
Map<String, String> properties = new HashMap<>();
properties.put("stages", GSON.toJson(stages));
properties.put("required", GSON.toJson(required));
properties.put("key", GSON.toJson(key));
properties.put("broadcast", GSON.toJson(broadcast));
return new ETLPlugin(NAME, BatchJoiner.PLUGIN_TYPE, properties, null);
}

Expand All @@ -152,6 +168,7 @@ private static PluginClass getPluginClass() {
properties.put("required", new PluginPropertyField("required", "", "string", false, false));
properties.put("key", new PluginPropertyField("key", "", "string", true, false));
properties.put("nullSafe", new PluginPropertyField("nullSafe", "", "boolean", false, false));
properties.put("broadcast", new PluginPropertyField("broadcast", "", "string", false, false));
return new PluginClass(BatchJoiner.PLUGIN_TYPE, NAME, "", MockAutoJoiner.class.getName(), "conf", properties);
}

Expand Down