Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow alien value in MVEL-based derivations #1166

Merged
merged 1 commit into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,24 @@
import com.linkedin.feathr.common.tensor.TensorIterator;
import com.linkedin.feathr.common.types.ValueType;
import com.linkedin.feathr.common.util.CoercionUtils;
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext;
import org.mvel2.DataConversion;
import org.mvel2.integration.impl.SimpleValueResolver;

import java.util.Optional;


/**
* FeatureVariableResolver takes a FeatureValue object for member variable during MVEL expression evaluation,
* and then resolve the value for that variable.
*/
public class FeatureVariableResolver extends SimpleValueResolver {
private FeatureValue _featureValue;

public FeatureVariableResolver(FeatureValue featureValue) {
private Optional<FeathrExpressionExecutionContext> _mvelContext = Optional.empty();
public FeatureVariableResolver(FeatureValue featureValue, Optional<FeathrExpressionExecutionContext> mvelContext) {
super(featureValue);
_featureValue = featureValue;
_mvelContext = mvelContext;
}

@Override
Expand All @@ -25,21 +30,27 @@ public Object getValue() {
return null;
}

Object fv = null;
switch (_featureValue.getFeatureType().getBasicType()) {
case NUMERIC:
return _featureValue.getAsNumeric();
fv = _featureValue.getAsNumeric(); break;
case TERM_VECTOR:
return getValueFromTermVector();
fv = getValueFromTermVector(); break;
case BOOLEAN:
case CATEGORICAL:
case CATEGORICAL_SET:
case DENSE_VECTOR:

case TENSOR:
return getValueFromTensor();

fv = getValueFromTensor(); break;
default:
throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType());
throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType());
}
// If there is any registered FeatureValue handler that can handle this feature value, return the converted value per request.
if (_mvelContext.isPresent() && _mvelContext.get().canConvertFromAny(fv)) {
return _mvelContext.get().convertFromAny(fv).head();
} else {
return fv;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private[offline] object PostTransformationUtil {
featureType: FeatureTypes,
mvelContext: Option[FeathrExpressionExecutionContext]): Try[FeatureValue] = Try {
val args = Map(featureName -> Some(featureValue))
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext)
val transformedValue = MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, featureValue, variableResolverFactory, mvelContext.orNull)
CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ private[offline] class MvelFeatureDerivationFunction(

override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = {
val argMap = (parameterNames zip inputs).toMap
val variableResolverFactory = new FeatureVariableResolverFactory(argMap)
val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext)

MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match {
case Some(value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ private[offline] class MvelFeatureDerivationFunction1(

override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = {
val argMap = (parameterNames zip inputs).toMap
val variableResolverFactory = new FeatureVariableResolverFactory(argMap)
val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext)

MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match {
case Some(value) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ private[offline] class SimpleMvelDerivationFunction(expression: String, featureN
MvelContext.ensureInitialized()

// In order to prevent MVEL from barfing if a feature is null, we use a custom variable resolver that understands `Option`
val variableResolverFactory = new FeatureVariableResolverFactory(args)
val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext)

if (TestFwkUtils.IS_DEBUGGER_ENABLED) {
while(TestFwkUtils.DERIVED_FEATURE_COUNTER > 0) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
package com.linkedin.feathr.offline.mvel

import com.linkedin.feathr.common.{FeatureValue, FeatureVariableResolver}
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import org.mvel2.integration.VariableResolver
import org.mvel2.integration.impl.BaseVariableResolverFactory

import java.util.Optional
import scala.collection.JavaConverters._

private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]]) extends BaseVariableResolverFactory {
variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull)).asInstanceOf[Map[String, VariableResolver]].asJava
private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]], mvelContext: Option[FeathrExpressionExecutionContext]) extends BaseVariableResolverFactory {

variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull, Optional.ofNullable(mvelContext.orNull))).asInstanceOf[Map[String, VariableResolver]].asJava

override def isTarget(name: String): Boolean = features.contains(name)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class FeathrExpressionExecutionContext extends Serializable {
*/
def canConvert(toType: Class[_], convertFrom: Class[_]): Boolean = {
if (isAssignableFrom(toType, convertFrom)) return true
if (isAssignableFrom(classOf[FeatureValueWrapper[toType.type]], convertFrom)) return true
if (converters.value.contains(toType.getCanonicalName)) {
converters.value.get(toType.getCanonicalName).get.canConvertFrom(toNonPrimitiveType(convertFrom))
} else if (toType.isArray && canConvert(toType.getComponentType, convertFrom)) {
Expand All @@ -79,6 +80,28 @@ class FeathrExpressionExecutionContext extends Serializable {
}
}

/**
* Check if there is registered converters that can handle the conversion.
* @param inputValue input value to check
* @return whether it can be converted or not
*/
def canConvertFromAny(inputValue: AnyRef): Boolean = {
val result = converters.value.filter(converter => converter._2.canConvertFrom(inputValue.getClass))
result.nonEmpty
}

/**
* Convert the input Check if there is registered converters that can handle the conversion.
* @param inputValue input value to convert
* @return return all converted values produced by registered converters
*/
def convertFromAny(inputValue: AnyRef): List[AnyRef] = {
converters.value.collect {
case converter if converter._2.canConvertFrom(inputValue.getClass) =>
converter._2.convertFrom(inputValue)
}.toList
}

/**
* Convert the input to output type using the registered converters
* @param in value to be converted
Expand All @@ -88,6 +111,9 @@ class FeathrExpressionExecutionContext extends Serializable {
*/
def convert[T](in: Any, toType: Class[T]): T = {
if ((toType eq in.getClass) || toType.isAssignableFrom(in.getClass)) return in.asInstanceOf[T]
if (isAssignableFrom(classOf[FeatureValueWrapper[toType.type]], in.getClass)) {
return in.asInstanceOf[FeatureValueWrapper[_]].getFeatureValue().asInstanceOf[T]
}
val converter = if (converters.value != null) {
converters.value.get(toType.getCanonicalName).get
} else {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.linkedin.feathr.offline.mvel.plugins

/**
* Trait that wraps a Frame or Feathr FeatureValue
* @tparam T FeatureValue type to be wrapped
*/
trait FeatureValueWrapper[T] {
// Get the wrapped feature value
def getFeatureValue(): T
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@ private AlienFeatureValue(Float floatValue, String stringValue) {
this.floatValue = floatValue;
this.stringValue = stringValue;
}

public AlienFeatureValue() {
this.floatValue = null;
this.stringValue = null;
}
public static AlienFeatureValue fromFloat(float floatValue) {
return new AlienFeatureValue(floatValue, null);
}
Expand Down
Loading
Loading