Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow alien value in MVEL-based derivations #1120

Merged
merged 1 commit into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
import com.linkedin.feathr.common.tensor.TensorIterator;
import com.linkedin.feathr.common.types.ValueType;
import com.linkedin.feathr.common.util.CoercionUtils;
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext;
import org.mvel2.DataConversion;
import org.mvel2.integration.impl.SimpleValueResolver;

import java.util.Optional;


/**
* FeatureVariableResolver takes a FeatureValue object for member variable during MVEL expression evaluation,
* and then resolve the value for that variable.
*/
public class FeatureVariableResolver extends SimpleValueResolver {
private FeatureValue _featureValue;

public FeatureVariableResolver(FeatureValue featureValue) {
private Optional<FeathrExpressionExecutionContext> _mvelContext = Optional.empty();
public FeatureVariableResolver(FeatureValue featureValue, Optional<FeathrExpressionExecutionContext> mvelContext) {
super(featureValue);
_featureValue = featureValue;
_mvelContext = mvelContext;
}

@Override
Expand All @@ -25,21 +30,27 @@ public Object getValue() {
return null;
}

Object fv = null;
switch (_featureValue.getFeatureType().getBasicType()) {
case NUMERIC:
return _featureValue.getAsNumeric();
fv = _featureValue.getAsNumeric(); break;
case TERM_VECTOR:
return getValueFromTermVector();
fv = getValueFromTermVector(); break;
case BOOLEAN:
case CATEGORICAL:
case CATEGORICAL_SET:
case DENSE_VECTOR:

case TENSOR:
return getValueFromTensor();

fv = getValueFromTensor(); break;
default:
throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType());
throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType());
}
// If there is any registered FeatureValue handler that can handle this feature value, return the converted value per request.
if (_mvelContext.isPresent() && _mvelContext.get().canConvertFromAny(fv)) {
return _mvelContext.get().convertFromAny(fv).head();
} else {
return fv;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private[offline] object PostTransformationUtil {
featureType: FeatureTypes,
mvelContext: Option[FeathrExpressionExecutionContext]): Try[FeatureValue] = Try {
val args = Map(featureName -> Some(featureValue))
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext)
val transformedValue = MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, featureValue, variableResolverFactory, mvelContext.orNull)
CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private[offline] class MvelFeatureDerivationFunction(

override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = {
val argMap = (parameterNames zip inputs).toMap
val variableResolverFactory = new FeatureVariableResolverFactory(argMap)
val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext)

MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match {
case Some(value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[offline] class MvelFeatureDerivationFunction1(

override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = {
val argMap = (parameterNames zip inputs).toMap
val variableResolverFactory = new FeatureVariableResolverFactory(argMap)
val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext)

MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match {
case Some(value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[offline] class SimpleMvelDerivationFunction(expression: String, featureN
MvelContext.ensureInitialized()

// In order to prevent MVEL from barfing if a feature is null, we use a custom variable resolver that understands `Option`
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext)

if (TestFwkUtils.IS_DEBUGGER_ENABLED) {
while(TestFwkUtils.DERIVED_FEATURE_COUNTER > 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.linkedin.feathr.offline.mvel

import com.linkedin.feathr.common.{FeatureValue, FeatureVariableResolver}
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import org.mvel2.integration.VariableResolver
import org.mvel2.integration.impl.BaseVariableResolverFactory

import java.util.Optional
import scala.collection.JavaConverters._

private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]]) extends BaseVariableResolverFactory {
variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull)).asInstanceOf[Map[String, VariableResolver]].asJava
private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]], mvelContext: Option[FeathrExpressionExecutionContext]) extends BaseVariableResolverFactory {

variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull, Optional.ofNullable(mvelContext.orNull))).asInstanceOf[Map[String, VariableResolver]].asJava

override def isTarget(name: String): Boolean = features.contains(name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,28 @@ class FeathrExpressionExecutionContext extends Serializable {
}
}

/**
* Check if there is registered converters that can handle the conversion.
* @param inputValue input value to check
* @return whether it can be converted or not
*/
def canConvertFromAny(inputValue: AnyRef): Boolean = {
val result = converters.value.filter(converter => converter._2.canConvertFrom(inputValue.getClass))
result.nonEmpty
}

/**
* Convert the input Check if there is registered converters that can handle the conversion.
* @param inputValue input value to convert
* @return return all converted values produced by registered converters
*/
def convertFromAny(inputValue: AnyRef): List[AnyRef] = {
converters.value.collect {
case converter if converter._2.canConvertFrom(inputValue.getClass) =>
converter._2.convertFrom(inputValue)
}.toList
}

/**
* Convert the input to output type using the registered converters
* @param in value to be converted
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
package com.linkedin.feathr.offline

import com.linkedin.feathr.common.FeatureTypes
import com.linkedin.feathr.offline.anchored.keyExtractor.AlienSourceKeyExtractorAdaptor
import com.linkedin.feathr.offline.client.plugins.FeathrUdfPluginContext
import com.linkedin.feathr.offline.derived.AlienDerivationFunctionAdaptor
Expand All @@ -9,7 +8,6 @@ import com.linkedin.feathr.offline.plugins.{AlienFeatureValue, AlienFeatureValue
import com.linkedin.feathr.offline.util.FeathrTestUtils
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{FloatType, StringType, StructField, StructType}
import org.testng.Assert.assertEquals
import org.testng.annotations.Test

class TestFeathrUdfPlugins extends FeathrIntegTest {
Expand All @@ -19,7 +17,7 @@ class TestFeathrUdfPlugins extends FeathrIntegTest {
private val mvelContext = new FeathrExpressionExecutionContext()

// todo - support udf plugins through FCM
@Test (enabled = false)
@Test (enabled = true)
def testMvelUdfPluginSupport: Unit = {
mvelContext.setupExecutorMvelContext(classOf[AlienFeatureValue], new AlienFeatureValueTypeAdaptor(), ss.sparkContext)
FeathrUdfPluginContext.registerUdfAdaptor(new AlienDerivationFunctionAdaptor(), ss.sparkContext)
Expand Down Expand Up @@ -113,8 +111,6 @@ class TestFeathrUdfPlugins extends FeathrIntegTest {
observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv",
mvelContext = Some(mvelContext))

val f8Type = df.fdsMetadata.header.get.featureInfoMap.filter(_._1.getFeatureName == "f8").head._2.featureType.getFeatureType
assertEquals(f8Type, FeatureTypes.NUMERIC)

val selectedColumns = Seq("a_id", "fA")
val filteredDf = df.data.select(selectedColumns.head, selectedColumns.tail: _*)
Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
version=1.0.1-rc1
version=1.0.2-rc1
SONATYPE_AUTOMATIC_RELEASE=true
POM_ARTIFACT_ID=feathr_2.12