Skip to content

Commit

Permalink
add bucketed_sum aggregation (#1168)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaymo001 committed May 15, 2023
1 parent 5c02d26 commit 6ce855c
Show file tree
Hide file tree
Showing 14 changed files with 116 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ private[offline] object WindowTimeUnit extends Enumeration {
case H => Duration.ofHours(timeWindowStr.dropRight(1).trim.toLong)
case M => Duration.ofMinutes(timeWindowStr.dropRight(1).trim.toLong)
case S => Duration.ofSeconds(timeWindowStr.dropRight(1).trim.toLong)
case Y => Duration.ofDays(365*timeWindowStr.dropRight(1).trim.toLong)
case W => Duration.ofDays(7*timeWindowStr.dropRight(1).trim.toLong)
case _ => Duration.ofSeconds(0)
}
} catch {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import com.fasterxml.jackson.annotation.JsonProperty
import com.linkedin.feathr.common.exception.{ErrorLabel, FeathrConfigException}
import com.linkedin.feathr.offline.config.{ComplexAggregationFeature, TimeWindowFeatureDefinition}
import com.linkedin.feathr.offline.generation.aggregations._
import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils.convertFeathrDefToSwjDef
import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils.{convertFeathrDefToSwjDef, isBucketedFunction}
import com.linkedin.feathr.sparkcommon.SimpleAnchorExtractorSpark
import com.linkedin.feathr.swj.aggregate.AggregationType
import com.typesafe.config.ConfigFactory
Expand Down Expand Up @@ -61,7 +61,7 @@ private[offline] class TimeWindowConfigurableAnchorExtractor(@JsonProperty("feat
*/
override def aggregateAsColumns(groupedDataFrame: DataFrame): Seq[(String, Column)] = {
val columnPairs = aggFeatures.collect {
case (featureName, featureDef) if !featureDef.timeWindowFeatureDefinition.aggregationType.toString.startsWith("BUCKETED_") =>
case (featureName, featureDef) if !isBucketedFunction(featureDef.timeWindowFeatureDefinition.aggregationType) =>
// for basic sliding window aggregation
// no complex aggregation will be defined
if (featureDef.swaFeature.lateralView.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import com.linkedin.feathr.offline.{FeatureDataFrame, FeatureDataWithJoinKeys}
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.job.FeatureTransformation
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter
import com.linkedin.feathr.offline.util.FeathrUtils
import org.apache.spark.sql.SparkSession

/**
Expand Down Expand Up @@ -43,7 +44,12 @@ private[offline] class FeatureGenDefaultsSubstituter() {
withDefaultDF,
featuresWithKeys.keys.map(FeatureTransformation.FEATURE_NAME_PREFIX + DataFrameColName.getEncodedFeatureRefStrForColName(_)).toSeq)
// If there're multiple rows with same join key, keep one record for these duplicate records(same behavior as Feature join API)
val withoutDupDF = withNullsDroppedDF.dropDuplicates(joinKeys)
val dropDuplicate = FeathrUtils.getFeathrJobParam(ss, FeathrUtils.DROP_DUPLICATE_ROWS_FOR_KEYS_IN_FEATURE_GENERATION).toBoolean
val withoutDupDF = if (dropDuplicate) {
withNullsDroppedDF.dropDuplicates(joinKeys)
} else {
withNullsDroppedDF
}
// Return features processed in this iteration
featuresWithKeys.map(f => (f._1, (FeatureDataFrame(withoutDupDF, inferredTypeConfig), joinKeys)))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ import com.linkedin.feathr.common.{DateTimeParam, DateTimeUtils, JoiningFeatureP
import com.linkedin.feathr.offline.anchored.anchorExtractor.TimeWindowConfigurableAnchorExtractor
import com.linkedin.feathr.offline.job.FeatureGenSpec
import com.linkedin.feathr.offline.logical.FeatureGroups
import com.linkedin.feathr.offline.util.FeathrUtils
import org.apache.spark.sql.SparkSession

import java.time.Duration
import scala.annotation.tailrec
import scala.collection.convert.wrapAll._

Expand Down Expand Up @@ -100,15 +103,19 @@ private[offline] object FeatureGenKeyTagAnalyzer extends FeatureGenKeyTagAnalyze
featureGenSpec: FeatureGenSpec,
featureGroups: FeatureGroups): Seq[JoiningFeatureParams] = {
val refTime = featureGenSpec.dateTimeParam
val ss = SparkSession.builder().getOrCreate()
val expand_days = FeathrUtils.getFeathrJobParam(ss, FeathrUtils.EXPAND_DAYS_IN_FEATURE_GENERATION_CUTOFF_TIME).toInt
taggedFeature.map(f => {
val featureName = f.getFeatureName
val featureAnchorWithSource = featureGroups.allAnchoredFeatures(featureName)
val dateParam = featureAnchorWithSource.featureAnchor.extractor match {
case extractor: TimeWindowConfigurableAnchorExtractor =>
val aggFeature = extractor.features(featureName)
val dateTimeParam = DateTimeParam.shiftStartTime(refTime, aggFeature.window)
DateTimeUtils.toDateParam(dateTimeParam)
case _ =>
val dateTimeShifted = DateTimeParam.shiftStartTime(refTime, aggFeature.window)
val dateTimeParamExpandStart = DateTimeParam.shiftStartTime(dateTimeShifted, Duration.ofDays(expand_days*2))
val dateTimeParamExpandEnd = DateTimeParam.shiftEndTime(dateTimeParamExpandStart, Duration.ofDays(expand_days).negated())
DateTimeUtils.toDateParam(dateTimeParamExpandEnd)
case _ =>
featureGenSpec.dateParam
}
new JoiningFeatureParams(f.getKeyTag, f.getFeatureName, Option(dateParam))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import com.linkedin.feathr.offline.join.DataFrameKeyCombiner
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.accessor.{DataSourceAccessor, NonTimeBasedDataSourceAccessor, TimeBasedDataSourceAccessor}
import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils
import com.linkedin.feathr.offline.swa.SlidingWindowFeatureUtils.isBucketedFunction
import com.linkedin.feathr.offline.transformation.FeatureColumnFormat.FeatureColumnFormat
import com.linkedin.feathr.offline.transformation._
import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils.tensorTypeToDataFrameSchema
Expand Down Expand Up @@ -207,7 +208,7 @@ private[offline] object FeatureTransformation {
val featureTypeConfigs = featureAnchorWithSource.featureAnchor.featureTypeConfigs
val transformedFeatureData: TransformedResult = featureAnchorWithSource.featureAnchor.extractor match {
case transformer: TimeWindowConfigurableAnchorExtractor =>
val nonBucketedFeatures = transformer.features.map(_._2.aggregationType).filter(agg => agg == AggregationType.BUCKETED_COUNT_DISTINCT)
val nonBucketedFeatures = transformer.features.map(_._2.aggregationType).filter(agg => isBucketedFunction(agg))
if (!(nonBucketedFeatures.size != transformer.features || transformer.features.isEmpty)) {
throw new FeathrFeatureTransformationException(
ErrorLabel.FEATHR_USER_ERROR,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,9 @@ private[offline] class PathPartitionedTimeSeriesSourceAccessor(
throw new FeathrInputDataException(
ErrorLabel.FEATHR_USER_ERROR,
s"Trying to create TimeSeriesSource but no data " +
s"is found to create source data. Source path: ${source.path}, source type: ${source.sourceType}")
s"is found to create source data. Source path: ${source.path}, source type: ${source.sourceType}." +
s"Try to get dataframe from interval ${timeIntervalOpt}, " +
s"but source dataset time has interval ${datePartitions.map(_.dateInterval.toString).mkString(",")} ")
}
selectedPartitions
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ private[offline] class BatchDataLoader(ss: SparkSession, location: DataLocation,
} else {
// Throwing exception to avoid dataLoaderHandler hook exception from being suppressed.
throw new FeathrInputDataException(ErrorLabel.FEATHR_USER_ERROR, s"Failed to load ${dataPath} after ${initialNumOfRetries} retries" +
s" and retry time of ${retryWaitTime}ms.")
s" and retry time of ${retryWaitTime}ms. Error message: ${e.getMessage}")
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@ import com.linkedin.feathr.offline.transformation.FeatureColumnFormat
import com.linkedin.feathr.offline.transformation.FeatureColumnFormat.FeatureColumnFormat
import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils
import com.linkedin.feathr.offline.util.datetime.{DateTimeInterval, OfflineDateTimeUtils}
import com.linkedin.feathr.swj.aggregate.AggregationType.AggregationType
import com.linkedin.feathr.swj.{FactData, GroupBySpec, LateralViewParams, SlidingWindowFeature, WindowSpec}
import com.linkedin.feathr.swj.aggregate.{AggregationType, AvgAggregate, AvgPoolingAggregate, CountAggregate, CountDistinctAggregate, DummyAggregate, LatestAggregate, MaxAggregate, MaxPoolingAggregate, MinAggregate, MinPoolingAggregate, SumAggregate}
import org.apache.logging.log4j.LogManager
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.util.sketch.BloomFilter

import java.text.SimpleDateFormat
import java.time._

Expand All @@ -41,6 +41,14 @@ private[offline] object SlidingWindowFeatureUtils {
val DEFAULT_TIME_DELAY = "Default-time-delay"
val TIMESTAMP_PARTITION_COLUMN = "__feathr_timestamp_column_from_partition"

/**
* Check if an aggregation function is bucketed
* @param aggregateFunction function type
*/
def isBucketedFunction(aggregateFunction: AggregationType): Boolean = {
aggregateFunction.toString.startsWith("BUCKETED")
}

/**
* Check if an anchor contains window aggregate features.
* Note: if an anchor contains window aggregate features, it will not contain other non-aggregate features.
Expand Down Expand Up @@ -187,6 +195,7 @@ private[offline] object SlidingWindowFeatureUtils {
case AggregationType.MIN_POOLING => new MinPoolingAggregate(featureDef)
case AggregationType.AVG_POOLING => new AvgPoolingAggregate(featureDef)
case AggregationType.BUCKETED_COUNT_DISTINCT => new DummyAggregate(featureDef)
case AggregationType.BUCKETED_SUM => new DummyAggregate(featureDef)
}
swj.SlidingWindowFeature(featureName, aggregationSpec, windowSpec, filter, groupBySpec, lateralViewParams)
}
Expand Down
Loading

0 comments on commit 6ce855c

Please sign in to comment.