Skip to content

Commit

Permalink
refine
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Jun 21, 2015
1 parent 9f07359 commit d18f401
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 41 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,20 @@ class Analyzer(
* Replaces [[UnresolvedAlias]]s with concrete aliases.
*/
object ResolveAliases extends Rule[LogicalPlan] {
private def assignAliases(exprs: Seq[Expression]) = {
var i = -1
private def assignAliases(exprs: Seq[NamedExpression]) = {
// The `UnresolvedAlias`s will appear only at root of a expression tree, we don't need
// to transform down the whole tree.
exprs.map {
case u @ UnresolvedAlias(child) =>
exprs.zipWithIndex.map {
case (u @ UnresolvedAlias(child), i) =>
child match {
case _: UnresolvedAttribute => u
case ne: NamedExpression => ne
case ev: ExtractValueWithStruct => Alias(ev, ev.field.name)()
case g: Generator if g.resolved && g.elementTypes.size > 1 => MultiAlias(g, Nil)
case e if !e.resolved => u
case other =>
i += 1
Alias(other, s"c$i")()
case other => Alias(other, s"_c$i")()
}
case (other, _) => other
}
}

Expand Down Expand Up @@ -611,7 +609,7 @@ class Analyzer(
/** Extracts a [[Generator]] expression and any names assigned by aliases to their output. */
private object AliasedGenerator {
def unapply(e: Expression): Option[(Generator, Seq[String])] = e match {
case Alias(g: Generator, name) if g.elementTypes.size > 1 =>
case Alias(g: Generator, name) if g.resolved && g.elementTypes.size > 1 =>
// If not given the default names, and the TGF with multiple output columns
failAnalysis(
s"""Expect multiple names given for ${g.getClass.getName},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@ abstract class ExtractValueWithStruct extends ExtractValue {
self: Product =>

def field: StructField
override def foldable: Boolean = child.foldable
override def toString: String = s"$child.${field.name}"
}

Expand Down Expand Up @@ -127,7 +126,6 @@ case class GetArrayStructFields(
containsNull: Boolean) extends ExtractValueWithStruct {

override def dataType: DataType = ArrayType(field.dataType, containsNull)
override def nullable: Boolean = child.nullable

override def eval(input: InternalRow): Any = {
val baseValue = child.eval(input).asInstanceOf[Seq[InternalRow]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,10 +156,9 @@ object PartialAggregation {
partialEvaluations(new TreeNodeRef(e)).finalEvaluation

case e: Expression =>
namedGroupingExpressions
.find { case (k, v) => k semanticEquals e }
.map(_._2.toAttribute)
.getOrElse(e)
namedGroupingExpressions.collectFirst {
case (expr, ne) if expr semanticEquals e => ne.toAttribute
}.getOrElse(e)
}).asInstanceOf[Seq[NamedExpression]]

val partialComputation = namedGroupingExpressions.map(_._2) ++
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -252,14 +252,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
// One match, but we also need to extract the requested nested field.
case Seq((a, nestedFields)) =>
// The foldLeft adds ExtractValues for every remaining parts of the identifier,
// and aliases it with the last part of the identifier.
// and wrap it with UnresolvedAlias which will be removed later.
// For example, consider "a.b.c", where "a" is resolved to an existing attribute.
// Then this will add ExtractValue("c", ExtractValue("b", a)), and alias
// the final expression as "c".
// Then this will add ExtractValue("c", ExtractValue("b", a)), and wrap it as
// UnresolvedAlias(ExtractValue("c", ExtractValue("b", a))).
val fieldExprs = nestedFields.foldLeft(a: Expression)((expr, fieldName) =>
ExtractValue(expr, Literal(fieldName), resolver))
val aliasName = nestedFields.last
Some(Alias(fieldExprs, aliasName)())
Some(UnresolvedAlias(fieldExprs))

// No matches.
case Seq() =>
Expand Down
37 changes: 15 additions & 22 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -70,27 +70,27 @@ class GroupedData protected[sql](
groupingExprs: Seq[Expression],
private val groupType: GroupedData.GroupType) {

private[this] def toDF(aggExprs: Seq[NamedExpression]): DataFrame = {
private[this] def toDF(aggExprs: Seq[Expression]): DataFrame = {
val aggregates = if (df.sqlContext.conf.dataFrameRetainGroupColumns) {
val retainedExprs = groupingExprs.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
retainedExprs ++ aggExprs
} else {
aggExprs
}
groupingExprs ++ aggExprs
} else {
aggExprs
}

val aliasedAgg = aggregates.map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
}
groupType match {
case GroupedData.GroupByType =>
DataFrame(
df.sqlContext, Aggregate(groupingExprs, aggregates, df.logicalPlan))
df.sqlContext, Aggregate(groupingExprs, aliasedAgg, df.logicalPlan))
case GroupedData.RollupType =>
DataFrame(
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aggregates))
df.sqlContext, Rollup(groupingExprs, df.logicalPlan, aliasedAgg))
case GroupedData.CubeType =>
DataFrame(
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aggregates))
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
}
}

Expand All @@ -112,10 +112,7 @@ class GroupedData protected[sql](
namedExpr
}
}
toDF(columnExprs.map { c =>
val a = f(c)
Alias(a, a.prettyString)()
})
toDF(columnExprs.map(f))
}

private[this] def strToExpr(expr: String): (Expression => Expression) = {
Expand Down Expand Up @@ -169,8 +166,7 @@ class GroupedData protected[sql](
*/
def agg(exprs: Map[String, String]): DataFrame = {
toDF(exprs.map { case (colName, expr) =>
val a = strToExpr(expr)(df(colName).expr)
Alias(a, a.prettyString)()
strToExpr(expr)(df(colName).expr)
}.toSeq)
}

Expand Down Expand Up @@ -224,10 +220,7 @@ class GroupedData protected[sql](
*/
@scala.annotation.varargs
def agg(expr: Column, exprs: Column*): DataFrame = {
toDF((expr +: exprs).map(_.expr).map {
case expr: NamedExpression => expr
case expr: Expression => Alias(expr, expr.prettyString)()
})
toDF((expr +: exprs).map(_.expr))
}

/**
Expand Down

0 comments on commit d18f401

Please sign in to comment.