From 68c4afa383e6e83e07428855bbcd2b0590161056 Mon Sep 17 00:00:00 2001 From: Toshiya Kobayashi Date: Thu, 21 Apr 2022 17:21:59 +0900 Subject: [PATCH] [DROOLS-6892] Wrapper class unwrapped value is not indexed in exec-model (#4318) * [DROOLS-6892] Wrapper class unwrapped value is not indexed in exec-model * - remove syso --- .../java/org/drools/core/util/ClassUtils.java | 56 ++++++++++++-- .../modelcompiler/builder/PackageModel.java | 4 +- .../expression/AbstractExpressionBuilder.java | 4 +- .../visitor/accumulate/AccumulateVisitor.java | 10 +-- .../drools/modelcompiler/util/ClassUtil.java | 17 ++++- .../lambdareplace/ReplaceTypeInLambda.java | 58 +++++++++++++- .../org/drools/modelcompiler/IndexTest.java | 76 +++++++++++++++++++ .../drools/modelcompiler/domain/Person.java | 4 + 8 files changed, 209 insertions(+), 20 deletions(-) diff --git a/drools-core/src/main/java/org/drools/core/util/ClassUtils.java b/drools-core/src/main/java/org/drools/core/util/ClassUtils.java index 2dc794dc9aa..63e3580b8a7 100644 --- a/drools-core/src/main/java/org/drools/core/util/ClassUtils.java +++ b/drools-core/src/main/java/org/drools/core/util/ClassUtils.java @@ -425,6 +425,16 @@ public static Class findClass(String className, ClassLoader cl) { return null; } + // Used for exec-model DomainClassMetadata and index + public static List getAccessiblePropertiesIncludingNonGetterValueMethod(Class clazz) { + List accessibleProperties = getAccessibleProperties(clazz); + + // Add nonGetterValueMethods at last so property reactivity mask index isn't affected + accessibleProperties.addAll(getNonGetterValueMethods(clazz, accessibleProperties)); + return accessibleProperties; + } + + // Used for property reactivity public static List getAccessibleProperties( Class clazz ) { Set props = new TreeSet<>(); for (Method m : clazz.getMethods()) { @@ -445,12 +455,42 @@ public static List getAccessibleProperties( Class clazz ) { } List accessibleProperties = new ArrayList<>(); - for ( PropertyInClass setter : props ) { - accessibleProperties.add(setter.setter); + for ( PropertyInClass propInClass : props ) { + accessibleProperties.add(propInClass.prop); } return accessibleProperties; } + public static List getNonGetterValueMethods(Class clazz, List accessibleProperties) { + Set nonGetterValueMethodInClassSet = new TreeSet<>(); + for (Method m : clazz.getMethods()) { + String propName = getter2property(m.getName()); + if (propName == null) { + String methodName = filterNonGetterValueMethod(m); + if (methodName != null) { + nonGetterValueMethodInClassSet.add(new PropertyInClass(methodName, m.getDeclaringClass())); + } + } + } + + List nonGetterValueMethods = new ArrayList<>(); + for (PropertyInClass propInClass : nonGetterValueMethodInClassSet) { + if (!accessibleProperties.contains(propInClass.prop)) { + nonGetterValueMethods.add(propInClass.prop); + } + } + return nonGetterValueMethods; + } + + private static String filterNonGetterValueMethod(Method m) { + String methodName = m.getName(); + if (m.getParameterTypes().length == 0 && !m.getReturnType().equals(void.class) && methodName != "toString" && methodName != "hashCode") { + return m.getName(); // e.g. Person.calcAge(), Integer.intValue() + } else { + return null; + } + } + public static Field getField(Class clazz, String field) { try { return clazz.getDeclaredField( field ); @@ -709,18 +749,18 @@ public static boolean isInterface(Class clazz) { } private static class PropertyInClass implements Comparable { - private final String setter; + private final String prop; private final Class clazz; - private PropertyInClass( String setter, Class clazz ) { - this.setter = setter; + private PropertyInClass( String prop, Class clazz ) { + this.prop = prop; this.clazz = clazz; } public int compareTo(Object o) { PropertyInClass other = (PropertyInClass) o; if (clazz == other.clazz) { - return setter.compareTo(other.setter); + return prop.compareTo(other.prop); } return clazz.isAssignableFrom(other.clazz) ? -1 : 1; } @@ -731,12 +771,12 @@ public boolean equals(Object obj) { return false; } PropertyInClass other = (PropertyInClass) obj; - return clazz == other.clazz && setter.equals(other.setter); + return clazz == other.clazz && prop.equals(other.prop); } @Override public int hashCode() { - return 29 * clazz.hashCode() + 31 * setter.hashCode(); + return 29 * clazz.hashCode() + 31 * prop.hashCode(); } } diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/PackageModel.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/PackageModel.java index 8e651cfff26..123c4221a60 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/PackageModel.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/PackageModel.java @@ -102,7 +102,7 @@ import static org.drools.modelcompiler.builder.generator.DslMethodNames.createDslTopLevelMethod; import static org.drools.modelcompiler.builder.generator.QueryGenerator.QUERY_METHOD_PREFIX; import static org.drools.modelcompiler.util.ClassUtil.asJavaSourceName; -import static org.drools.modelcompiler.util.ClassUtil.getAccessibleProperties; +import static org.drools.modelcompiler.util.ClassUtil.getAccessiblePropertiesIncludingNonGetterValueMethod; public class PackageModel { @@ -873,7 +873,7 @@ public String getDomainClassesMetadataSource() { ); for (Class domainClass : domainClasses) { String domainClassSourceName = asJavaSourceName( domainClass ); - List accessibleProperties = getAccessibleProperties( domainClass ); + List accessibleProperties = getAccessiblePropertiesIncludingNonGetterValueMethod( domainClass ); accessibleProperties = accessibleProperties.stream().distinct().collect(Collectors.toList()); sb.append( " public static final " + DomainClassMetadata.class.getCanonicalName() + " " + domainClassSourceName + DOMAIN_CLASS_METADATA_INSTANCE + " = new " + domainClassSourceName+ "_Metadata();\n" ); sb.append( " private static class " + domainClassSourceName + "_Metadata implements " + DomainClassMetadata.class.getCanonicalName() + " {\n\n" ); diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expression/AbstractExpressionBuilder.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expression/AbstractExpressionBuilder.java index 9d63fc5160c..3f44089a3ab 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expression/AbstractExpressionBuilder.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/expression/AbstractExpressionBuilder.java @@ -66,7 +66,7 @@ import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toJavaParserType; import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toStringLiteral; import static org.drools.modelcompiler.builder.generator.drlxparse.ConstraintParser.toBigDecimalExpression; -import static org.drools.modelcompiler.util.ClassUtil.isAccessibleProperties; +import static org.drools.modelcompiler.util.ClassUtil.isAccessiblePropertiesIncludingNonGetterValueMethod; import static org.drools.modelcompiler.util.ClassUtil.toRawClass; import static org.drools.mvel.parser.printer.PrintUtil.printNode; @@ -316,7 +316,7 @@ protected void addIndexedByDeclaration(TypedExpression left, } String getIndexIdArgument(SingleDrlxParseSuccess drlxParseResult, TypedExpression left) { - return isAccessibleProperties( drlxParseResult.getPatternType(), left.getFieldName() ) ? + return isAccessiblePropertiesIncludingNonGetterValueMethod( drlxParseResult.getPatternType(), left.getFieldName() ) ? context.getPackageModel().getDomainClassName( drlxParseResult.getPatternType() ) + ".getPropertyIndex(\"" + left.getFieldName() + "\")" : "-1"; } diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/visitor/accumulate/AccumulateVisitor.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/visitor/accumulate/AccumulateVisitor.java index 88a19839717..50bb6e59767 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/visitor/accumulate/AccumulateVisitor.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/builder/generator/visitor/accumulate/AccumulateVisitor.java @@ -81,7 +81,7 @@ import static org.drools.modelcompiler.builder.generator.DslMethodNames.VALUE_OF_CALL; import static org.drools.modelcompiler.builder.generator.DslMethodNames.REACT_ON_CALL; import static org.drools.modelcompiler.builder.generator.DslMethodNames.createDslTopLevelMethod; -import static org.drools.modelcompiler.util.lambdareplace.ReplaceTypeInLambda.replaceTypeInExprLambda; +import static org.drools.modelcompiler.util.lambdareplace.ReplaceTypeInLambda.replaceTypeInExprLambdaAndIndex; import static org.drools.mvel.parser.printer.PrintUtil.printNode; public class AccumulateVisitor { @@ -355,7 +355,7 @@ public Optional onSuccess(DrlxParseSuccess result) { context.addDeclarationReplacing(new DeclarationSpec(bindingId, accumulateFunctionResultType)); - context.getExpressions().forEach(expression -> replaceTypeInExprLambda(bindingId, accumulateFunctionResultType, expression)); + context.getExpressions().forEach(expression -> replaceTypeInExprLambdaAndIndex(bindingId, accumulateFunctionResultType, expression)); List ids = new ArrayList<>(); if (singleResult.getPatternBinding() != null) { @@ -408,7 +408,7 @@ public Optional onSuccess(DrlxParseSuccess drlxParseResult) { singleResult.setExprBinding(bindExpressionVariable); context.addDeclarationReplacing(new DeclarationSpec(singleResult.getPatternBinding(), exprRawClass)); - context.getExpressions().forEach(expression -> replaceTypeInExprLambda(bindingId, exprRawClass, expression)); + context.getExpressions().forEach(expression -> replaceTypeInExprLambdaAndIndex(bindingId, exprRawClass, expression)); functionDSL.addArgument(createAccSupplierExpr(accumulateFunction)); final MethodCallExpr newBindingFromBinary = AccumulateVisitor.this.buildBinding(bindExpressionVariable, singleResult.getUsedDeclarations(), singleResult.getExpr()); @@ -430,7 +430,7 @@ private void zeroParameterFunction(PatternDescr basePattern, MethodCallExpr func functionDSL.addArgument(createAccSupplierExpr(accumulateFunction)); Class accumulateFunctionResultType = accumulateFunction.getResultType(); context.addDeclarationReplacing(new DeclarationSpec(bindingId, accumulateFunctionResultType)); - context.getExpressions().forEach(expression -> replaceTypeInExprLambda(bindingId, accumulateFunctionResultType, expression)); + context.getExpressions().forEach(expression -> replaceTypeInExprLambdaAndIndex(bindingId, accumulateFunctionResultType, expression)); } private static MethodReferenceExpr createAccSupplierExpr(AccumulateFunction accumulateFunction) { @@ -475,7 +475,7 @@ private void addBindingAsDeclaration(RuleContext context, String bindingId, Accu context.addDeclarationReplacing(new DeclarationSpec(bindingId, accumulateFunctionResultType)); if (context.getExpressions().size() > 1) { // replace the type of the lambda with the one resulting from the accumulate operation only in the pattern immediately before it - replaceTypeInExprLambda(bindingId, accumulateFunctionResultType, context.getExpressions().get(context.getExpressions().size()-2)); + replaceTypeInExprLambdaAndIndex(bindingId, accumulateFunctionResultType, context.getExpressions().get(context.getExpressions().size()-2)); } } } diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/util/ClassUtil.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/util/ClassUtil.java index d6a7e5fa584..33ff81e3d10 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/util/ClassUtil.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/util/ClassUtil.java @@ -30,18 +30,31 @@ public class ClassUtil { private static final Map, List> ACCESSIBLE_PROPS_CACHE = Collections.synchronizedMap( new WeakHashMap<>() ); + private static final Map, List> ACCESSIBLE_PROPS_CACHE_INCLUDING_NON_GETTER = Collections.synchronizedMap( new WeakHashMap<>() ); + public static String asJavaSourceName( Class clazz ) { return clazz.getCanonicalName().replace( '.', '_' ); } - public static boolean isAccessibleProperties( Class clazz, String prop ) { - return getAccessibleProperties( clazz ).contains( prop ); + public static Class javaSourceNameToClass(String javaSourceName) throws ClassNotFoundException { + String fqcn = javaSourceName.replace('_', '.'); + return Class.forName(fqcn); } public static List getAccessibleProperties( Class clazz ) { return ACCESSIBLE_PROPS_CACHE.computeIfAbsent( clazz, org.drools.core.util.ClassUtils::getAccessibleProperties ); } + public static boolean isAccessiblePropertiesIncludingNonGetterValueMethod( Class clazz, String prop ) { + return getAccessiblePropertiesIncludingNonGetterValueMethod( clazz ).contains( prop ); + } + + // ACCESSIBLE_PROPS_CACHE_INCLUDING_NON_GETTER must contain the same order of props in ClassUtils.getAccessibleProperties() first. Then NON_GETTER methods are listed at the end. + // So index and property reactivity can share the same ACCESSIBLE_PROPS_CACHE_INCLUDING_NON_GETTER in DamainClassMetadata.getPropertyIndex() + public static List getAccessiblePropertiesIncludingNonGetterValueMethod( Class clazz ) { + return ACCESSIBLE_PROPS_CACHE_INCLUDING_NON_GETTER.computeIfAbsent( clazz, org.drools.core.util.ClassUtils::getAccessiblePropertiesIncludingNonGetterValueMethod ); + } + public static Type boxTypePrimitive(Type type) { if (type instanceof Class) { return MethodUtils.boxPrimitive((Class)type); diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/util/lambdareplace/ReplaceTypeInLambda.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/util/lambdareplace/ReplaceTypeInLambda.java index 903b2168729..8e44b9462df 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/util/lambdareplace/ReplaceTypeInLambda.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/util/lambdareplace/ReplaceTypeInLambda.java @@ -24,28 +24,42 @@ import com.github.javaparser.ast.Node; import com.github.javaparser.ast.body.Parameter; import com.github.javaparser.ast.expr.Expression; +import com.github.javaparser.ast.expr.FieldAccessExpr; import com.github.javaparser.ast.expr.LambdaExpr; import com.github.javaparser.ast.expr.MethodCallExpr; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import static org.drools.modelcompiler.builder.PackageModel.DOMAIN_CLASS_METADATA_INSTANCE; import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toClassOrInterfaceType; import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toVar; import static org.drools.modelcompiler.builder.generator.DslMethodNames.ACCUMULATE_CALL; +import static org.drools.modelcompiler.builder.generator.DslMethodNames.ALPHA_INDEXED_BY_CALL; +import static org.drools.modelcompiler.builder.generator.DslMethodNames.BETA_INDEXED_BY_CALL; import static org.drools.modelcompiler.builder.generator.DslMethodNames.BIND_CALL; import static org.drools.modelcompiler.builder.generator.DslMethodNames.EVAL_EXPR_CALL; import static org.drools.modelcompiler.builder.generator.DslMethodNames.EXPR_CALL; import static org.drools.modelcompiler.builder.generator.DslMethodNames.PATTERN_CALL; +import static org.drools.modelcompiler.util.ClassUtil.asJavaSourceName; +import static org.drools.modelcompiler.util.ClassUtil.javaSourceNameToClass; public class ReplaceTypeInLambda { + private static final Logger logger = LoggerFactory.getLogger(ReplaceTypeInLambda.class); + private ReplaceTypeInLambda() { } - public static void replaceTypeInExprLambda(String bindingId, Class accumulateFunctionResultType, Expression expression) { + public static void replaceTypeInExprLambdaAndIndex(String bindingId, Class accumulateFunctionResultType, Expression expression) { if (expression instanceof MethodCallExpr && (( MethodCallExpr ) expression).getNameAsString().equals( ACCUMULATE_CALL )) { return; } + replaceTypeInExprLambda(bindingId, accumulateFunctionResultType, expression); + replaceTypeInIndex(bindingId, accumulateFunctionResultType, expression); + } + private static void replaceTypeInExprLambda(String bindingId, Class accumulateFunctionResultType, Expression expression) { expression.findAll(MethodCallExpr.class).forEach(mc -> { if (mc.getArguments().stream().anyMatch(a -> a.toString().equals(toVar(bindingId)))) { List allLambdas = new ArrayList<>(); @@ -82,4 +96,46 @@ private static void replaceLambdaParameter(Class accumulateFunctionResultType, L } } } + + private static void replaceTypeInIndex(String bindingId, Class accumulateFunctionResultType, Expression expression) { + expression.findAll(MethodCallExpr.class) + .stream() + .filter(mce -> { + String methodName = mce.getName().asString(); + return (methodName.equals(ALPHA_INDEXED_BY_CALL) || methodName.equals(BETA_INDEXED_BY_CALL)); + }) + .forEach(mce -> { + mce.getArguments() + .stream() + .filter(MethodCallExpr.class::isInstance) + .map(MethodCallExpr.class::cast) + .filter(argMce -> argMce.getName().asString().equals("getPropertyIndex")) + .map(MethodCallExpr::getScope) + .filter(Optional::isPresent) + .map(Optional::get) + .filter(FieldAccessExpr.class::isInstance) + .map(FieldAccessExpr.class::cast) + .forEach(fieldAccessExpr -> { + Class domainClass = extractDomainClass(fieldAccessExpr.getName().asString()); + if (domainClass != null && domainClass != accumulateFunctionResultType && domainClass.isAssignableFrom(accumulateFunctionResultType)) { + // e.g. from java_lang_Number_Metadata_INSTANCE to java_lang_Long_Metadata_INSTANCE + fieldAccessExpr.setName(asJavaSourceName(accumulateFunctionResultType) + DOMAIN_CLASS_METADATA_INSTANCE); + } + }); + }); + } + + private static Class extractDomainClass(String domainClassInstance) { + if (!domainClassInstance.endsWith(DOMAIN_CLASS_METADATA_INSTANCE)) { + return null; + } + String javaSourceName = domainClassInstance.substring(0, domainClassInstance.lastIndexOf(DOMAIN_CLASS_METADATA_INSTANCE)); + try { + return javaSourceNameToClass(javaSourceName); + } catch (ClassNotFoundException e) { + logger.info("Class not found. Not an issue unless the generated code causes a compile error : domainClassInstance = {} ", domainClassInstance); + return null; + } + } + } diff --git a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/IndexTest.java b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/IndexTest.java index 05f9ec9da4b..fc4efd634fc 100644 --- a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/IndexTest.java +++ b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/IndexTest.java @@ -22,6 +22,7 @@ import org.drools.core.impl.InternalKnowledgeBase; import org.drools.core.reteoo.AlphaNode; import org.drools.core.reteoo.BetaNode; +import org.drools.core.reteoo.CompositeObjectSinkAdapter; import org.drools.core.reteoo.EntryPointNode; import org.drools.core.reteoo.ObjectSink; import org.drools.core.reteoo.ObjectTypeNode; @@ -368,4 +369,79 @@ public void testBetaIndexOn4ValuesOnLeftTuple() { assertEquals( 1, ksession.fireAllRules() ); } + + @Test + public void testAlphaIndexHashed() { + String str = + "import " + Person.class.getCanonicalName() + ";" + + "rule R1 when\n" + + " Person( age == 10 )\n" + + "then\n" + + "end\n" + + "rule R2 when\n" + + " Person( age == 20 )\n" + + "then\n" + + "end\n" + + "rule R3 when\n" + + " Person( age == 30 )\n" + + "then\n" + + "end\n"; + + KieSession ksession = getKieSession(str); + + assertHashIndex(ksession, Person.class, 3); + } + + @Test + public void testAlphaIndexHashedNonGetter() { + String str = + "import " + Person.class.getCanonicalName() + ";" + + "rule R1 when\n" + + " Person( calcAge == 10 )\n" + + "then\n" + + "end\n" + + "rule R2 when\n" + + " Person( calcAge == 20 )\n" + + "then\n" + + "end\n" + + "rule R3 when\n" + + " Person( calcAge == 30 )\n" + + "then\n" + + "end\n"; + + KieSession ksession = getKieSession(str); + + assertHashIndex(ksession, Person.class, 3); + } + + private void assertHashIndex(KieSession ksession, Class factClass, int expectedHashedSinkMapSize) { + EntryPointNode epn = ((InternalKnowledgeBase) ksession.getKieBase()).getRete().getEntryPointNodes().values().iterator().next(); + ObjectTypeNode otn = epn.getObjectTypeNodes().get(new ClassObjectType(factClass)); + CompositeObjectSinkAdapter compositeObjectSinkAdapter = (CompositeObjectSinkAdapter) otn.getObjectSinkPropagator(); + + assertNotNull(compositeObjectSinkAdapter.getHashedSinkMap()); + assertEquals(expectedHashedSinkMapSize, compositeObjectSinkAdapter.getHashedSinkMap().size()); + } + + @Test + public void testAlphaIndexHashedPrimitiveWrapper() { + String str = + "import " + Integer.class.getCanonicalName() + ";\n" + + "rule R1 when\n" + + " Integer( intValue == 10 )\n" + + "then\n" + + "end\n" + + "rule R2 when\n" + + " Integer( intValue == 20 )\n" + + "then\n" + + "end\n" + + "rule R3 when\n" + + " Integer( intValue == 30 )\n" + + "then\n" + + "end\n"; + + KieSession ksession = getKieSession(str); + + assertHashIndex(ksession, Integer.class, 3); + } } diff --git a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/domain/Person.java b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/domain/Person.java index 529c52d48ae..61e4ed0df4f 100644 --- a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/domain/Person.java +++ b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/domain/Person.java @@ -108,6 +108,10 @@ public int getAge() { return age; } + public int calcAge() { + return age; + } + public Integer getAgeBoxed() { return age; }