Skip to content

Commit

Permalink
[SPARK-22103] Move HashAggregateExec parent consume to a separate fun…
Browse files Browse the repository at this point in the history
…ction in codegen

## What changes were proposed in this pull request?

HashAggregateExec codegen uses two paths for fast hash table and a generic one.
It generates code paths for iterating over both, and both code paths generate the consume code of the parent operator, resulting in that code being expanded twice.
This leads to a long generated function that might be an issue for the compiler (see e.g. SPARK-21603).
I propose to remove the double expansion by generating the consume code in a helper function that can just be called from both iterating loops.

An issue with separating the `consume` code to a helper function was that a number of places relied and assumed on being in the scope of an outside `produce` loop and e.g. use `continue` to jump out.
I replaced such code flows with nested scopes. It is code that should be handled the same by compiler, while getting rid of depending on assumptions that are outside of the `consume`'s own scope.

## How was this patch tested?

Existing test coverage.

Author: Juliusz Sompolski <julek@databricks.com>

Closes apache#19324 from juliuszsompolski/aggrconsumecodegen.
  • Loading branch information
juliuszsompolski authored and gatorsmile committed Sep 25, 2017
1 parent 2c5b9b1 commit 038b185
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 89 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,9 @@ class CodegenContext {
private val classFunctions: mutable.Map[String, mutable.Map[String, String]] =
mutable.Map(outerClassName -> mutable.Map.empty[String, String])

// Verbatim extra code to be added to the OuterClass.
private val extraCode: mutable.ListBuffer[String] = mutable.ListBuffer[String]()

// Returns the size of the most recently added class.
private def currClassSize(): Int = classSize(classes.head._1)

Expand Down Expand Up @@ -328,6 +331,22 @@ class CodegenContext {
(inlinedFunctions ++ initNestedClasses ++ declareNestedClasses).mkString("\n")
}

/**
* Emits any source code added with addExtraCode
*/
def emitExtraCode(): String = {
extraCode.mkString("\n")
}

/**
* Add extra source code to the outermost generated class.
* @param code verbatim source code to be added.
*/
def addExtraCode(code: String): Unit = {
extraCode.append(code)
classSize(outerClassName) += code.length
}

final val JAVA_BOOLEAN = "boolean"
final val JAVA_BYTE = "byte"
final val JAVA_SHORT = "short"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,14 @@ trait CodegenSupport extends SparkPlan {
*
* This should be override by subclass to support codegen.
*
* For example, Filter will generate the code like this:
* Note: The operator should not assume the existence of an outer processing loop,
* which it can jump from with "continue;"!
*
* For example, filter could generate this:
* # code to evaluate the predicate expression, result is isNull1 and value2
* if (isNull1 || !value2) continue;
* # call consume(), which will call parent.doConsume()
* if (!isNull1 && value2) {
* # call consume(), which will call parent.doConsume()
* }
*
* Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input).
*/
Expand Down Expand Up @@ -329,6 +332,15 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
def doCodeGen(): (CodegenContext, CodeAndComment) = {
val ctx = new CodegenContext
val code = child.asInstanceOf[CodegenSupport].produce(ctx, this)

// main next function.
ctx.addNewFunction("processNext",
s"""
protected void processNext() throws java.io.IOException {
${code.trim}
}
""", inlineToOuterClass = true)

val source = s"""
public Object generate(Object[] references) {
return new GeneratedIterator(references);
Expand All @@ -352,9 +364,7 @@ case class WholeStageCodegenExec(child: SparkPlan) extends UnaryExecNode with Co
${ctx.initPartition()}
}

protected void processNext() throws java.io.IOException {
${code.trim}
}
${ctx.emitExtraCode()}

${ctx.declareAddedFunctions()}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -425,12 +425,14 @@ case class HashAggregateExec(

/**
* Generate the code for output.
* @return function name for the result code.
*/
private def generateResultCode(
ctx: CodegenContext,
keyTerm: String,
bufferTerm: String,
plan: String): String = {
private def generateResultFunction(ctx: CodegenContext): String = {
val funcName = ctx.freshName("doAggregateWithKeysOutput")
val keyTerm = ctx.freshName("keyTerm")
val bufferTerm = ctx.freshName("bufferTerm")

val body =
if (modes.contains(Final) || modes.contains(Complete)) {
// generate output using resultExpressions
ctx.currentVars = null
Expand Down Expand Up @@ -462,18 +464,36 @@ case class HashAggregateExec(
$evaluateAggResults
${consume(ctx, resultVars)}
"""

} else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
// This should be the last operator in a stage, we should output UnsafeRow directly
val joinerTerm = ctx.freshName("unsafeRowJoiner")
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
s"$joinerTerm = $plan.createUnsafeJoiner();")
val resultRow = ctx.freshName("resultRow")
// resultExpressions are Attributes of groupingExpressions and aggregateBufferAttributes.
assert(resultExpressions.forall(_.isInstanceOf[Attribute]))
assert(resultExpressions.length ==
groupingExpressions.length + aggregateBufferAttributes.length)

ctx.currentVars = null

ctx.INPUT_ROW = keyTerm
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).genCode(ctx)
}
val evaluateKeyVars = evaluateVariables(keyVars)

ctx.INPUT_ROW = bufferTerm
val resultBufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) =>
BoundReference(i, e.dataType, e.nullable).genCode(ctx)
}
val evaluateResultBufferVars = evaluateVariables(resultBufferVars)

ctx.currentVars = keyVars ++ resultBufferVars
val inputAttrs = resultExpressions.map(_.toAttribute)
val resultVars = resultExpressions.map { e =>
BindReferences.bindReference(e, inputAttrs).genCode(ctx)
}
s"""
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
${consume(ctx, null, resultRow)}
$evaluateKeyVars
$evaluateResultBufferVars
${consume(ctx, resultVars)}
"""

} else {
// generate result based on grouping key
ctx.INPUT_ROW = keyTerm
Expand All @@ -483,6 +503,13 @@ case class HashAggregateExec(
}
consume(ctx, eval)
}
ctx.addNewFunction(funcName,
s"""
private void $funcName(UnsafeRow $keyTerm, UnsafeRow $bufferTerm)
throws java.io.IOException {
$body
}
""")
}

/**
Expand Down Expand Up @@ -581,11 +608,6 @@ case class HashAggregateExec(
val iterTerm = ctx.freshName("mapIter")
ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")

val doAgg = ctx.freshName("doAggregateWithKeys")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
val avgHashProbe = metricTerm(ctx, "avgHashProbe")

def generateGenerateCode(): String = {
if (isFastHashMapEnabled) {
if (isVectorizedHashMapEnabled) {
Expand All @@ -599,10 +621,14 @@ case class HashAggregateExec(
}
} else ""
}
ctx.addExtraCode(generateGenerateCode())

val doAgg = ctx.freshName("doAggregateWithKeys")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
val avgHashProbe = metricTerm(ctx, "avgHashProbe")
val doAggFuncName = ctx.addNewFunction(doAgg,
s"""
${generateGenerateCode}
private void $doAgg() throws java.io.IOException {
$hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
Expand All @@ -618,7 +644,7 @@ case class HashAggregateExec(
// generate code for output
val keyTerm = ctx.freshName("aggKey")
val bufferTerm = ctx.freshName("aggBuffer")
val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
val outputFunc = generateResultFunction(ctx)
val numOutput = metricTerm(ctx, "numOutputRows")

// The child could change `copyResult` to true, but we had already consumed all the rows,
Expand All @@ -641,7 +667,7 @@ case class HashAggregateExec(
$numOutput.add(1);
UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
$outputCode
$outputFunc($keyTerm, $bufferTerm);

if (shouldStop()) return;
}
Expand All @@ -654,18 +680,23 @@ case class HashAggregateExec(
val row = ctx.freshName("fastHashMapRow")
ctx.currentVars = null
ctx.INPUT_ROW = row
var schema: StructType = groupingKeySchema
bufferSchema.foreach(i => schema = schema.add(i))
val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) })
val generateKeyRow = GenerateUnsafeProjection.createCode(ctx,
groupingKeySchema.toAttributes.zipWithIndex
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) }
)
val generateBufferRow = GenerateUnsafeProjection.createCode(ctx,
bufferSchema.toAttributes.zipWithIndex
.map { case (attr, i) =>
BoundReference(groupingKeySchema.length + i, attr.dataType, attr.nullable) })
s"""
| while ($iterTermForFastHashMap.hasNext()) {
| $numOutput.add(1);
| org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
| (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
| $iterTermForFastHashMap.next();
| ${generateRow.code}
| ${consume(ctx, Seq.empty, {generateRow.value})}
| ${generateKeyRow.code}
| ${generateBufferRow.code}
| $outputFunc(${generateKeyRow.value}, ${generateBufferRow.value});
|
| if (shouldStop()) return;
| }
Expand All @@ -692,7 +723,7 @@ case class HashAggregateExec(
$numOutput.add(1);
UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
$outputCode
$outputFunc($keyTerm, $bufferTerm);

if (shouldStop()) return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,14 @@ case class FilterExec(condition: Expression, child: SparkPlan)
ev
}

// Note: wrap in "do { } while(false);", so the generated checks can jump out with "continue;"
s"""
|$generated
|$nullChecks
|$numOutput.add(1);
|${consume(ctx, resultVars)}
|do {
| $generated
| $nullChecks
| $numOutput.add(1);
| ${consume(ctx, resultVars)}
|} while(false);
""".stripMargin
}

Expand Down Expand Up @@ -316,9 +319,10 @@ case class SampleExec(
""".stripMargin.trim)

s"""
| if ($sampler.sample() == 0) continue;
| $numOutput.add(1);
| ${consume(ctx, input)}
| if ($sampler.sample() != 0) {
| $numOutput.add(1);
| ${consume(ctx, input)}
| }
""".stripMargin.trim
}
}
Expand Down
Loading

0 comments on commit 038b185

Please sign in to comment.