diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8b68b0df35f48..d367eff845b88 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.util.collection.OpenHashSet import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ @@ -59,6 +58,7 @@ class Analyzer( ResolveReferences :: ResolveGroupingAnalytics :: ResolveSortReferences :: + ResolveGenerate :: ImplicitGenerate :: ResolveFunctions :: GlobalAggregates :: @@ -473,10 +473,47 @@ class Analyzer( */ object ImplicitGenerate extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case Project(Seq(Alias(g: Generator, _)), child) => - Generate(g, join = false, outer = false, None, child) + case Project(Seq(Alias(g: Generator, name)), child) => + Generate(g, join = false, outer = false, child, qualifier = None, name :: Nil, Nil) + case Project(Seq(MultiAlias(g: Generator, names)), child) => + Generate(g, join = false, outer = false, child, qualifier = None, names, Nil) } } + + object ResolveGenerate extends Rule[LogicalPlan] { + // Construct the output attributes for the generator, + // The output attribute names can be either specified or + // auto generated. + private def makeGeneratorOutput( + generator: Generator, + attributeNames: Seq[String], + qualifier: Option[String]): Array[Attribute] = { + val elementTypes = generator.elementTypes + + val raw = if (attributeNames.size == elementTypes.size) { + attributeNames.zip(elementTypes).map { + case (n, (t, nullable)) => AttributeReference(n, t, nullable)() + } + } else { + elementTypes.zipWithIndex.map { + // keep the default column names as Hive does _c0, _c1, _cN + case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)() + } + } + + qualifier.map(q => raw.map(_.withQualifiers(q :: Nil))).getOrElse(raw).toArray[Attribute] + } + + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case p: Generate if !p.child.resolved || !p.generator.resolved => p + case p: Generate if p.resolved == false => + // if the generator output names are not specified, we will use the default ones. + val gOutput = makeGeneratorOutput(p.generator, p.attributeNames, p.qualifier) + Generate( + p.generator, p.join, p.outer, p.child, p.qualifier, gOutput.map(_.name), gOutput) + } + } + } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index fa02111385c06..da324703c7337 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -38,6 +38,12 @@ trait CheckAnalysis { throw new AnalysisException(msg) } + def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = { + exprs.flatMap(_.collect { + case e: Generator => true + }).length >= 1 + } + def checkAnalysis(plan: LogicalPlan): Unit = { // We transform up and order the rules so as to catch the first possible failure instead // of the result of cascading resolution failures. @@ -107,6 +113,12 @@ trait CheckAnalysis { failAnalysis( s"unresolved operator ${operator.simpleString}") + case p @ Project(exprs, _) if containsMultipleGenerators(exprs) => + failAnalysis( + s"""Only a single table generating function is allowed in a SELECT clause, found: + | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin) + + case _ => // Analysis successful! } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 21c15ad14fd19..fc87b34bb1d3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -284,12 +284,13 @@ package object dsl { seed: Int = (math.random * 1000).toInt): LogicalPlan = Sample(fraction, withReplacement, seed, logicalPlan) + // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None): LogicalPlan = - Generate(generator, join, outer, None, logicalPlan) + alias: Option[String] = None): Generate = + Generate(generator, join, outer, logicalPlan, alias) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 860b72fad38b3..3257fbc67e368 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -42,47 +42,27 @@ abstract class Generator extends Expression { override type EvaluatedType = TraversableOnce[Row] - override lazy val dataType = - ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) + override def dataType: DataType = ??? override def nullable: Boolean = false /** - * Should be overridden by specific generators. Called only once for each instance to ensure - * that rule application does not change the output schema of a generator. + * The output element data types in structure of Seq[(DataType, Nullable)] */ - protected def makeOutput(): Seq[Attribute] - - private var _output: Seq[Attribute] = null - - def output: Seq[Attribute] = { - if (_output == null) { - _output = makeOutput() - } - _output - } + def elementTypes: Seq[(DataType, Boolean)] /** Should be implemented by child classes to perform specific Generators. */ override def eval(input: Row): TraversableOnce[Row] - - /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */ - override def makeCopy(newArgs: Array[AnyRef]): this.type = { - val copy = super.makeCopy(newArgs) - copy._output = _output - copy - } } /** * A generator that produces its output using the provided lambda function. */ case class UserDefinedGenerator( - schema: Seq[Attribute], + elementTypes: Seq[(DataType, Boolean)], function: Row => TraversableOnce[Row], children: Seq[Expression]) - extends Generator{ - - override protected def makeOutput(): Seq[Attribute] = schema + extends Generator { override def eval(input: Row): TraversableOnce[Row] = { val inputRow = new InterpretedProjection(children) @@ -95,30 +75,18 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ -case class Explode(attributeNames: Seq[String], child: Expression) +case class Explode(child: Expression) extends Generator with trees.UnaryNode[Expression] { override lazy val resolved = child.resolved && (child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType]) - private lazy val elementTypes = child.dataType match { + override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match { case ArrayType(et, containsNull) => (et, containsNull) :: Nil case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil } - // TODO: Move this pattern into Generator. - protected def makeOutput() = - if (attributeNames.size == elementTypes.size) { - attributeNames.zip(elementTypes).map { - case (n, (t, nullable)) => AttributeReference(n, t, nullable)() - } - } else { - elementTypes.zipWithIndex.map { - case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)() - } - } - override def eval(input: Row): TraversableOnce[Row] = { child.dataType match { case ArrayType(_, _) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index bcbcbeb31c7b5..afcb2ce8b9cb4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)( extends NamedExpression with trees.UnaryNode[Expression] { override type EvaluatedType = Any + // Alias(Generator, xx) need to be transformed into Generate(generator, ...) + override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator] override def eval(input: Row): Any = child.eval(input) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 93e69d409cb91..f5c71ee1da21c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -477,16 +477,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] { object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, - generate @ Generate(generator, join, outer, alias, grandChild)) => + case filter @ Filter(condition, g: Generate) => // Predicates that reference attributes produced by the `Generate` operator cannot // be pushed below the operator. val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { - conjunct => conjunct.references subsetOf grandChild.outputSet + conjunct => conjunct.references subsetOf g.child.outputSet } if (pushDown.nonEmpty) { val pushDownPredicate = pushDown.reduce(And) - val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild)) + val withPushdown = Generate(g.generator, join = g.join, outer = g.outer, + Filter(pushDownPredicate, g.child), g.qualifier, g.attributeNames, g.gOutput) stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown) } else { filter diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 17522976dc2c9..6acb83c5e32de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -40,34 +40,41 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. - * @param alias when set, this string is applied to the schema of the output of the transformation - * as a qualifier. + * @param child Children logical plan node + * @param qualifier Qualifier for the attributes of generator(UDTF) + * @param attributeNames the column names for the generator(UDTF), will be _c0, _c1 .. _cN if + * leave as default (empty) + * @param gOutput The output of Generator. */ case class Generate( generator: Generator, join: Boolean, outer: Boolean, - alias: Option[String], - child: LogicalPlan) + child: LogicalPlan, + qualifier: Option[String] = None, + attributeNames: Seq[String] = Nil, + gOutput: Seq[Attribute] = Nil) extends UnaryNode { - protected def generatorOutput: Seq[Attribute] = { - val output = alias - .map(a => generator.output.map(_.withQualifiers(a :: Nil))) - .getOrElse(generator.output) - if (join && outer) { - output.map(_.withNullability(true)) - } else { - output - } + override lazy val resolved: Boolean = { + generator.resolved && + childrenResolved && + attributeNames.length > 0 && + gOutput.map(_.name) == attributeNames } - override def output: Seq[Attribute] = - if (join) child.output ++ generatorOutput else generatorOutput + // we don't want the gOutput to be taken as part of the expressions + // as that will cause exceptions like unresolved attributes etc. + override def expressions: Seq[Expression] = generator :: Nil + + def output: Seq[Attribute] = { + if (join) child.output ++ gOutput else gOutput + } } case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 6e3d6b9263e86..9911c8ecb5966 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -92,7 +92,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter { assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved) - val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)()) + val explode = Explode(AttributeReference("a", IntegerType, nullable = true)()) assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved) assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index 1448098c770aa..45cf695d20b01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest { test("generate: predicate referenced no generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), true, false, Some("arr")) .where(('b >= 5) && ('a > 6)) } val optimized = Optimize(originalQuery.analyze) val correctAnswer = { testRelationWithArrayType .where(('b >= 5) && ('a > 6)) - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze + .generate(Explode('c_arr), true, false, Some("arr")).analyze } comparePlans(optimized, correctAnswer) } test("generate: part of conjuncts referenced generated column") { - val generator = Explode(Seq("c"), 'c_arr) + val generator = Explode('c_arr) val originalQuery = { testRelationWithArrayType .generate(generator, true, false, Some("arr")) @@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest { test("generate: all conjuncts referenced generated column") { val originalQuery = { testRelationWithArrayType - .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")) + .generate(Explode('c_arr), true, false, Some("arr")) .where(('c > 6) || ('b > 5)).analyze } val optimized = Optimize(originalQuery) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 94ae2d65fd0e4..858b126f8df32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -711,12 +711,15 @@ class DataFrame private[sql]( */ def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val attributes = schema.toAttributes + + val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } + val names = schema.toAttributes.map(_.name) + val rowFunction = f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) - val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) + val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - Generate(generator, join = true, outer = false, None, logicalPlan) + Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil) } /** @@ -733,12 +736,16 @@ class DataFrame private[sql]( : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil + // TODO handle the metadata? + val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } + val names = attributes.map(_.name) + def rowFunction(row: Row): TraversableOnce[Row] = { f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) } - val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) + val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - Generate(generator, join = true, outer = false, None, logicalPlan) + Generate(generator, join = true, outer = false, logicalPlan, qualifier = None, names, Nil) } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 12271048bb39c..98c0d4c9b5cf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -31,40 +31,29 @@ import org.apache.spark.sql.catalyst.expressions._ * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. + * @param output the output attributes of this node, which constructed in analysis phase, + * and we can not change it, as the parent node bound with it already. */ @DeveloperApi case class Generate( generator: Generator, join: Boolean, outer: Boolean, - child: SparkPlan) + child: SparkPlan, + output: Seq[Attribute]) extends UnaryNode { - // This must be a val since the generator output expr ids are not preserved by serialization. - protected val generatorOutput: Seq[Attribute] = { - if (join && outer) { - generator.output.map(_.withNullability(true)) - } else { - generator.output - } - } - - // This must be a val since the generator output expr ids are not preserved by serialization. - override val output = - if (join) child.output ++ generatorOutput else generatorOutput - val boundGenerator = BindReferences.bindReference(generator, child.output) override def execute(): RDD[Row] = { if (join) { child.execute().mapPartitions { iter => - val nullValues = Seq.fill(generator.output.size)(Literal(null)) + val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = newProjection(child.output ++ nullValues, child.output) - val joinProjection = - newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) + val joinProjection = newProjection(output, output) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5b99e40c2f491..fee3008d18050 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -304,8 +304,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil - case logical.Generate(generator, join, outer, _, child) => - execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil + case g @ logical.Generate(generator, join, outer, child, _, _, _) => + execution.Generate( + generator, join = join, outer = outer, planLater(child), g.output) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 7c6a7df2bd01e..c4a73b3004076 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -249,7 +249,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.CreateTables :: catalog.PreInsertionCasts :: ExtractPythonUdfs :: - ResolveUdtfsAlias :: sources.PreInsertCastAndRename :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 53a204b8c2932..e5ada7f5178e9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -725,12 +725,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = false, - Some(alias.toLowerCase), - withWhere) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = false, + withWhere, + Some(alias.toLowerCase), + attributes, + Nil) }.getOrElse(withWhere) // The projection of the query can either be a normal projection, an aggregation @@ -833,12 +836,15 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText - Generate( - nodesToGenerator(clauses), - join = true, - outer = isOuter.nonEmpty, - Some(alias.toLowerCase), - nodeToRelation(relationClause)) + val (generator, attributes) = nodesToGenerator(clauses) + Generate( + generator, + join = true, + outer = isOuter.nonEmpty, + nodeToRelation(relationClause), + Some(alias.toLowerCase), + attributes, + Nil) /* All relations, possibly with aliases or sampling clauses. */ case Token("TOK_TABREF", clauses) => @@ -1311,7 +1317,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C val explode = "(?i)explode".r - def nodesToGenerator(nodes: Seq[Node]): Generator = { + def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = { val function = nodes.head val attributes = nodes.flatMap { @@ -1321,7 +1327,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C function match { case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) => - Explode(attributes, nodeToExpr(child)) + (Explode(nodeToExpr(child)), attributes) case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => val functionInfo: FunctionInfo = @@ -1329,10 +1335,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUdtf( + (HiveGenericUdtf( new HiveFunctionWrapper(functionClassName), - attributes, - children.map(nodeToExpr)) + children.map(nodeToExpr)), attributes) case a: ASTNode => throw new NotImplementedError( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index 47305571e579e..4b6f0ad75f54f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -66,7 +66,7 @@ private[hive] abstract class HiveFunctionRegistry } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { HiveUdaf(new HiveFunctionWrapper(functionClassName), children) } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children) + HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) } else { sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") } @@ -266,7 +266,6 @@ private[hive] case class HiveUdaf( */ private[hive] case class HiveGenericUdtf( funcWrapper: HiveFunctionWrapper, - aliasNames: Seq[String], children: Seq[Expression]) extends Generator with HiveInspectors { @@ -282,23 +281,8 @@ private[hive] case class HiveGenericUdtf( @transient protected lazy val udtInput = new Array[AnyRef](children.length) - protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map { - field => inspectorToDataType(field.getFieldObjectInspector) - } - - override protected def makeOutput() = { - // Use column names when given, otherwise _c1, _c2, ... _cn. - if (aliasNames.size == outputDataTypes.size) { - aliasNames.zip(outputDataTypes).map { - case (attrName, attrDataType) => - AttributeReference(attrName, attrDataType, nullable = true)() - } - } else { - outputDataTypes.zipWithIndex.map { - case (attrDataType, i) => - AttributeReference(s"_c$i", attrDataType, nullable = true)() - } - } + lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { + field => (inspectorToDataType(field.getFieldObjectInspector), true) } override def eval(input: Row): TraversableOnce[Row] = { @@ -333,22 +317,6 @@ private[hive] case class HiveGenericUdtf( } } -/** - * Resolve Udtfs Alias. - */ -private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case p @ Project(projectList, _) - if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 => - throw new TreeNodeException(p, "only single Generator supported for SELECT clause") - - case Project(Seq(Alias(udtf @ HiveGenericUdtf(_, _, _), name)), child) => - Generate(udtf.copy(aliasNames = Seq(name)), join = false, outer = false, None, child) - case Project(Seq(MultiAlias(udtf @ HiveGenericUdtf(_, _, _), names)), child) => - Generate(udtf.copy(aliasNames = names), join = false, outer = false, None, child) - } -} - private[hive] case class HiveUdafFunction( funcWrapper: HiveFunctionWrapper, exprs: Seq[Expression], diff --git a/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 new file mode 100644 index 0000000000000..d00491fd7e5bb --- /dev/null +++ b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 @@ -0,0 +1 @@ +1 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 b/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 0000000000000..01e79c32a8c99 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 new file mode 100644 index 0000000000000..0c7520f2090dd --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 @@ -0,0 +1,3 @@ +86 val_86 +238 val_238 +311 val_311 diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 b/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 new file mode 100644 index 0000000000000..01e79c32a8c99 --- /dev/null +++ b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 @@ -0,0 +1,3 @@ +1 +2 +3 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 300b1f7920473..ac10b173307d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -27,7 +27,7 @@ import scala.util.Try import org.apache.hadoop.hive.conf.HiveConf.ConfVars import org.apache.spark.{SparkFiles, SparkException} -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive._ @@ -67,6 +67,40 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { } } + createQueryTest("insert table with generator with column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) AS val FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + createQueryTest("insert table with generator with multiple column names", + """ + | CREATE TABLE gen_tmp (key Int, value String); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(map(key, value)) as (k1, k2) FROM src LIMIT 3; + | SELECT key, value FROM gen_tmp ORDER BY key, value ASC; + """.stripMargin) + + createQueryTest("insert table with generator without column name", + """ + | CREATE TABLE gen_tmp (key Int); + | INSERT OVERWRITE TABLE gen_tmp + | SELECT explode(array(1,2,3)) FROM src LIMIT 3; + | SELECT key FROM gen_tmp ORDER BY key ASC; + """.stripMargin) + + test("multiple generator in projection") { + intercept[AnalysisException] { + sql("SELECT explode(map(key, value)), key FROM src").collect() + } + + intercept[AnalysisException] { + sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect() + } + } + createQueryTest("! operator", """ |SELECT a FROM ( @@ -456,7 +490,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view2", "SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl") - createQueryTest("lateral view3", "FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX") @@ -478,6 +511,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("lateral view6", "SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v") + createQueryTest("Specify the udtf output", + "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t") + test("sampling") { sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s") sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s")