Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add COUNT_DISTINCT aggregation #594

Merged
merged 4 commits into from
Aug 22, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import com.linkedin.feathr.offline.transformation.FeatureColumnFormat.FeatureCol
import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils
import com.linkedin.feathr.offline.util.datetime.{DateTimeInterval, OfflineDateTimeUtils}
import com.linkedin.feathr.swj.{FactData, GroupBySpec, LateralViewParams, SlidingWindowFeature, WindowSpec}
import com.linkedin.feathr.swj.aggregate.{AggregationType, AvgAggregate, AvgPoolingAggregate, CountAggregate, LatestAggregate, MaxAggregate, MaxPoolingAggregate, MinAggregate, MinPoolingAggregate, SumAggregate}
import com.linkedin.feathr.swj.aggregate.{AggregationType, AvgAggregate, AvgPoolingAggregate, CountAggregate, CountDistinctAggregate, LatestAggregate, MaxAggregate, MaxPoolingAggregate, MinAggregate, MinPoolingAggregate, SumAggregate}
import org.apache.log4j.Logger
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.UserDefinedFunction
Expand Down Expand Up @@ -178,6 +178,7 @@ private[offline] object SlidingWindowFeatureUtils {
// In Feathr's use case, we want to treat the count aggregation as simple count of non-null items.
val rewrittenDef = s"CASE WHEN ${featureDef} IS NOT NULL THEN 1 ELSE 0 END"
new CountAggregate(rewrittenDef)
case AggregationType.COUNT_DISTINCT => new CountDistinctAggregate(featureDef)
case AggregationType.AVG => new AvgAggregate(featureDef)
case AggregationType.MAX => new MaxAggregate(featureDef)
case AggregationType.MIN => new MinAggregate(featureDef)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ package com.linkedin.feathr.swj.aggregate

object AggregationType extends Enumeration {
type AggregationType = Value
val SUM, COUNT, AVG, MAX, TIMESINCE, LATEST, DUMMY, MIN, MAX_POOLING, MIN_POOLING, AVG_POOLING, SUM_POOLING = Value
val SUM, COUNT, COUNT_DISTINCT, AVG, MAX, TIMESINCE, LATEST, DUMMY, MIN, MAX_POOLING, MIN_POOLING, AVG_POOLING, SUM_POOLING = Value
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package com.linkedin.feathr.swj.aggregate

import com.linkedin.feathr.swj.aggregate.AggregationType._
import org.apache.spark.sql.types._

/**
* COUNT_DISTINCT aggregation implementation.
*
* @param metricCol Name of the metric column or a Spark SQL column expression for derived metric
* that will be aggregated using COUNT_DISTINCT.
*/
class CountDistinctAggregate(val metricCol: String) extends AggregationSpec {
override def aggregation: AggregationType = COUNT_DISTINCT

override def metricName = "count_distinct_col"

override def isIncrementalAgg = false

override def isCalculateAggregateNeeded: Boolean = true

override def calculateAggregate(aggregate: Any, dataType: DataType): Any = {
if (aggregate == null) {
aggregate
} else {
dataType match {
case IntegerType => aggregate.asInstanceOf[Set[Int]].size
case LongType => aggregate.asInstanceOf[Set[Long]].size
case DoubleType => aggregate.asInstanceOf[Set[Double]].size
case FloatType => aggregate.asInstanceOf[Set[Float]].size
case StringType => aggregate.asInstanceOf[Set[String]].size
case _ => throw new RuntimeException(s"Invalid data type for COUNT_DISTINCT metric col $metricCol. " +
s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
}
}
}

override def agg(aggregate: Any, record: Any, dataType: DataType): Any = {
if (aggregate == null) {
Set(record)
} else if (record == null) {
aggregate
} else {
dataType match {
case IntegerType => aggregate.asInstanceOf[Set[Int]] + record.asInstanceOf[Int]
case LongType => aggregate.asInstanceOf[Set[Long]] + record.asInstanceOf[Long]
case DoubleType => aggregate.asInstanceOf[Set[Double]] + record.asInstanceOf[Double]
case FloatType => aggregate.asInstanceOf[Set[Float]] + record.asInstanceOf[Float]
case StringType=> aggregate.asInstanceOf[Set[String]] + record.asInstanceOf[String]
case _ => throw new RuntimeException(s"Invalid data type for COUNT_DISTINCT metric col $metricCol. " +
s"Only Int, Long, Double, Float, and String are supported, but got ${dataType.typeName}")
}
}
}

override def deagg(aggregate: Any, record: Any, dataType: DataType): Any = {
throw new RuntimeException("Method deagg for COUNT_DISTINCT aggregate is not implemented because COUNT_DISTINCT is " +
"not an incremental aggregation.")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -983,4 +983,87 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {

validateRows(dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField)), expectedRows)
}


@Test
def testSWACountDistinct(): Unit = {
val featureDefAsString =
"""
|sources: {
| swaSource: {
| location: { path: "generation/daily/" }
| isTimeSeries: true
| timeWindowParameters: {
| timestampColumn: "timestamp"
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|}
|anchors: {
| swaAnchorWithKeyExtractor: {
| source: "swaSource"
| key: [x]
| features: {
| f: {
| def: "Id" // the column that contains the raw view count
| aggregation: COUNT
| window: 10d
| }
| g: {
| def: "Id" // the column that contains the raw view count
| aggregation: COUNT_DISTINCT
| window: 10d
| }
| }
| }
|}
""".stripMargin

val features = Seq("f", "g")
val keyField = "x"
val featureJoinAsString =
s"""
| settings: {
| joinTimeSettings: {
| timestampColumn: {
| def: timestamp
| format: yyyy-MM-dd
| }
| }
|}
|features: [
| {
| key: [$keyField],
| featureList: [${features.mkString(",")}]
| }
|]
""".stripMargin


/**
* Expected output:
* +--------+----+----+
* |x| f| g|
* +--------+----+----+
* | 1| 6| 2|
* | 2| 5| 2|
* | 3| 1| 1|
* +--------+----+----+
*/
val expectedSchema = StructType(
Seq(
StructField(keyField, LongType),
StructField(features.head, LongType), // f
StructField(features.last, LongType) // g
))

val expectedRows = Array(
new GenericRowWithSchema(Array(1, 6, 2), expectedSchema),
new GenericRowWithSchema(Array(2, 5, 2), expectedSchema),
new GenericRowWithSchema(Array(3, 1, 1), expectedSchema))
val dfs = runLocalFeatureJoinForTest(featureJoinAsString, featureDefAsString, "featuresWithFilterObs.avro.json").data
dfs.show()
jaymo001 marked this conversation as resolved.
Show resolved Hide resolved

validateRows(dfs.select(keyField, features: _*).collect().sortBy(row => row.getAs[Int](keyField)), expectedRows)
}
}