Skip to content

Commit

Permalink
Merge pull request apache#4 from marmbrus/pr/6885
Browse files Browse the repository at this point in the history
Add simple resolver
  • Loading branch information
JoshRosen committed Jun 18, 2015
2 parents c60a44d + d9ab1e4 commit a80f9b0
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 11 deletions.
Expand Up @@ -18,9 +18,13 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.catalyst.expressions.{BoundReference, Ascending, SortOrder}
import org.apache.spark.sql.catalyst.dsl.expressions._

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType}

class SortSuite extends SparkPlanTest {
import TestSQLContext.implicits.localSeqToDataFrameHolder

test("basic sorting using ExternalSort") {

Expand All @@ -30,16 +34,14 @@ class SortSuite extends SparkPlanTest {
("World", 8)
)

val sortOrder = Seq(
SortOrder(BoundReference(0, StringType, nullable = false), Ascending),
SortOrder(BoundReference(1, IntegerType, nullable = false), Ascending)
)

checkAnswer(
input,
(child: SparkPlan) => new ExternalSort(sortOrder, global = false, child),
input.sorted
)
input.toDF("a", "b"),
ExternalSort('a.asc :: 'b.asc :: Nil, global = false, _: SparkPlan),
input.sorted)

checkAnswer(
input.toDF("a", "b"),
ExternalSort('b.asc :: 'a.asc :: Nil, global = false, _: SparkPlan),
input.sortBy(t => (t._2, t._1)))
}
}
Expand Up @@ -21,9 +21,13 @@ import scala.util.control.NonFatal
import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.SparkFunSuite

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.BoundReference
import org.apache.spark.sql.catalyst.util._

import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.{Row, DataFrame}
import org.apache.spark.sql.catalyst.util._

/**
* Base class for writing tests for individual physical operators. For an example of how this
Expand All @@ -48,6 +52,24 @@ class SparkPlanTest extends SparkFunSuite {
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
* @param planFunction a function which accepts the input SparkPlan and uses it to instantiate
* the physical operator that's being tested.
* @param expectedAnswer the expected result in a [[Seq]] of [[Product]]s.
*/
protected def checkAnswer[A <: Product : TypeTag](
input: DataFrame,
planFunction: SparkPlan => SparkPlan,
expectedAnswer: Seq[A]): Unit = {
val expectedRows = expectedAnswer.map(Row.fromTuple)
SparkPlanTest.checkAnswer(input, planFunction, expectedRows) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}

/**
* Runs the plan and makes sure the answer matches the expected result.
* @param input the input data to be used.
Expand Down Expand Up @@ -87,6 +109,23 @@ object SparkPlanTest {

val outputPlan = planFunction(input.queryExecution.sparkPlan)

// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
val resolvedPlan = outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).zipWithIndex.map {
case (a, i) =>
(a.name, BoundReference(i, a.dataType, a.nullable))
}.toMap

plan.transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.get(u).getOrElse {
sys.error(s"Invalid Test: Cannot resolve $u given input ${inputMap}")
}
}
}

def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
// Converts data to types that we can do equality comparison using Scala collections.
// For BigDecimal type, the Scala type has a better definition of equality test (similar to
Expand All @@ -105,7 +144,7 @@ object SparkPlanTest {
}

val sparkAnswer: Seq[Row] = try {
outputPlan.executeCollect().toSeq
resolvedPlan.executeCollect().toSeq
} catch {
case NonFatal(e) =>
val errorMessage =
Expand Down

0 comments on commit a80f9b0

Please sign in to comment.