diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 19c6dc4d9a363..7b1d6f8528ca1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -135,22 +135,20 @@ class Analyzer( * Replaces [[UnresolvedAlias]]s with concrete aliases. */ object ResolveAliases extends Rule[LogicalPlan] { - private def assignAliases(exprs: Seq[Expression]) = { - var i = -1 + private def assignAliases(exprs: Seq[NamedExpression]) = { // The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need // to transform down the whole tree. - exprs.map { - case u @ UnresolvedAlias(child) => + exprs.zipWithIndex.map { + case (u @ UnresolvedAlias(child), i) => child match { case _: UnresolvedAttribute => u case ne: NamedExpression => ne case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)() case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil) case e if !e.resolved => u - case other => - i += 1 - Alias(other, s"c$i")() + case other => Alias(other, s"_c$i")() } + case (other, _) => other } } @@ -611,7 +609,7 @@ class Analyzer( /** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */ private object AliasedGenerator { def unapply(e: Expression): Option[(Generator, Seq[String])] = e match { - case Alias(g: Generator, name) if g.elementTypes.size > 1 => + case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 => // If not given the default names, and the TGF with multiple output columns failAnalysis( s"""Expect multiple names given for ${g.getClass.getName}, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 741b0c3870603..f65a107924ec5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -98,7 +98,6 @@ abstract class ExtractValueWithStruct extends ExtractValue { self: Product => def field: StructField - override def foldable: Boolean = child.foldable override def toString: String = s"$child.${field.name}" } @@ -127,7 +126,6 @@ case class GetArrayStructFields( containsNull: Boolean) extends ExtractValueWithStruct { override def dataType: DataType = ArrayType(field.dataType, containsNull) - override def nullable: Boolean = child.nullable override def eval(input: InternalRow): Any = { val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 8d79585e65a36..179a348d5baac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -156,10 +156,9 @@ object PartialAggregation { partialEvaluations(new TreeNodeRef(e)).finalEvaluation case e: Expression => - namedGroupingExpressions - .find { case (k, v) => k semanticEquals e } - .map(_._2.toAttribute) - .getOrElse(e) + namedGroupingExpressions.collectFirst { + case (expr, ne) if expr semanticEquals e => ne.toAttribute + }.getOrElse(e) }).asInstanceOf[Seq[NamedExpression]] val partialComputation = namedGroupingExpressions.map(_._2) ++ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 73e8a0c8effb7..b009a200b920f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { // One match, but we also need to extract the requested nested field. case Seq((a, nestedFields)) => // The foldLeft adds ExtractValues for every remaining parts of the identifier, - // and aliases it with the last part of the identifier. + // and wrap it with UnresolvedAlias which will be removed later. // For example, consider "a.b.c", where "a" is resolved to an existing attribute. - // Then this will add ExtractValue("c", ExtractValue("b", a)), and alias - // the final expression as "c". + // Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as + // UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))). val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) => ExtractValue(expr, Literal(fieldName), resolver)) - val aliasName = nestedFields.last - Some(Alias(fieldExprs, aliasName)()) + Some(UnresolvedAlias(fieldExprs)) // No matches. case Seq() => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index 45b3e1bc627d5..859224d263ec2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -70,27 +70,27 @@ class GroupedData protected[sql]( groupingExprs: Seq[Expression], private val groupType: GroupedData.GroupType) { - private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = { + private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = { val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) { - val retainedExprs = groupingExprs.map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - } - retainedExprs ++ aggExprs - } else { - aggExprs - } + groupingExprs ++ aggExprs + } else { + aggExprs + } + val aliasedAgg = aggregates.map { + case expr: NamedExpression => expr + case expr: Expression => Alias(expr, expr.prettyString)() + } groupType match { case GroupedData.GroupByType => DataFrame( - df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan)) + df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan)) case GroupedData.RollupType => DataFrame( - df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg)) case GroupedData.CubeType => DataFrame( - df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates)) + df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg)) } } @@ -112,10 +112,7 @@ class GroupedData protected[sql]( namedExpr } } - toDF(columnExprs.map { c => - val a = f(c) - Alias(a, a.prettyString)() - }) + toDF(columnExprs.map(f)) } private[this] def strToExpr(expr: String): (Expression => Expression) = { @@ -169,8 +166,7 @@ class GroupedData protected[sql]( */ def agg(exprs: Map[String, String]): DataFrame = { toDF(exprs.map { case (colName, expr) => - val a = strToExpr(expr)(df(colName).expr) - Alias(a, a.prettyString)() + strToExpr(expr)(df(colName).expr) }.toSeq) } @@ -224,10 +220,7 @@ class GroupedData protected[sql]( */ @scala.annotation.varargs def agg(expr: Column, exprs: Column*): DataFrame = { - toDF((expr +: exprs).map(_.expr).map { - case expr: NamedExpression => expr - case expr: Expression => Alias(expr, expr.prettyString)() - }) + toDF((expr +: exprs).map(_.expr)) } /**