Skip to content

Commit

Permalink
shrink the commits
Browse files Browse the repository at this point in the history
  • Loading branch information
chenghao-intel committed Apr 14, 2015
1 parent 77eeb10 commit ca5e7f4
Show file tree
Hide file tree
Showing 26 changed files with 189 additions and 147 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
Expand Down Expand Up @@ -59,6 +58,7 @@ class Analyzer(
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
ResolveGenerate ::
ImplicitGenerate ::
ResolveFunctions ::
GlobalAggregates ::
Expand Down Expand Up @@ -473,10 +473,47 @@ class Analyzer(
*/
object ImplicitGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case Project(Seq(Alias(g: Generator, _)), child) =>
Generate(g, join = false, outer = false, None, child)
case Project(Seq(Alias(g: Generator, name)), child) =>
Generate(g, join = false, outer = false, child, qualifier = None, name :: Nil, Nil)
case Project(Seq(MultiAlias(g: Generator, names)), child) =>
Generate(g, join = false, outer = false, child, qualifier = None, names, Nil)
}
}

object ResolveGenerate extends Rule[LogicalPlan] {
// Construct the output attributes for the generator,
// The output attribute names can be either specified or
// auto generated.
private def makeGeneratorOutput(
generator: Generator,
attributeNames: Seq[String],
qualifier: Option[String]): Array[Attribute] = {
val elementTypes = generator.elementTypes

val raw = if (attributeNames.size == elementTypes.size) {
attributeNames.zip(elementTypes).map {
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
}
} else {
elementTypes.zipWithIndex.map {
// keep the default column names as Hive does _c0, _c1, _cN
case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
}
}

qualifier.map(q => raw.map(_.withQualifiers(q :: Nil))).getOrElse(raw).toArray[Attribute]
}

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case p: Generate if !p.child.resolved || !p.generator.resolved => p
case p: Generate if p.resolved == false =>
// if the generator output names are not specified, we will use the default ones.
val gOutput = makeGeneratorOutput(p.generator, p.attributeNames, p.qualifier)
Generate(
p.generator, p.join, p.outer, p.child, p.qualifier, gOutput.map(_.name), gOutput)
}
}

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ trait CheckAnalysis {
throw new AnalysisException(msg)
}

def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
exprs.flatMap(_.collect {
case e: Generator => true
}).length >= 1
}

def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
Expand Down Expand Up @@ -107,6 +113,12 @@ trait CheckAnalysis {
failAnalysis(
s"unresolved operator ${operator.simpleString}")

case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
failAnalysis(
s"""Only a single table generating function is allowed in a SELECT clause, found:
| ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)


case _ => // Analysis successful!
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,13 @@ package object dsl {
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)

// TODO specify the output column names
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None): LogicalPlan =
Generate(generator, join, outer, None, logicalPlan)
alias: Option[String] = None): Generate =
Generate(generator, join, outer, logicalPlan, alias)

def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,47 +42,27 @@ abstract class Generator extends Expression {

override type EvaluatedType = TraversableOnce[Row]

override lazy val dataType =
ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
override def dataType: DataType = ???

override def nullable: Boolean = false

/**
* Should be overridden by specific generators. Called only once for each instance to ensure
* that rule application does not change the output schema of a generator.
* The output element data types in structure of Seq[(DataType, Nullable)]
*/
protected def makeOutput(): Seq[Attribute]

private var _output: Seq[Attribute] = null

def output: Seq[Attribute] = {
if (_output == null) {
_output = makeOutput()
}
_output
}
def elementTypes: Seq[(DataType, Boolean)]

/** Should be implemented by child classes to perform specific Generators. */
override def eval(input: Row): TraversableOnce[Row]

/** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
override def makeCopy(newArgs: Array[AnyRef]): this.type = {
val copy = super.makeCopy(newArgs)
copy._output = _output
copy
}
}

/**
* A generator that produces its output using the provided lambda function.
*/
case class UserDefinedGenerator(
schema: Seq[Attribute],
elementTypes: Seq[(DataType, Boolean)],
function: Row => TraversableOnce[Row],
children: Seq[Expression])
extends Generator{

override protected def makeOutput(): Seq[Attribute] = schema
extends Generator {

override def eval(input: Row): TraversableOnce[Row] = {
val inputRow = new InterpretedProjection(children)
Expand All @@ -95,30 +75,18 @@ case class UserDefinedGenerator(
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
case class Explode(attributeNames: Seq[String], child: Expression)
case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {

override lazy val resolved =
child.resolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])

private lazy val elementTypes = child.dataType match {
override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil
}

// TODO: Move this pattern into Generator.
protected def makeOutput() =
if (attributeNames.size == elementTypes.size) {
attributeNames.zip(elementTypes).map {
case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
}
} else {
elementTypes.zipWithIndex.map {
case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)()
}
}

override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
case ArrayType(_, _) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)(
extends NamedExpression with trees.UnaryNode[Expression] {

override type EvaluatedType = Any
// Alias(Generator, xx) need to be transformed into Generate(generator, ...)
override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]

override def eval(input: Row): Any = child.eval(input)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -477,16 +477,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter(condition,
generate @ Generate(generator, join, outer, alias, grandChild)) =>
case filter @ Filter(condition, g: Generate) =>
// Predicates that reference attributes produced by the `Generate` operator cannot
// be pushed below the operator.
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
conjunct => conjunct.references subsetOf grandChild.outputSet
conjunct => conjunct.references subsetOf g.child.outputSet
}
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
Filter(pushDownPredicate, g.child), g.qualifier, g.attributeNames, g.gOutput)
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
} else {
filter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,34 +40,41 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
* @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
* @param alias when set, this string is applied to the schema of the output of the transformation
* as a qualifier.
* @param child Children logical plan node
* @param qualifier Qualifier for the attributes of generator(UDTF)
* @param attributeNames the column names for the generator(UDTF), will be _c0, _c1 .. _cN if
* leave as default (empty)
* @param gOutput The output of Generator.
*/
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
alias: Option[String],
child: LogicalPlan)
child: LogicalPlan,
qualifier: Option[String] = None,
attributeNames: Seq[String] = Nil,
gOutput: Seq[Attribute] = Nil)
extends UnaryNode {

protected def generatorOutput: Seq[Attribute] = {
val output = alias
.map(a => generator.output.map(_.withQualifiers(a :: Nil)))
.getOrElse(generator.output)
if (join && outer) {
output.map(_.withNullability(true))
} else {
output
}
override lazy val resolved: Boolean = {
generator.resolved &&
childrenResolved &&
attributeNames.length > 0 &&
gOutput.map(_.name) == attributeNames
}

override def output: Seq[Attribute] =
if (join) child.output ++ generatorOutput else generatorOutput
// we don't want the gOutput to be taken as part of the expressions
// as that will cause exceptions like unresolved attributes etc.
override def expressions: Seq[Expression] = generator :: Nil

def output: Seq[Attribute] = {
if (join) child.output ++ gOutput else gOutput
}
}

case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {

assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)

val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)

assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest {
test("generate: predicate referenced no generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('b >= 5) && ('a > 6))
}
val optimized = Optimize(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where(('b >= 5) && ('a > 6))
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
.generate(Explode('c_arr), true, false, Some("arr")).analyze
}

comparePlans(optimized, correctAnswer)
}

test("generate: part of conjuncts referenced generated column") {
val generator = Explode(Seq("c"), 'c_arr)
val generator = Explode('c_arr)
val originalQuery = {
testRelationWithArrayType
.generate(generator, true, false, Some("arr"))
Expand Down Expand Up @@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest {
test("generate: all conjuncts referenced generated column") {
val originalQuery = {
testRelationWithArrayType
.generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
.generate(Explode('c_arr), true, false, Some("arr"))
.where(('c > 6) || ('b > 5)).analyze
}
val optimized = Optimize(originalQuery)
Expand Down
17 changes: 12 additions & 5 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Original file line number Diff line number Diff line change
Expand Up @@ -711,12 +711,15 @@ class DataFrame private[sql](
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
val attributes = schema.toAttributes

val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) }
val names = schema.toAttributes.map(_.name)

val rowFunction =
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))

Generate(generator, join = true, outer = false, None, logicalPlan)
Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)
}

/**
Expand All @@ -733,12 +736,16 @@ class DataFrame private[sql](
: DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
// TODO handle the metadata?
val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
val names = attributes.map(_.name)

def rowFunction(row: Row): TraversableOnce[Row] = {
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
}
val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)

Generate(generator, join = true, outer = false, None, logicalPlan)
Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil)
}

/////////////////////////////////////////////////////////////////////////////
Expand Down
Loading

0 comments on commit ca5e7f4

Please sign in to comment.