diff --git a/src/main/scala/org/apache/spark/sql/sources/druid/AggregateTransform.scala b/src/main/scala/org/apache/spark/sql/sources/druid/AggregateTransform.scala index cc3ccf1..1ddaba0 100644 --- a/src/main/scala/org/apache/spark/sql/sources/druid/AggregateTransform.scala +++ b/src/main/scala/org/apache/spark/sql/sources/druid/AggregateTransform.scala @@ -324,6 +324,24 @@ trait AggregateTransform { case _ => Seq() } + private def addCountAgg(aggExp: AggregateExpression, + dqb : DruidQueryBuilder) : DruidQueryBuilder = { + val a = dqb.nextAlias + + /* + * if the Druid Index has a count metric then translate to a sum of the count, + * else translate to a count metric. + */ + val druidAggFunc = if (dqb.drInfo.druidDS.metrics.contains("count")) { + "longSum" + } else { + "count" + } + + dqb.aggregate(FunctionAggregationSpec(druidAggFunc, a, "count")). + outputAttribute(a, aggExp, aggExp.dataType, LongType) + } + def aggregateExpression(dqb: DruidQueryBuilder, aggExp: AggregateExpression)( implicit expandOpProjection: Seq[Expression], aEExprIdToPos: Map[ExprId, Int], @@ -332,17 +350,10 @@ trait AggregateTransform { val nativeAgg = new DruidNativeAggregator(dqb, aggExp, expandOpProjection, aEExprIdToPos, aEToLiteralExpr) (aggExp, aggExp.aggregateFunction) match { - case (_, Count(Literal(1, IntegerType) :: Nil)) => { - val a = dqb.nextAlias - Some(dqb.aggregate(FunctionAggregationSpec("longSum", a, "count")). - outputAttribute(a, aggExp, aggExp.dataType, LongType)) + case (_, Count(Literal(1, IntegerType) :: Nil)) | + (_, Count(AttributeReference("1", _, _, _) :: Nil)) => { + Some(addCountAgg(aggExp, dqb)) } - case (_, Count(AttributeReference("1", _, _, _) :: Nil)) => { - val a = dqb.nextAlias - Some(dqb.aggregate(FunctionAggregationSpec("longSum", a, "count")). - outputAttribute(a, aggExp, aggExp.dataType, LongType)) - } - // TODO: // Instead of JS rewriting AVG as Sum, Cnt, Sum/Cnt // the expression should be rewritten generically. Introduce