Skip to content

Commit

Permalink
Allow alien value in MVEL-based derivations (#1120) (#1166)
Browse files Browse the repository at this point in the history
Add feature value wrapper for 3rdparity feature value compatibility
  • Loading branch information
jaymo001 committed May 11, 2023
1 parent 3a4ffad commit 5c02d26
Show file tree
Hide file tree
Showing 13 changed files with 98 additions and 316 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
import com.linkedin.feathr.common.tensor.TensorIterator;
import com.linkedin.feathr.common.types.ValueType;
import com.linkedin.feathr.common.util.CoercionUtils;
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext;
import org.mvel2.DataConversion;
import org.mvel2.integration.impl.SimpleValueResolver;

import java.util.Optional;


/**
* FeatureVariableResolver takes a FeatureValue object for member variable during MVEL expression evaluation,
* and then resolve the value for that variable.
*/
public class FeatureVariableResolver extends SimpleValueResolver {
private FeatureValue _featureValue;

public FeatureVariableResolver(FeatureValue featureValue) {
private Optional<FeathrExpressionExecutionContext> _mvelContext = Optional.empty();
public FeatureVariableResolver(FeatureValue featureValue, Optional<FeathrExpressionExecutionContext> mvelContext) {
super(featureValue);
_featureValue = featureValue;
_mvelContext = mvelContext;
}

@Override
Expand All @@ -25,21 +30,27 @@ public Object getValue() {
return null;
}

Object fv = null;
switch (_featureValue.getFeatureType().getBasicType()) {
case NUMERIC:
return _featureValue.getAsNumeric();
fv = _featureValue.getAsNumeric(); break;
case TERM_VECTOR:
return getValueFromTermVector();
fv = getValueFromTermVector(); break;
case BOOLEAN:
case CATEGORICAL:
case CATEGORICAL_SET:
case DENSE_VECTOR:

case TENSOR:
return getValueFromTensor();

fv = getValueFromTensor(); break;
default:
throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType());
throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType());
}
// If there is any registered FeatureValue handler that can handle this feature value, return the converted value per request.
if (_mvelContext.isPresent() && _mvelContext.get().canConvertFromAny(fv)) {
return _mvelContext.get().convertFromAny(fv).head();
} else {
return fv;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private[offline] object PostTransformationUtil {
featureType: FeatureTypes,
mvelContext: Option[FeathrExpressionExecutionContext]): Try[FeatureValue] = Try {
val args = Map(featureName -> Some(featureValue))
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext)
val transformedValue = MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, featureValue, variableResolverFactory, mvelContext.orNull)
CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private[offline] class MvelFeatureDerivationFunction(

override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = {
val argMap = (parameterNames zip inputs).toMap
val variableResolverFactory = new FeatureVariableResolverFactory(argMap)
val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext)

MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match {
case Some(value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[offline] class MvelFeatureDerivationFunction1(

override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = {
val argMap = (parameterNames zip inputs).toMap
val variableResolverFactory = new FeatureVariableResolverFactory(argMap)
val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext)

MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match {
case Some(value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[offline] class SimpleMvelDerivationFunction(expression: String, featureN
MvelContext.ensureInitialized()

// In order to prevent MVEL from barfing if a feature is null, we use a custom variable resolver that understands `Option`
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext)

if (TestFwkUtils.IS_DEBUGGER_ENABLED) {
while(TestFwkUtils.DERIVED_FEATURE_COUNTER > 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.linkedin.feathr.offline.mvel

import com.linkedin.feathr.common.{FeatureValue, FeatureVariableResolver}
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import org.mvel2.integration.VariableResolver
import org.mvel2.integration.impl.BaseVariableResolverFactory

import java.util.Optional
import scala.collection.JavaConverters._

private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]]) extends BaseVariableResolverFactory {
variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull)).asInstanceOf[Map[String, VariableResolver]].asJava
private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]], mvelContext: Option[FeathrExpressionExecutionContext]) extends BaseVariableResolverFactory {

variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull, Optional.ofNullable(mvelContext.orNull))).asInstanceOf[Map[String, VariableResolver]].asJava

override def isTarget(name: String): Boolean = features.contains(name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class FeathrExpressionExecutionContext extends Serializable {
*/
def canConvert(toType: Class[_], convertFrom: Class[_]): Boolean = {
if (isAssignableFrom(toType, convertFrom)) return true
if (isAssignableFrom(classOf[FeatureValueWrapper[toType.type]], convertFrom)) return true
if (converters.value.contains(toType.getCanonicalName)) {
converters.value.get(toType.getCanonicalName).get.canConvertFrom(toNonPrimitiveType(convertFrom))
} else if (toType.isArray && canConvert(toType.getComponentType, convertFrom)) {
Expand All @@ -79,6 +80,28 @@ class FeathrExpressionExecutionContext extends Serializable {
}
}

/**
* Check if there is registered converters that can handle the conversion.
* @param inputValue input value to check
* @return whether it can be converted or not
*/
def canConvertFromAny(inputValue: AnyRef): Boolean = {
val result = converters.value.filter(converter => converter._2.canConvertFrom(inputValue.getClass))
result.nonEmpty
}

/**
* Convert the input Check if there is registered converters that can handle the conversion.
* @param inputValue input value to convert
* @return return all converted values produced by registered converters
*/
def convertFromAny(inputValue: AnyRef): List[AnyRef] = {
converters.value.collect {
case converter if converter._2.canConvertFrom(inputValue.getClass) =>
converter._2.convertFrom(inputValue)
}.toList
}

/**
* Convert the input to output type using the registered converters
* @param in value to be converted
Expand All @@ -88,6 +111,9 @@ class FeathrExpressionExecutionContext extends Serializable {
*/
def convert[T](in: Any, toType: Class[T]): T = {
if ((toType eq in.getClass) || toType.isAssignableFrom(in.getClass)) return in.asInstanceOf[T]
if (isAssignableFrom(classOf[FeatureValueWrapper[toType.type]], in.getClass)) {
return in.asInstanceOf[FeatureValueWrapper[_]].getFeatureValue().asInstanceOf[T]
}
val converter = if (converters.value != null) {
converters.value.get(toType.getCanonicalName).get
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.linkedin.feathr.offline.mvel.plugins

/**
* Trait that wraps a Frame or Feathr FeatureValue
* @tparam T FeatureValue type to be wrapped
*/
trait FeatureValueWrapper[T] {
// Get the wrapped feature value
def getFeatureValue(): T
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ private AlienFeatureValue(Float floatValue, String stringValue) {
this.floatValue = floatValue;
this.stringValue = stringValue;
}

public AlienFeatureValue() {
this.floatValue = null;
this.stringValue = null;
}
public static AlienFeatureValue fromFloat(float floatValue) {
return new AlienFeatureValue(floatValue, null);
}
Expand Down
Loading

0 comments on commit 5c02d26

Please sign in to comment.