From 6ceb169608428a651d53c93bf73ca5ac53a6bde2 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 23 Jun 2015 01:50:31 -0700 Subject: [PATCH] [SPARK-8300] DataFrame hint for broadcast join. Users can now do ```scala left.join(broadcast(right), "joinKey") ``` to give the query planner a hint that "right" DataFrame is small and should be broadcasted. Author: Reynold Xin Closes #6751 from rxin/broadcastjoin-hint and squashes the following commits: 953eec2 [Reynold Xin] Code review feedback. 88752d8 [Reynold Xin] Fixed import. 8187b88 [Reynold Xin] [SPARK-8300] DataFrame hint for broadcast join. --- .../plans/logical/basicOperators.scala | 8 ++++++ .../spark/sql/execution/SparkStrategies.scala | 25 +++++++++++++------ .../org/apache/spark/sql/functions.scala | 17 +++++++++++++ .../apache/spark/sql/DataFrameJoinSuite.scala | 17 +++++++++++++ 4 files changed, 59 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index f8e5916d69f9c..7814e51628db6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -130,6 +130,14 @@ case class Join( } } +/** + * A hint for the optimizer that we should broadcast the `child` if used in a join operator. + */ +case class BroadcastHint(child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = child.output +} + + case class Except(left: LogicalPlan, right: LogicalPlan) extends BinaryNode { override def output: Seq[Attribute] = left.output } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 422992d019c7b..5c420eb9d761f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.columnar.{InMemoryColumnarTableScan, InMemoryRelation} import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand} @@ -52,6 +52,18 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Matches a plan whose output should be small enough to be used in broadcast join. + */ + object CanBroadcast { + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { + case BroadcastHint(p) => Some(p) + case p if sqlContext.conf.autoBroadcastJoinThreshold > 0 && + p.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => Some(p) + case _ => None + } + } + /** * Uses the ExtractEquiJoinKeys pattern to find joins where at least some of the predicates can be * evaluated by matching hash keys. @@ -80,15 +92,11 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - right.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildRight) - case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, left, right) - if sqlContext.conf.autoBroadcastJoinThreshold > 0 && - left.statistics.sizeInBytes <= sqlContext.conf.autoBroadcastJoinThreshold => - makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) + case ExtractEquiJoinKeys(Inner, leftKeys, rightKeys, condition, CanBroadcast(left), right) => + makeBroadcastHashJoin(leftKeys, rightKeys, left, right, condition, joins.BuildLeft) // If the sort merge join option is set, we want to use sort merge join prior to hashjoin // for now let's support inner join first, then add outer join @@ -329,6 +337,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case e @ EvaluatePython(udf, child, _) => BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd) :: Nil + case BroadcastHint(child) => apply(child) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8cea826ae6921..38d9085a505fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import org.apache.spark.annotation.Experimental import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, Star} import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -565,6 +566,22 @@ object functions { array((colName +: colNames).map(col) : _*) } + /** + * Marks a DataFrame as small enough for use in broadcast joins. + * + * The following example marks the right DataFrame for broadcast hash join using `joinKey`. + * {{{ + * // left and right are DataFrames + * left.join(broadcast(right), "joinKey") + * }}} + * + * @group normal_funcs + * @since 1.5.0 + */ + def broadcast(df: DataFrame): DataFrame = { + DataFrame(df.sqlContext, BroadcastHint(df.logicalPlan)) + } + /** * Returns the first column that is not null. * {{{ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 6165764632c29..e1c6c706242d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ +import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions._ class DataFrameJoinSuite extends QueryTest { @@ -93,4 +94,20 @@ class DataFrameJoinSuite extends QueryTest { left.join(right, left("key") === right("key")), Row(1, 1, 1, 1) :: Row(2, 1, 2, 2) :: Nil) } + + test("broadcast join hint") { + val df1 = Seq((1, "1"), (2, "2")).toDF("key", "value") + val df2 = Seq((1, "1"), (2, "2")).toDF("key", "value") + + // equijoin - should be converted into broadcast join + val plan1 = df1.join(broadcast(df2), "key").queryExecution.executedPlan + assert(plan1.collect { case p: BroadcastHashJoin => p }.size === 1) + + // no join key -- should not be a broadcast join + val plan2 = df1.join(broadcast(df2)).queryExecution.executedPlan + assert(plan2.collect { case p: BroadcastHashJoin => p }.size === 0) + + // planner should not crash without a join + broadcast(df1).queryExecution.executedPlan + } }