Skip to content

Commit

Permalink
Seq join bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Rakesh Kashyap Hanasoge Padmanabha committed May 16, 2023
1 parent 5c02d26 commit c860208
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import com.linkedin.feathr.common.{FeatureAggregationType, FeatureValue}
import com.linkedin.feathr.offline.PostTransformationUtil
import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.client.DataFrameColName.genFeatureColumnName
import com.linkedin.feathr.offline.derived.DerivedFeature
import com.linkedin.feathr.offline.derived.functions.SeqJoinDerivationFunction
import com.linkedin.feathr.offline.derived.strategies.SequentialJoinAsDerivation._
Expand All @@ -18,7 +19,7 @@ import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults
import com.linkedin.feathr.offline.transformation.{AnchorToDataSourceMapper, MvelDefinition}
import com.linkedin.feathr.offline.util.{CoercionUtilsScala, DataFrameSplitterMerger, FeathrUtils, FeaturizedDatasetUtils}
import com.linkedin.feathr.offline.util.{CoercionUtilsScala, DataFrameSplitterMerger, FeathrUtils, FeaturizedDatasetUtils, SuppressedExceptionHandlerUtils}
import com.linkedin.feathr.sparkcommon.{ComplexAggregation, SeqJoinCustomAggregation}
import org.apache.logging.log4j.LogManager
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -52,6 +53,22 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession,
val seqJoinDerivationFunction = derivationFunction
val baseFeatureName = seqJoinDerivationFunction.left.feature
val expansionFeatureName = seqJoinDerivationFunction.right.feature
val shouldAddDefault = FeathrUtils.getFeathrJobParam(ss, FeathrUtils.ADD_DEFAULT_COL_FOR_MISSING_DATA).toBoolean

// If expansion feature is missing because of data issues, then we just copy the expansion feature column to be the final feature value.
if (shouldAddDefault && SuppressedExceptionHandlerUtils.missingFeatures.contains(expansionFeatureName)) {
val seqJoinedFeatureResult = df.withColumn(genFeatureColumnName(FEATURE_NAME_PREFIX + derivedFeature.producedFeatureNames.head),
col(genFeatureColumnName(FEATURE_NAME_PREFIX + expansionFeatureName)))

// Add the additional column with the keytags to mimic the exact behavior of the seq join flow, this will get dropped later.
val seqJoinFeatureResultWithRenamed = seqJoinedFeatureResult.withColumn(genFeatureColumnName(FEATURE_NAME_PREFIX
+ derivedFeature.producedFeatureNames.head, Some(keyTags.map(keyTagList).toList)),
col(genFeatureColumnName(FEATURE_NAME_PREFIX + expansionFeatureName)))
val missingFeature = derivedFeature.producedFeatureNames.head
log.warn(s"Missing data for features ${missingFeature}. Default values will be populated for this column.")
SuppressedExceptionHandlerUtils.missingDataSuppressedExceptionMsgs += missingFeature
return seqJoinFeatureResultWithRenamed
}
val aggregationFunction = seqJoinDerivationFunction.aggregation
val tagStrList = Some(keyTags.map(keyTagList).toList)
val outputKey = seqJoinDerivationFunction.left.outputKey
Expand Down Expand Up @@ -219,8 +236,6 @@ private[offline] class SequentialJoinAsDerivation(ss: SparkSession,
val anchorDFMap1 = anchorToDataSourceMapper.getBasicAnchorDFMapForJoin(ss, Seq(featureAnchor), failOnMissingPartition)
val updatedAnchorDFMap = anchorDFMap1.filter(anchorEntry => anchorEntry._2.isDefined)
.map(anchorEntry => anchorEntry._1 -> anchorEntry._2.get)
// We dont need to check if the anchored feature's dataframes are missing (due to skip missing feature) as such
// seq join features have already been removed in the FeatureGroupsUpdater#getUpdatedFeatureGroupsWithoutInvalidPaths.
val featureInfo = FeatureTransformation.directCalculate(
anchorGroup: AnchorFeatureGroups,
updatedAnchorDFMap(featureAnchor),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults
import com.linkedin.feathr.offline.transformation.DataFrameExt._
import com.linkedin.feathr.offline.util.{DataFrameUtils, FeathrUtils, FeaturizedDatasetUtils}
import com.linkedin.feathr.offline.util.{DataFrameUtils, FeathrUtils, FeaturizedDatasetUtils, SuppressedExceptionHandlerUtils}
import com.linkedin.feathr.offline.util.FeathrUtils.shouldCheckPoint
import org.apache.logging.log4j.LogManager
import org.apache.spark.sql.{DataFrame, SparkSession}
Expand Down Expand Up @@ -113,10 +113,13 @@ private[offline] class AnchoredFeatureJoinStep(
val containsFeature: Seq[Boolean] = anchorDFMap.map(y => y._1.selectedFeatures.contains(x)).toSeq
!containsFeature.contains(true)
})
log.warn(s"Missing data for features ${missingFeatures.mkString}. Default values will be populated for this column.")
SuppressedExceptionHandlerUtils.missingDataSuppressedExceptionMsgs += missingFeatures.mkString
SuppressedExceptionHandlerUtils.missingFeatures ++= missingFeatures
val missingAnchoredFeatures = ctx.featureGroups.allAnchoredFeatures.filter(featureName => missingFeatures.contains(featureName._1))
substituteDefaultsForDataMissingFeatures(ctx.sparkSession, observationDF, ctx.logicalPlan,
missingAnchoredFeatures)
}else observationDF
} else observationDF

val allAnchoredFeatures: Map[String, FeatureAnchorWithSource] = ctx.featureGroups.allAnchoredFeatures
val joinStages = ctx.logicalPlan.joinStages
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ private[offline] class AnchorToDataSourceMapper(dataPathHandlers: List[DataPathH
} catch {
case e: Exception => if (shouldSkipFeature || shouldAddDefaultCol) None else throw e
}

anchorsWithDate.map(anchor => (anchor, timeSeriesSource))
})
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package com.linkedin.feathr.offline.util

import com.linkedin.feathr.offline.anchored.feature.FeatureAnchorWithSource

/**
* Util classes and methods to handle suppressed exceptions.
*/
object SuppressedExceptionHandlerUtils {
val MISSING_DATA_EXCEPTION = "missing_data_exception"
var missingDataSuppressedExceptionMsgs = ""

// Set of features that may be missing because of missing data.
var missingFeatures = scala.collection.mutable.Set.empty[String]
}
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,7 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
| key: a_id
| featureList: ["featureWithNull", "derived_featureWithNull", "featureWithNull2", "featureWithNull3", "featureWithNull4",
| "featureWithNull5", "derived_featureWithNull2", "featureWithNull6", "featureWithNull7", "derived_featureWithNull7"
| "aEmbedding", "memberEmbeddingAutoTZ", "aEmbedding", "featureWithNullSql"]
| "aEmbedding", "memberEmbeddingAutoTZ", "aEmbedding", "featureWithNullSql", "seqJoin_featureWithNull"]
| }
""".stripMargin,
featureDefAsString =
Expand Down Expand Up @@ -529,6 +529,14 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
| derived_featureWithNull: "featureWithNull * 2"
| derived_featureWithNull2: "featureWithNull2 * 2"
| derived_featureWithNull7: "featureWithNull7 * 2"
| seqJoin_featureWithNull: {
| key: x
| join: {
| base: {key: x, feature: featureWithNull2}
| expansion: {key: y, feature: featureWithNull5}
| }
| aggregation: "SUM"
| }
|}
""".stripMargin,
observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv")
Expand All @@ -550,6 +558,8 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest {
assertEquals(featureList(0).getAs[Row]("derived_featureWithNull2"),
Row(mutable.WrappedArray.make(Array("")), mutable.WrappedArray.make(Array(2.0f))))
assertEquals(featureList(0).getAs[Row]("featureWithNullSql"), 1.0f)
assertEquals(featureList(0).getAs[Row]("seqJoin_featureWithNull"),
Row(mutable.WrappedArray.make(Array("")), mutable.WrappedArray.make(Array(1.0f))))
setFeathrJobParam(ADD_DEFAULT_COL_FOR_MISSING_DATA, "false")
}

Expand Down

0 comments on commit c860208

Please sign in to comment.