diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/evaluator/DerivedFeatureGenStage.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/evaluator/DerivedFeatureGenStage.scala index ebb6b2809..571bdc3b8 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/evaluator/DerivedFeatureGenStage.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/evaluator/DerivedFeatureGenStage.scala @@ -36,24 +36,24 @@ private[offline] class DerivedFeatureGenStage(featureGroups: FeatureGroups, logi * @return all evaluated features including the ones processed here. */ override def evaluate(features: Seq[String], keyTags: Seq[Int], context: FeatureDataWithJoinKeys): FeatureDataWithJoinKeys = { - val featureToDerivationMap = features.map(f => (featureGroups.allDerivedFeatures(f), f)).toMap - featureToDerivationMap.foldLeft(context)((accumulator: FeatureDataWithJoinKeys, currFeatureDerivation) => { - val (derivation, derivedFeatureName) = currFeatureDerivation - val featureColumnName = DataFrameColName.genFeatureColumnName(derivedFeatureName) - // Compute the base DataFrame that can be used to compute the derived feature. - val BaseDataFrameMetadata(baseFeatureDataFrame, joinKeys, featuresOnBaseDf) = - evaluateBaseDataFrameForDerivation(derivedFeatureName, derivation, accumulator) - val derivedFeatureDataFrame = - if (baseFeatureDataFrame.df.columns.contains(featureColumnName)) { // if DataFrame already has the feature, no need to apply derivations - baseFeatureDataFrame - } else { - derivedFeatureUtils.evaluate(keyTags, logicalPlan.keyTagIntsToStrings, baseFeatureDataFrame.df, derivation) - } - val columnRenamedDf = dropFrameTagsAndRenameColumn(derivedFeatureDataFrame.df, featureColumnName) - // Update featureTypeMap and features on DataFrame metadata - val updatedFeatureTypeMap = baseFeatureDataFrame.inferredFeatureType ++ derivedFeatureDataFrame.inferredFeatureType - val updatedFeaturesOnDf = featuresOnBaseDf :+ derivedFeatureName - accumulator ++ updatedFeaturesOnDf.map(f => f -> (FeatureDataFrame(columnRenamedDf, updatedFeatureTypeMap), joinKeys)).toMap + features.map(f => (featureGroups.allDerivedFeatures(f), f)) + .foldLeft(context)((accumulator: FeatureDataWithJoinKeys, currFeatureDerivation) => { + val (derivation, derivedFeatureName) = currFeatureDerivation + val featureColumnName = DataFrameColName.genFeatureColumnName(derivedFeatureName) + // Compute the base DataFrame that can be used to compute the derived feature. + val BaseDataFrameMetadata(baseFeatureDataFrame, joinKeys, featuresOnBaseDf) = + evaluateBaseDataFrameForDerivation(derivedFeatureName, derivation, accumulator) + val derivedFeatureDataFrame = + if (baseFeatureDataFrame.df.columns.contains(featureColumnName)) { // if DataFrame already has the feature, no need to apply derivations + baseFeatureDataFrame + } else { + derivedFeatureUtils.evaluate(keyTags, logicalPlan.keyTagIntsToStrings, baseFeatureDataFrame.df, derivation) + } + val columnRenamedDf = dropFrameTagsAndRenameColumn(derivedFeatureDataFrame.df, featureColumnName) + // Update featureTypeMap and features on DataFrame metadata + val updatedFeatureTypeMap = baseFeatureDataFrame.inferredFeatureType ++ derivedFeatureDataFrame.inferredFeatureType + val updatedFeaturesOnDf = featuresOnBaseDf :+ derivedFeatureName + accumulator ++ updatedFeaturesOnDf.map(f => f -> (FeatureDataFrame(columnRenamedDf, updatedFeatureTypeMap), joinKeys)).toMap }) }