Skip to content

Commit

Permalink
Allow alien value in MVEL-based derivations (feathr-ai#1120)
Browse files Browse the repository at this point in the history
  • Loading branch information
jaymo001 committed May 11, 2023
1 parent 3a4ffad commit ff14050
Show file tree
Hide file tree
Showing 9 changed files with 51 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
import com.linkedin.feathr.common.tensor.TensorIterator;
import com.linkedin.feathr.common.types.ValueType;
import com.linkedin.feathr.common.util.CoercionUtils;
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext;
import org.mvel2.DataConversion;
import org.mvel2.integration.impl.SimpleValueResolver;

import java.util.Optional;


/**
* FeatureVariableResolver takes a FeatureValue object for member variable during MVEL expression evaluation,
* and then resolve the value for that variable.
*/
public class FeatureVariableResolver extends SimpleValueResolver {
private FeatureValue _featureValue;

public FeatureVariableResolver(FeatureValue featureValue) {
private Optional<FeathrExpressionExecutionContext> _mvelContext = Optional.empty();
public FeatureVariableResolver(FeatureValue featureValue, Optional<FeathrExpressionExecutionContext> mvelContext) {
super(featureValue);
_featureValue = featureValue;
_mvelContext = mvelContext;
}

@Override
Expand All @@ -25,21 +30,27 @@ public Object getValue() {
return null;
}

Object fv = null;
switch (_featureValue.getFeatureType().getBasicType()) {
case NUMERIC:
return _featureValue.getAsNumeric();
fv = _featureValue.getAsNumeric(); break;
case TERM_VECTOR:
return getValueFromTermVector();
fv = getValueFromTermVector(); break;
case BOOLEAN:
case CATEGORICAL:
case CATEGORICAL_SET:
case DENSE_VECTOR:

case TENSOR:
return getValueFromTensor();

fv = getValueFromTensor(); break;
default:
throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType());
throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType());
}
// If there is any registered FeatureValue handler that can handle this feature value, return the converted value per request.
if (_mvelContext.isPresent() && _mvelContext.get().canConvertFromAny(fv)) {
return _mvelContext.get().convertFromAny(fv).head();
} else {
return fv;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private[offline] object PostTransformationUtil {
featureType: FeatureTypes,
mvelContext: Option[FeathrExpressionExecutionContext]): Try[FeatureValue] = Try {
val args = Map(featureName -> Some(featureValue))
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext)
val transformedValue = MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, featureValue, variableResolverFactory, mvelContext.orNull)
CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private[offline] class MvelFeatureDerivationFunction(

override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = {
val argMap = (parameterNames zip inputs).toMap
val variableResolverFactory = new FeatureVariableResolverFactory(argMap)
val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext)

MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match {
case Some(value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[offline] class MvelFeatureDerivationFunction1(

override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = {
val argMap = (parameterNames zip inputs).toMap
val variableResolverFactory = new FeatureVariableResolverFactory(argMap)
val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext)

MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match {
case Some(value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[offline] class SimpleMvelDerivationFunction(expression: String, featureN
MvelContext.ensureInitialized()

// In order to prevent MVEL from barfing if a feature is null, we use a custom variable resolver that understands `Option`
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext)

if (TestFwkUtils.IS_DEBUGGER_ENABLED) {
while(TestFwkUtils.DERIVED_FEATURE_COUNTER > 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.linkedin.feathr.offline.mvel

import com.linkedin.feathr.common.{FeatureValue, FeatureVariableResolver}
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import org.mvel2.integration.VariableResolver
import org.mvel2.integration.impl.BaseVariableResolverFactory

import java.util.Optional
import scala.collection.JavaConverters._

private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]]) extends BaseVariableResolverFactory {
variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull)).asInstanceOf[Map[String, VariableResolver]].asJava
private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]], mvelContext: Option[FeathrExpressionExecutionContext]) extends BaseVariableResolverFactory {

variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull, Optional.ofNullable(mvelContext.orNull))).asInstanceOf[Map[String, VariableResolver]].asJava

override def isTarget(name: String): Boolean = features.contains(name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,28 @@ class FeathrExpressionExecutionContext extends Serializable {
}
}

/**
* Check if there is registered converters that can handle the conversion.
* @param inputValue input value to check
* @return whether it can be converted or not
*/
def canConvertFromAny(inputValue: AnyRef): Boolean = {
val result = converters.value.filter(converter => converter._2.canConvertFrom(inputValue.getClass))
result.nonEmpty
}

/**
* Convert the input Check if there is registered converters that can handle the conversion.
* @param inputValue input value to convert
* @return return all converted values produced by registered converters
*/
def convertFromAny(inputValue: AnyRef): List[AnyRef] = {
converters.value.collect {
case converter if converter._2.canConvertFrom(inputValue.getClass) =>
converter._2.convertFrom(inputValue)
}.toList
}

/**
* Convert the input to output type using the registered converters
* @param in value to be converted
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.linkedin.feathr.offline

import com.linkedin.feathr.common.FeatureTypes
import com.linkedin.feathr.offline.anchored.keyExtractor.AlienSourceKeyExtractorAdaptor
import com.linkedin.feathr.offline.client.plugins.FeathrUdfPluginContext
import com.linkedin.feathr.offline.derived.AlienDerivationFunctionAdaptor
Expand All @@ -9,7 +8,6 @@ import com.linkedin.feathr.offline.plugins.{AlienFeatureValue, AlienFeatureValue
import com.linkedin.feathr.offline.util.FeathrTestUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{FloatType, StringType, StructField, StructType}
import org.testng.Assert.assertEquals
import org.testng.annotations.Test

class TestFeathrUdfPlugins extends FeathrIntegTest {
Expand All @@ -19,7 +17,7 @@ class TestFeathrUdfPlugins extends FeathrIntegTest {
private val mvelContext = new FeathrExpressionExecutionContext()

// todo - support udf plugins through FCM
@Test (enabled = false)
@Test (enabled = true)
def testMvelUdfPluginSupport: Unit = {
mvelContext.setupExecutorMvelContext(classOf[AlienFeatureValue], new AlienFeatureValueTypeAdaptor(), ss.sparkContext)
FeathrUdfPluginContext.registerUdfAdaptor(new AlienDerivationFunctionAdaptor(), ss.sparkContext)
Expand Down Expand Up @@ -113,8 +111,6 @@ class TestFeathrUdfPlugins extends FeathrIntegTest {
observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv",
mvelContext = Some(mvelContext))

val f8Type = df.fdsMetadata.header.get.featureInfoMap.filter(_._1.getFeatureName == "f8").head._2.featureType.getFeatureType
assertEquals(f8Type, FeatureTypes.NUMERIC)

val selectedColumns = Seq("a_id", "fA")
val filteredDf = df.data.select(selectedColumns.head, selectedColumns.tail: _*)
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
version=1.0.4-rc2
version=1.0.4-rc3
SONATYPE_AUTOMATIC_RELEASE=true
POM_ARTIFACT_ID=feathr_2.12

0 comments on commit ff14050

Please sign in to comment.