Skip to content

Commit

Permalink
[DROOLS-6892] Wrapper class unwrapped value is not indexed in exec-mo…
Browse files Browse the repository at this point in the history
…del (apache#4318)

* [DROOLS-6892] Wrapper class unwrapped value is not indexed in exec-model

* - remove syso
  • Loading branch information
tkobayas committed Apr 21, 2022
1 parent dc0c474 commit 68c4afa
Show file tree
Hide file tree
Showing 8 changed files with 209 additions and 20 deletions.
56 changes: 48 additions & 8 deletions drools-core/src/main/java/org/drools/core/util/ClassUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,16 @@ public static Class<?> findClass(String className, ClassLoader cl) {
return null;
}

// Used for exec-model DomainClassMetadata and index
public static List<String> getAccessiblePropertiesIncludingNonGetterValueMethod(Class<?> clazz) {
List<String> accessibleProperties = getAccessibleProperties(clazz);

// Add nonGetterValueMethods at last so property reactivity mask index isn't affected
accessibleProperties.addAll(getNonGetterValueMethods(clazz, accessibleProperties));
return accessibleProperties;
}

// Used for property reactivity
public static List<String> getAccessibleProperties( Class<?> clazz ) {
Set<PropertyInClass> props = new TreeSet<>();
for (Method m : clazz.getMethods()) {
Expand All @@ -445,12 +455,42 @@ public static List<String> getAccessibleProperties( Class<?> clazz ) {
}

List<String> accessibleProperties = new ArrayList<>();
for ( PropertyInClass setter : props ) {
accessibleProperties.add(setter.setter);
for ( PropertyInClass propInClass : props ) {
accessibleProperties.add(propInClass.prop);
}
return accessibleProperties;
}

public static List<String> getNonGetterValueMethods(Class<?> clazz, List<String> accessibleProperties) {
Set<PropertyInClass> nonGetterValueMethodInClassSet = new TreeSet<>();
for (Method m : clazz.getMethods()) {
String propName = getter2property(m.getName());
if (propName == null) {
String methodName = filterNonGetterValueMethod(m);
if (methodName != null) {
nonGetterValueMethodInClassSet.add(new PropertyInClass(methodName, m.getDeclaringClass()));
}
}
}

List<String> nonGetterValueMethods = new ArrayList<>();
for (PropertyInClass propInClass : nonGetterValueMethodInClassSet) {
if (!accessibleProperties.contains(propInClass.prop)) {
nonGetterValueMethods.add(propInClass.prop);
}
}
return nonGetterValueMethods;
}

private static String filterNonGetterValueMethod(Method m) {
String methodName = m.getName();
if (m.getParameterTypes().length == 0 && !m.getReturnType().equals(void.class) && methodName != "toString" && methodName != "hashCode") {
return m.getName(); // e.g. Person.calcAge(), Integer.intValue()
} else {
return null;
}
}

public static Field getField(Class<?> clazz, String field) {
try {
return clazz.getDeclaredField( field );
Expand Down Expand Up @@ -709,18 +749,18 @@ public static boolean isInterface(Class<?> clazz) {
}

private static class PropertyInClass implements Comparable {
private final String setter;
private final String prop;
private final Class<?> clazz;

private PropertyInClass( String setter, Class<?> clazz ) {
this.setter = setter;
private PropertyInClass( String prop, Class<?> clazz ) {
this.prop = prop;
this.clazz = clazz;
}

public int compareTo(Object o) {
PropertyInClass other = (PropertyInClass) o;
if (clazz == other.clazz) {
return setter.compareTo(other.setter);
return prop.compareTo(other.prop);
}
return clazz.isAssignableFrom(other.clazz) ? -1 : 1;
}
Expand All @@ -731,12 +771,12 @@ public boolean equals(Object obj) {
return false;
}
PropertyInClass other = (PropertyInClass) obj;
return clazz == other.clazz && setter.equals(other.setter);
return clazz == other.clazz && prop.equals(other.prop);
}

@Override
public int hashCode() {
return 29 * clazz.hashCode() + 31 * setter.hashCode();
return 29 * clazz.hashCode() + 31 * prop.hashCode();
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
import static org.drools.modelcompiler.builder.generator.DslMethodNames.createDslTopLevelMethod;
import static org.drools.modelcompiler.builder.generator.QueryGenerator.QUERY_METHOD_PREFIX;
import static org.drools.modelcompiler.util.ClassUtil.asJavaSourceName;
import static org.drools.modelcompiler.util.ClassUtil.getAccessibleProperties;
import static org.drools.modelcompiler.util.ClassUtil.getAccessiblePropertiesIncludingNonGetterValueMethod;

public class PackageModel {

Expand Down Expand Up @@ -873,7 +873,7 @@ public String getDomainClassesMetadataSource() {
);
for (Class<?> domainClass : domainClasses) {
String domainClassSourceName = asJavaSourceName( domainClass );
List<String> accessibleProperties = getAccessibleProperties( domainClass );
List<String> accessibleProperties = getAccessiblePropertiesIncludingNonGetterValueMethod( domainClass );
accessibleProperties = accessibleProperties.stream().distinct().collect(Collectors.toList());
sb.append( " public static final " + DomainClassMetadata.class.getCanonicalName() + " " + domainClassSourceName + DOMAIN_CLASS_METADATA_INSTANCE + " = new " + domainClassSourceName+ "_Metadata();\n" );
sb.append( " private static class " + domainClassSourceName + "_Metadata implements " + DomainClassMetadata.class.getCanonicalName() + " {\n\n" );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toJavaParserType;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toStringLiteral;
import static org.drools.modelcompiler.builder.generator.drlxparse.ConstraintParser.toBigDecimalExpression;
import static org.drools.modelcompiler.util.ClassUtil.isAccessibleProperties;
import static org.drools.modelcompiler.util.ClassUtil.isAccessiblePropertiesIncludingNonGetterValueMethod;
import static org.drools.modelcompiler.util.ClassUtil.toRawClass;
import static org.drools.mvel.parser.printer.PrintUtil.printNode;

Expand Down Expand Up @@ -316,7 +316,7 @@ protected void addIndexedByDeclaration(TypedExpression left,
}

String getIndexIdArgument(SingleDrlxParseSuccess drlxParseResult, TypedExpression left) {
return isAccessibleProperties( drlxParseResult.getPatternType(), left.getFieldName() ) ?
return isAccessiblePropertiesIncludingNonGetterValueMethod( drlxParseResult.getPatternType(), left.getFieldName() ) ?
context.getPackageModel().getDomainClassName( drlxParseResult.getPatternType() ) + ".getPropertyIndex(\"" + left.getFieldName() + "\")" :
"-1";
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
import static org.drools.modelcompiler.builder.generator.DslMethodNames.VALUE_OF_CALL;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.REACT_ON_CALL;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.createDslTopLevelMethod;
import static org.drools.modelcompiler.util.lambdareplace.ReplaceTypeInLambda.replaceTypeInExprLambda;
import static org.drools.modelcompiler.util.lambdareplace.ReplaceTypeInLambda.replaceTypeInExprLambdaAndIndex;
import static org.drools.mvel.parser.printer.PrintUtil.printNode;

public class AccumulateVisitor {
Expand Down Expand Up @@ -355,7 +355,7 @@ public Optional<NewBinding> onSuccess(DrlxParseSuccess result) {

context.addDeclarationReplacing(new DeclarationSpec(bindingId, accumulateFunctionResultType));

context.getExpressions().forEach(expression -> replaceTypeInExprLambda(bindingId, accumulateFunctionResultType, expression));
context.getExpressions().forEach(expression -> replaceTypeInExprLambdaAndIndex(bindingId, accumulateFunctionResultType, expression));

List<String> ids = new ArrayList<>();
if (singleResult.getPatternBinding() != null) {
Expand Down Expand Up @@ -408,7 +408,7 @@ public Optional<NewBinding> onSuccess(DrlxParseSuccess drlxParseResult) {
singleResult.setExprBinding(bindExpressionVariable);

context.addDeclarationReplacing(new DeclarationSpec(singleResult.getPatternBinding(), exprRawClass));
context.getExpressions().forEach(expression -> replaceTypeInExprLambda(bindingId, exprRawClass, expression));
context.getExpressions().forEach(expression -> replaceTypeInExprLambdaAndIndex(bindingId, exprRawClass, expression));

functionDSL.addArgument(createAccSupplierExpr(accumulateFunction));
final MethodCallExpr newBindingFromBinary = AccumulateVisitor.this.buildBinding(bindExpressionVariable, singleResult.getUsedDeclarations(), singleResult.getExpr());
Expand All @@ -430,7 +430,7 @@ private void zeroParameterFunction(PatternDescr basePattern, MethodCallExpr func
functionDSL.addArgument(createAccSupplierExpr(accumulateFunction));
Class accumulateFunctionResultType = accumulateFunction.getResultType();
context.addDeclarationReplacing(new DeclarationSpec(bindingId, accumulateFunctionResultType));
context.getExpressions().forEach(expression -> replaceTypeInExprLambda(bindingId, accumulateFunctionResultType, expression));
context.getExpressions().forEach(expression -> replaceTypeInExprLambdaAndIndex(bindingId, accumulateFunctionResultType, expression));
}

private static MethodReferenceExpr createAccSupplierExpr(AccumulateFunction accumulateFunction) {
Expand Down Expand Up @@ -475,7 +475,7 @@ private void addBindingAsDeclaration(RuleContext context, String bindingId, Accu
context.addDeclarationReplacing(new DeclarationSpec(bindingId, accumulateFunctionResultType));
if (context.getExpressions().size() > 1) {
// replace the type of the lambda with the one resulting from the accumulate operation only in the pattern immediately before it
replaceTypeInExprLambda(bindingId, accumulateFunctionResultType, context.getExpressions().get(context.getExpressions().size()-2));
replaceTypeInExprLambdaAndIndex(bindingId, accumulateFunctionResultType, context.getExpressions().get(context.getExpressions().size()-2));
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,31 @@ public class ClassUtil {

private static final Map<Class<?>, List<String>> ACCESSIBLE_PROPS_CACHE = Collections.synchronizedMap( new WeakHashMap<>() );

private static final Map<Class<?>, List<String>> ACCESSIBLE_PROPS_CACHE_INCLUDING_NON_GETTER = Collections.synchronizedMap( new WeakHashMap<>() );

public static String asJavaSourceName( Class<?> clazz ) {
return clazz.getCanonicalName().replace( '.', '_' );
}

public static boolean isAccessibleProperties( Class<?> clazz, String prop ) {
return getAccessibleProperties( clazz ).contains( prop );
public static Class<?> javaSourceNameToClass(String javaSourceName) throws ClassNotFoundException {
String fqcn = javaSourceName.replace('_', '.');
return Class.forName(fqcn);
}

public static List<String> getAccessibleProperties( Class<?> clazz ) {
return ACCESSIBLE_PROPS_CACHE.computeIfAbsent( clazz, org.drools.core.util.ClassUtils::getAccessibleProperties );
}

public static boolean isAccessiblePropertiesIncludingNonGetterValueMethod( Class<?> clazz, String prop ) {
return getAccessiblePropertiesIncludingNonGetterValueMethod( clazz ).contains( prop );
}

// ACCESSIBLE_PROPS_CACHE_INCLUDING_NON_GETTER must contain the same order of props in ClassUtils.getAccessibleProperties() first. Then NON_GETTER methods are listed at the end.
// So index and property reactivity can share the same ACCESSIBLE_PROPS_CACHE_INCLUDING_NON_GETTER in DamainClassMetadata.getPropertyIndex()
public static List<String> getAccessiblePropertiesIncludingNonGetterValueMethod( Class<?> clazz ) {
return ACCESSIBLE_PROPS_CACHE_INCLUDING_NON_GETTER.computeIfAbsent( clazz, org.drools.core.util.ClassUtils::getAccessiblePropertiesIncludingNonGetterValueMethod );
}

public static Type boxTypePrimitive(Type type) {
if (type instanceof Class<?>) {
return MethodUtils.boxPrimitive((Class<?>)type);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,42 @@
import com.github.javaparser.ast.Node;
import com.github.javaparser.ast.body.Parameter;
import com.github.javaparser.ast.expr.Expression;
import com.github.javaparser.ast.expr.FieldAccessExpr;
import com.github.javaparser.ast.expr.LambdaExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.drools.modelcompiler.builder.PackageModel.DOMAIN_CLASS_METADATA_INSTANCE;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toClassOrInterfaceType;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toVar;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.ACCUMULATE_CALL;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.ALPHA_INDEXED_BY_CALL;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.BETA_INDEXED_BY_CALL;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.BIND_CALL;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.EVAL_EXPR_CALL;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.EXPR_CALL;
import static org.drools.modelcompiler.builder.generator.DslMethodNames.PATTERN_CALL;
import static org.drools.modelcompiler.util.ClassUtil.asJavaSourceName;
import static org.drools.modelcompiler.util.ClassUtil.javaSourceNameToClass;

public class ReplaceTypeInLambda {

private static final Logger logger = LoggerFactory.getLogger(ReplaceTypeInLambda.class);

private ReplaceTypeInLambda() {

}

public static void replaceTypeInExprLambda(String bindingId, Class accumulateFunctionResultType, Expression expression) {
public static void replaceTypeInExprLambdaAndIndex(String bindingId, Class accumulateFunctionResultType, Expression expression) {
if (expression instanceof MethodCallExpr && (( MethodCallExpr ) expression).getNameAsString().equals( ACCUMULATE_CALL )) {
return;
}
replaceTypeInExprLambda(bindingId, accumulateFunctionResultType, expression);
replaceTypeInIndex(bindingId, accumulateFunctionResultType, expression);
}

private static void replaceTypeInExprLambda(String bindingId, Class accumulateFunctionResultType, Expression expression) {
expression.findAll(MethodCallExpr.class).forEach(mc -> {
if (mc.getArguments().stream().anyMatch(a -> a.toString().equals(toVar(bindingId)))) {
List<LambdaExpr> allLambdas = new ArrayList<>();
Expand Down Expand Up @@ -82,4 +96,46 @@ private static void replaceLambdaParameter(Class accumulateFunctionResultType, L
}
}
}

private static void replaceTypeInIndex(String bindingId, Class accumulateFunctionResultType, Expression expression) {
expression.findAll(MethodCallExpr.class)
.stream()
.filter(mce -> {
String methodName = mce.getName().asString();
return (methodName.equals(ALPHA_INDEXED_BY_CALL) || methodName.equals(BETA_INDEXED_BY_CALL));
})
.forEach(mce -> {
mce.getArguments()
.stream()
.filter(MethodCallExpr.class::isInstance)
.map(MethodCallExpr.class::cast)
.filter(argMce -> argMce.getName().asString().equals("getPropertyIndex"))
.map(MethodCallExpr::getScope)
.filter(Optional::isPresent)
.map(Optional::get)
.filter(FieldAccessExpr.class::isInstance)
.map(FieldAccessExpr.class::cast)
.forEach(fieldAccessExpr -> {
Class<?> domainClass = extractDomainClass(fieldAccessExpr.getName().asString());
if (domainClass != null && domainClass != accumulateFunctionResultType && domainClass.isAssignableFrom(accumulateFunctionResultType)) {
// e.g. from java_lang_Number_Metadata_INSTANCE to java_lang_Long_Metadata_INSTANCE
fieldAccessExpr.setName(asJavaSourceName(accumulateFunctionResultType) + DOMAIN_CLASS_METADATA_INSTANCE);
}
});
});
}

private static Class<?> extractDomainClass(String domainClassInstance) {
if (!domainClassInstance.endsWith(DOMAIN_CLASS_METADATA_INSTANCE)) {
return null;
}
String javaSourceName = domainClassInstance.substring(0, domainClassInstance.lastIndexOf(DOMAIN_CLASS_METADATA_INSTANCE));
try {
return javaSourceNameToClass(javaSourceName);
} catch (ClassNotFoundException e) {
logger.info("Class not found. Not an issue unless the generated code causes a compile error : domainClassInstance = {} ", domainClassInstance);
return null;
}
}

}
Loading

0 comments on commit 68c4afa

Please sign in to comment.