Skip to content

Commit

Permalink
Seq join bug fix (#1169)
Browse files Browse the repository at this point in the history
* Seq join bug fix

* Address comments

* version bump

---------

Co-authored-by: Rakesh Kashyap Hanasoge Padmanabha <rkashyap@rkashyap-mn3.linkedin.biz>
  • Loading branch information
rakeshkashyap123 and Rakesh Kashyap Hanasoge Padmanabha committed May 16, 2023
1 parent 6ce855c commit 402efe1
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
def joinFeaturesWithSuppressedExceptions(joinConfig: FeatureJoinConfig, obsData: SparkFeaturizedDataset,
jobContext: JoinJobContext = JoinJobContext()): (SparkFeaturizedDataset, Map[String, String]) = {
(joinFeatures(joinConfig, obsData, jobContext), Map(SuppressedExceptionHandlerUtils.MISSING_DATA_EXCEPTION
-> SuppressedExceptionHandlerUtils.missingDataSuppressedExceptionMsgs))
-> SuppressedExceptionHandlerUtils.missingFeatures.mkString))
}

/**
Expand Down Expand Up @@ -231,7 +231,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:

val (joinedDF, header) = doJoinObsAndFeatures(joinConfig, jobContext, obsData)
(joinedDF, header, Map(SuppressedExceptionHandlerUtils.MISSING_DATA_EXCEPTION
-> SuppressedExceptionHandlerUtils.missingDataSuppressedExceptionMsgs))
-> SuppressedExceptionHandlerUtils.missingFeatures.mkString))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import com.linkedin.feathr.offline.join.algorithms.{SequentialJoinConditionBuild
import com.linkedin.feathr.offline.logical.FeatureGroups
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.util.FeaturizedDatasetUtils
import com.linkedin.feathr.offline.util.{FeaturizedDatasetUtils, SuppressedExceptionHandlerUtils}
import com.linkedin.feathr.offline.{ErasedEntityTaggedFeature, FeatureDataFrame}
import com.linkedin.feathr.sparkcommon.FeatureDerivationFunctionSpark
import com.linkedin.feathr.{common, offline}
Expand All @@ -37,6 +37,9 @@ private[offline] class DerivedFeatureEvaluator(derivationStrategies: DerivationS
def evaluate(keyTag: Seq[Int], keyTagList: Seq[String], contextDF: DataFrame, derivedFeature: DerivedFeature): FeatureDataFrame = {
val tags = Some(keyTag.map(keyTagList).toList)
val producedFeatureColName = DataFrameColName.genFeatureColumnName(derivedFeature.producedFeatureNames.head, tags)
if (derivedFeature.consumedFeatureNames.exists(x => SuppressedExceptionHandlerUtils.missingFeatures.contains(x.getFeatureName))) {
SuppressedExceptionHandlerUtils.missingFeatures.add(derivedFeature.producedFeatureNames.head)
}

derivedFeature.derivation match {
case g: SeqJoinDerivationFunction =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.linkedin.feathr.common.{FeatureAggregationType, FeatureValue}
import com.linkedin.feathr.offline.PostTransformationUtil
import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.client.DataFrameColName.genFeatureColumnName
import com.linkedin.feathr.offline.derived.DerivedFeature
import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction
import com.linkedin.feathr.offline.derived.strategies.SequentialJoinAsDerivation._
Expand All @@ -18,7 +19,7 @@ import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults
import com.linkedin.feathr.offline.transformation.{AnchorToDataSourceMapper, MvelDefinition}
import com.linkedin.feathr.offline.util.{CoercionUtilsScala, DataFrameSplitterMerger, FeathrUtils, FeaturizedDatasetUtils}
import com.linkedin.feathr.offline.util.{CoercionUtilsScala, DataFrameSplitterMerger, FeathrUtils, FeaturizedDatasetUtils, SuppressedExceptionHandlerUtils}
import com.linkedin.feathr.sparkcommon.{ComplexAggregation, SeqJoinCustomAggregation}
import org.apache.logging.log4j.LogManager
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -52,6 +53,22 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession,
val seqJoinDerivationFunction = derivationFunction
val baseFeatureName = seqJoinDerivationFunction.left.feature
val expansionFeatureName = seqJoinDerivationFunction.right.feature
val shouldAddDefault = FeathrUtils.getFeathrJobParam(ss, FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA).toBoolean

// If expansion feature is missing because of data issues, then we just copy the expansion feature column to be the final feature value.
if (shouldAddDefault && SuppressedExceptionHandlerUtils.missingFeatures.contains(expansionFeatureName)) {
val seqJoinedFeatureResult = df.withColumn(genFeatureColumnName(FEATURE_NAME_PREFIX + derivedFeature.producedFeatureNames.head),
col(genFeatureColumnName(FEATURE_NAME_PREFIX + expansionFeatureName)))

// Add the additional column with the keytags to mimic the exact behavior of the seq join flow, this will get dropped later.
val seqJoinFeatureResultWithRenamed = seqJoinedFeatureResult.withColumn(genFeatureColumnName(FEATURE_NAME_PREFIX
+ derivedFeature.producedFeatureNames.head, Some(keyTags.map(keyTagList).toList)),
col(genFeatureColumnName(FEATURE_NAME_PREFIX + expansionFeatureName)))
val missingFeature = derivedFeature.producedFeatureNames.head
log.warn(s"Missing data for features ${missingFeature}. Default values will be populated for this column.")
SuppressedExceptionHandlerUtils.missingFeatures += missingFeature
return seqJoinFeatureResultWithRenamed
}
val aggregationFunction = seqJoinDerivationFunction.aggregation
val tagStrList = Some(keyTags.map(keyTagList).toList)
val outputKey = seqJoinDerivationFunction.left.outputKey
Expand Down Expand Up @@ -219,8 +236,6 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession,
val anchorDFMap1 = anchorToDataSourceMapper.getBasicAnchorDFMapForJoin(ss, Seq(featureAnchor), failOnMissingPartition)
val updatedAnchorDFMap = anchorDFMap1.filter(anchorEntry => anchorEntry._2.isDefined)
.map(anchorEntry => anchorEntry._1 -> anchorEntry._2.get)
// We dont need to check if the anchored feature's dataframes are missing (due to skip missing feature) as such
// seq join features have already been removed in the FeatureGroupsUpdater#getUpdatedFeatureGroupsWithoutInvalidPaths.
val featureInfo = FeatureTransformation.directCalculate(
anchorGroup: AnchorFeatureGroups,
updatedAnchorDFMap(featureAnchor),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults
import com.linkedin.feathr.offline.transformation.DataFrameExt._
import com.linkedin.feathr.offline.util.{DataFrameUtils, FeathrUtils, FeaturizedDatasetUtils}
import com.linkedin.feathr.offline.util.{DataFrameUtils, FeathrUtils, FeaturizedDatasetUtils, SuppressedExceptionHandlerUtils}
import com.linkedin.feathr.offline.util.FeathrUtils.shouldCheckPoint
import org.apache.logging.log4j.LogManager
import org.apache.spark.sql.{DataFrame, SparkSession}
Expand Down Expand Up @@ -113,10 +113,12 @@ private[offline] class AnchoredFeatureJoinStep(
val containsFeature: Seq[Boolean] = anchorDFMap.map(y => y._1.selectedFeatures.contains(x)).toSeq
!containsFeature.contains(true)
})
log.warn(s"Missing data for features ${missingFeatures.mkString}. Default values will be populated for this column.")
SuppressedExceptionHandlerUtils.missingFeatures ++= missingFeatures
val missingAnchoredFeatures = ctx.featureGroups.allAnchoredFeatures.filter(featureName => missingFeatures.contains(featureName._1))
substituteDefaultsForDataMissingFeatures(ctx.sparkSession, observationDF, ctx.logicalPlan,
missingAnchoredFeatures)
}else observationDF
} else observationDF

val allAnchoredFeatures: Map[String, FeatureAnchorWithSource] = ctx.featureGroups.allAnchoredFeatures
val joinStages = ctx.logicalPlan.joinStages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ private[offline] class SlidingWindowAggregationJoiner(
res.map(emptyFeatures.add)
val exceptionMsg = emptyFeatures.mkString
log.warn(s"Missing data for features ${emptyFeatures}. Default values will be populated for this column.")
SuppressedExceptionHandlerUtils.missingDataSuppressedExceptionMsgs += exceptionMsg
SuppressedExceptionHandlerUtils.missingFeatures ++= emptyFeatures
anchors.map(anchor => (anchor, originalSourceDf))
} else {
val sourceDF: DataFrame = preprocessedDf match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
} catch {
case e: Exception => if (shouldSkipFeature || shouldAddDefaultCol) None else throw e
}

anchorsWithDate.map(anchor => (anchor, timeSeriesSource))
})
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
package com.linkedin.feathr.offline.util

import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource

/**
* Util classes and methods to handle suppressed exceptions.
*/
object SuppressedExceptionHandlerUtils {
val MISSING_DATA_EXCEPTION = "missing_data_exception"
var missingDataSuppressedExceptionMsgs = ""

// Set of features that may be missing because of missing data.
var missingFeatures = scala.collection.mutable.Set.empty[String]
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import com.linkedin.feathr.offline.config.location.SimplePath
import com.linkedin.feathr.offline.generation.SparkIOUtils
import com.linkedin.feathr.offline.job.PreprocessedDataFrameManager
import com.linkedin.feathr.offline.source.dataloader.{AvroJsonDataLoader, CsvDataLoader}
import com.linkedin.feathr.offline.util.FeathrTestUtils
import com.linkedin.feathr.offline.util.{FeathrTestUtils, SuppressedExceptionHandlerUtils}
import com.linkedin.feathr.offline.util.FeathrUtils.{ADD_DEFAULT_COL_FOR_MISSING_DATA, SKIP_MISSING_FEATURE, setFeathrJobParam}
import org.apache.spark.sql.Row
import org.apache.spark.sql.functions.col
Expand Down Expand Up @@ -409,7 +409,7 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
| key: a_id
| featureList: ["featureWithNull", "derived_featureWithNull", "featureWithNull2", "featureWithNull3", "featureWithNull4",
| "featureWithNull5", "derived_featureWithNull2", "featureWithNull6", "featureWithNull7", "derived_featureWithNull7"
| "aEmbedding", "memberEmbeddingAutoTZ", "aEmbedding", "featureWithNullSql"]
| "aEmbedding", "memberEmbeddingAutoTZ", "aEmbedding", "featureWithNullSql", "seqJoin_featureWithNull"]
| }
""".stripMargin,
featureDefAsString =
Expand Down Expand Up @@ -529,6 +529,14 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
| derived_featureWithNull: "featureWithNull * 2"
| derived_featureWithNull2: "featureWithNull2 * 2"
| derived_featureWithNull7: "featureWithNull7 * 2"
| seqJoin_featureWithNull: {
| key: x
| join: {
| base: {key: x, feature: featureWithNull2}
| expansion: {key: y, feature: featureWithNull5}
| }
| aggregation: "SUM"
| }
|}
""".stripMargin,
observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv")
Expand All @@ -550,6 +558,11 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
assertEquals(featureList(0).getAs[Row]("derived_featureWithNull2"),
Row(mutable.WrappedArray.make(Array("")), mutable.WrappedArray.make(Array(2.0f))))
assertEquals(featureList(0).getAs[Row]("featureWithNullSql"), 1.0f)
assertEquals(featureList(0).getAs[Row]("seqJoin_featureWithNull"),
Row(mutable.WrappedArray.make(Array("")), mutable.WrappedArray.make(Array(1.0f))))
assertEquals(SuppressedExceptionHandlerUtils.missingFeatures,
Set("featureWithNull", "featureWithNull3", "featureWithNull5", "featureWithNull4", "featureWithNull7",
"aEmbedding", "featureWithNull6", "derived_featureWithNull", "seqJoin_featureWithNull"))
setFeathrJobParam(ADD_DEFAULT_COL_FOR_MISSING_DATA, "false")
}

Expand Down

0 comments on commit 402efe1

Please sign in to comment.