Skip to content

Commit

Permalink
[DROOLS-1175][DROOLS-1242] infer numeric type for sum expression in a…
Browse files Browse the repository at this point in the history
…n accumulate pattern + make accumulate functions null safe (#867)
  • Loading branch information
mariofusco committed Aug 9, 2016
1 parent 1615518 commit 42d7be9
Show file tree
Hide file tree
Showing 40 changed files with 786 additions and 280 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,6 @@ public KnowledgeBuilderConfigurationImpl(ClassLoader... classLoaders) {

/**
* Programmatic properties file, added with lease precedence
* @param properties
*/
public KnowledgeBuilderConfigurationImpl(Properties properties) {
init(properties,
Expand All @@ -170,8 +169,6 @@ public KnowledgeBuilderConfigurationImpl(Properties properties) {

/**
* Programmatic properties file, added with lease precedence
* @param classLoaders
* @param properties
*/
public KnowledgeBuilderConfigurationImpl(Properties properties,
ClassLoader... classLoaders) {
Expand Down Expand Up @@ -537,21 +534,6 @@ private void buildAccumulateFunctionsMap() {
}
}

/**
* This method is deprecated and will be removed
* @return
*
* @deprecated
*/
public Map<String, String> getAccumulateFunctionsMap() {
Map<String, String> result = new HashMap<String, String>();
for (Map.Entry<String, AccumulateFunction> entry : this.accumulateFunctions.entrySet()) {
result.put(entry.getKey(),
entry.getValue().getClass().getName());
}
return result;
}

public void addAccumulateFunction(String identifier,
String className) {
this.accumulateFunctions.put(identifier,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
*/
public class MemoryResourceReader implements ResourceReader {

private Map resources;
private Map<String, byte[]> resources;

private Set<String> modifiedResourcesSinceLastMark;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.drools.compiler.rule.builder.RuleBuildContext;
import org.drools.compiler.rule.builder.RuleConditionBuilder;
import org.drools.compiler.rule.builder.dialect.java.parser.JavaLocalDeclarationDescr;
import org.drools.compiler.rule.builder.dialect.mvel.MVELExprAnalyzer;
import org.drools.compiler.rule.builder.util.PackageBuilderUtil;
import org.drools.core.base.accumulators.JavaAccumulatorFunctionExecutor;
import org.drools.core.base.extractors.ArrayElementReader;
Expand All @@ -47,6 +48,8 @@
import org.drools.core.util.index.IndexUtil;
import org.kie.api.runtime.rule.AccumulateFunction;

import java.math.BigDecimal;
import java.math.BigInteger;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
Expand Down Expand Up @@ -124,12 +127,12 @@ public RuleConditionElement build( final RuleBuildContext context,
return accumulate;
}

private Accumulate buildExternalFunctionCall( final RuleBuildContext context,
final AccumulateDescr accumDescr,
final RuleConditionElement source,
private Accumulate buildExternalFunctionCall( RuleBuildContext context,
AccumulateDescr accumDescr,
RuleConditionElement source,
Map<String, Declaration> declsInScope,
Map<String, Class< ? >> declCls,
final boolean readLocalsFromTuple) {
boolean readLocalsFromTuple) {
// list of functions to build
final List<AccumulateFunctionCallDescr> funcCalls = accumDescr.getFunctions();
// list of available source declarations
Expand All @@ -150,7 +153,7 @@ private Accumulate buildExternalFunctionCall( final RuleBuildContext context,

int index = 0;
for ( AccumulateFunctionCallDescr fc : funcCalls ) {
AccumulateFunction function = getAccumulateFunction(context, accumDescr, fc);
AccumulateFunction function = getAccumulateFunction(context, accumDescr, fc, source, declCls);
if (function == null) {
return null;
}
Expand All @@ -164,7 +167,7 @@ private Accumulate buildExternalFunctionCall( final RuleBuildContext context,
accumulators );
} else {
AccumulateFunctionCallDescr fc = accumDescr.getFunctions().get(0);
AccumulateFunction function = getAccumulateFunction(context, accumDescr, fc);
AccumulateFunction function = getAccumulateFunction(context, accumDescr, fc, source, declCls);
if (function == null) {
return null;
}
Expand Down Expand Up @@ -221,22 +224,50 @@ private void bindReaderToDeclaration( RuleBuildContext context, AccumulateDescr
}
}

private AccumulateFunction getAccumulateFunction(RuleBuildContext context, AccumulateDescr accumDescr, AccumulateFunctionCallDescr fc) {
private AccumulateFunction getAccumulateFunction(RuleBuildContext context,
AccumulateDescr accumDescr,
AccumulateFunctionCallDescr fc,
RuleConditionElement source,
Map<String, Class< ? >> declCls) {
String functionName = getFunctionName( context, fc, source, declCls );

// find the corresponding function
AccumulateFunction function = context.getConfiguration().getAccumulateFunction( fc.getFunction() );
AccumulateFunction function = context.getConfiguration().getAccumulateFunction( functionName );
if( function == null ) {
// might have been imported in the package
function = context.getKnowledgeBuilder().getPackage().getAccumulateFunctions().get(fc.getFunction());
function = context.getKnowledgeBuilder().getPackage().getAccumulateFunctions().get( functionName );
}
if ( function == null ) {
context.addError( new DescrBuildError( accumDescr,
context.getRuleDescr(),
null,
"Unknown accumulate function: '" + fc.getFunction() + "' on rule '" + context.getRuleDescr().getName() + "'. All accumulate functions must be registered before building a resource." ) );
"Unknown accumulate function: '" + functionName + "' on rule '" + context.getRuleDescr().getName() + "'. All accumulate functions must be registered before building a resource." ) );
}
return function;
}

private String getFunctionName( RuleBuildContext context, AccumulateFunctionCallDescr fc, RuleConditionElement source, Map<String, Class<?>> declCls ) {
String functionName = fc.getFunction();
if (functionName.equals( "sum" )) {
Class<?> exprClass = MVELExprAnalyzer.getExpressionType( context, declCls, source, fc.getParams()[0] );
if (exprClass == int.class || exprClass == Integer.class) {
functionName = "sumI";
} else if (exprClass == long.class || exprClass == Long.class) {
functionName = "sumL";
} else if (exprClass == BigInteger.class) {
functionName = "sumBI";
} else if (exprClass == BigDecimal.class) {
functionName = "sumBD";
}
} else if (functionName.equals( "average" )) {
Class<?> exprClass = MVELExprAnalyzer.getExpressionType( context, declCls, source, fc.getParams()[0] );
if (exprClass == BigDecimal.class) {
functionName = "averageBD";
}
}
return functionName;
}

private Accumulator buildAccumulator(RuleBuildContext context, AccumulateDescr accumDescr, Map<String, Declaration> declsInScope, Map<String, Class<?>> declCls, boolean readLocalsFromTuple, Declaration[] sourceDeclArr, Set<Declaration> requiredDecl, AccumulateFunctionCallDescr fc, AccumulateFunction function) {
// analyze the expression
final JavaAnalysisResult analysis = (JavaAnalysisResult) context.getDialect().analyzeBlock( context,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ public class MVELDialect
initBuilder();
}

private static final MVELExprAnalyzer analyzer = new MVELExprAnalyzer();

private final Map interceptors = MVELCompilationUnit.INTERCEPTORS;

protected List<KnowledgeBuilderResult> results;
Expand Down Expand Up @@ -509,12 +507,12 @@ public AnalysisResult analyzeExpression(final PackageBuildContext context,
BaseDescr temp = context.getParentDescr();
context.setParentDescr( descr );
try {
result = analyzer.analyzeExpression( context,
(String) content,
availableIdentifiers,
localTypes,
"drools",
KnowledgeHelper.class );
result = MVELExprAnalyzer.analyzeExpression( context,
(String) content,
availableIdentifiers,
localTypes,
"drools",
KnowledgeHelper.class );
} catch ( final Exception e ) {
DialectUtil.copyErrorLocation( e, descr );
context.addError( new DescrBuildError( context.getParentDescr(),
Expand Down Expand Up @@ -547,12 +545,12 @@ public AnalysisResult analyzeBlock(final PackageBuildContext context,
String contextIndeifier,
Class kcontextClass) {

return analyzer.analyzeExpression( context,
text,
availableIdentifiers,
localTypes,
contextIndeifier,
kcontextClass );
return MVELExprAnalyzer.analyzeExpression( context,
text,
availableIdentifiers,
localTypes,
contextIndeifier,
kcontextClass );
}

public MVELCompilationUnit getMVELCompilationUnit(final String expression,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,17 @@
import org.drools.compiler.rule.builder.RuleBuildContext;
import org.drools.compiler.rule.builder.dialect.DialectUtil;
import org.drools.core.base.EvaluatorWrapper;
import org.drools.core.rule.Declaration;
import org.drools.core.rule.MVELDialectRuntimeData;
import org.drools.core.rule.RuleConditionElement;
import org.kie.api.definition.rule.Rule;
import org.mvel2.MVEL;
import org.mvel2.ParserConfiguration;
import org.mvel2.ParserContext;
import org.mvel2.optimizers.OptimizerFactory;
import org.mvel2.util.PropertyTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.Collections;
import java.util.HashMap;
Expand All @@ -45,6 +49,8 @@
*/
public class MVELExprAnalyzer {

private static final Logger log = LoggerFactory.getLogger( MVELExprAnalyzer.class );

static {
// always use mvel reflective optimizer
OptimizerFactory.setDefaultOptimizer(OptimizerFactory.SAFE_REFLECTIVE);
Expand All @@ -71,12 +77,12 @@ public MVELExprAnalyzer() {
* If an error occurs in the parser.
*/
@SuppressWarnings("unchecked")
public MVELAnalysisResult analyzeExpression(final PackageBuildContext context,
final String expr,
final BoundIdentifiers availableIdentifiers,
final Map<String, Class< ? >> localTypes,
String contextIndeifier,
Class kcontextClass) {
public static MVELAnalysisResult analyzeExpression(final PackageBuildContext context,
final String expr,
final BoundIdentifiers availableIdentifiers,
final Map<String, Class< ? >> localTypes,
String contextIndeifier,
Class kcontextClass) {
if ( expr.trim().length() <= 0 ) {
MVELAnalysisResult result = analyze( (Set<String>) Collections.EMPTY_SET, availableIdentifiers );
result.setMvelVariables( new HashMap<String, Class< ? >>() );
Expand Down Expand Up @@ -246,8 +252,8 @@ public MVELAnalysisResult analyzeExpression(final PackageBuildContext context,
* @throws RecognitionException
* If an error occurs in the parser.
*/
private MVELAnalysisResult analyze(final Set<String> identifiers,
final BoundIdentifiers availableIdentifiers) {
private static MVELAnalysisResult analyze(final Set<String> identifiers,
final BoundIdentifiers availableIdentifiers) {

MVELAnalysisResult result = new MVELAnalysisResult();
result.setIdentifiers( identifiers );
Expand Down Expand Up @@ -290,4 +296,28 @@ private MVELAnalysisResult analyze(final Set<String> identifiers,

return result;
}

public static Class<?> getExpressionType(PackageBuildContext context,
Map<String, Class< ? >> declCls,
RuleConditionElement source,
String expression) {
MVELDialectRuntimeData data = ( MVELDialectRuntimeData) context.getPkg().getDialectRuntimeRegistry().getDialectData( "mvel" );
ParserConfiguration conf = data.getParserConfiguration();
conf.setClassLoader( context.getKnowledgeBuilder().getRootClassLoader() );
ParserContext pctx = new ParserContext( conf );
pctx.setStrongTyping(true);
pctx.setStrictTypeEnforcement(true);
for (Map.Entry<String, Class< ? >> entry : declCls.entrySet()) {
pctx.addInput(entry.getKey(), entry.getValue());
}
for (Declaration decl : source.getOuterDeclarations().values()) {
pctx.addInput(decl.getBindingName(), decl.getDeclarationClass());
}
try {
return MVEL.analyze( expression, pctx );
} catch (Exception e) {
log.warn( "Unable to parse expression: " + expression, e );
}
return null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@ drools.dialect.mvel = org.drools.compiler.rule.builder.dialect.mvel.MVELDialectC
drools.dialect.mvel.strict = true
drools.dialect.mvel.langLevel = 4

drools.accumulate.function.average = org.drools.core.base.accumulators.AverageAccumulateFunction
drools.accumulate.function.max = org.drools.core.base.accumulators.MaxAccumulateFunction
drools.accumulate.function.min = org.drools.core.base.accumulators.MinAccumulateFunction
drools.accumulate.function.count = org.drools.core.base.accumulators.CountAccumulateFunction
drools.accumulate.function.sum = org.drools.core.base.accumulators.SumAccumulateFunction
drools.accumulate.function.collectList = org.drools.core.base.accumulators.CollectListAccumulateFunction
drools.accumulate.function.collectSet = org.drools.core.base.accumulators.CollectSetAccumulateFunction
drools.accumulate.function.sumBD = org.drools.core.base.accumulators.BigDecimalSumAccumulateFunction
drools.accumulate.function.average = org.drools.core.base.accumulators.AverageAccumulateFunction
drools.accumulate.function.averageBD = org.drools.core.base.accumulators.BigDecimalAverageAccumulateFunction
drools.accumulate.function.sum = org.drools.core.base.accumulators.SumAccumulateFunction
drools.accumulate.function.sumI = org.drools.core.base.accumulators.IntegerSumAccumulateFunction
drools.accumulate.function.sumL = org.drools.core.base.accumulators.LongSumAccumulateFunction
drools.accumulate.function.sumBI = org.drools.core.base.accumulators.BigIntegerSumAccumulateFunction
drools.accumulate.function.sumBD = org.drools.core.base.accumulators.BigDecimalSumAccumulateFunction

drools.evaluator.coincides = org.drools.core.base.evaluators.CoincidesEvaluatorDefinition
drools.evaluator.before = org.drools.core.base.evaluators.BeforeEvaluatorDefinition
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -624,17 +624,17 @@ public void testAccnSharingWithMixedDormantAndActive() {
String str = "package org.kie.test \n" +
"\n" +
"rule rule1 @Propagation(EAGER) when\n" +
" $s1 : Double() from accumulate( $i : Integer(), sum ( $i ) ) " +
" $s1 : Integer() from accumulate( $i : Integer(), sum ( $i ) ) " +
"then\n" +
"end\n" +
"rule rule2 @Propagation(EAGER) when\n" +
" $s1 : Double() from accumulate( $i : Integer(), sum ( $i ) ) " +
" $s1 : Integer() from accumulate( $i : Integer(), sum ( $i ) ) " +
" eval( 1 == 1 ) \n" +
"then\n" +
"end\n" +
"rule rule3 salience 10 when\n" +
" eval( 1 == 1 ) \n" +
" $s1 : Double() from accumulate( $i : Integer(), sum ( $i ) ) " +
" $s1 : Integer() from accumulate( $i : Integer(), sum ( $i ) ) " +
" eval( 1 == 1 ) \n" +
"then\n" +
" kcontext.getKieRuntime().halt();\n" +
Expand All @@ -658,7 +658,7 @@ public void testAccnSharingWithMixedDormantAndActive() {
list.add( act.getRule().getName() + ":" + act.getDeclarationValue( "$s1" ) + ":" + act.isQueued() );
}

assertContains( new String[]{"rule1:6.0:true", "rule2:6.0:true", "rule3:6.0:false"},
assertContains( new String[]{"rule1:6:true", "rule2:6:true", "rule3:6:false"},
list );
}

Expand Down
Loading

0 comments on commit 42d7be9

Please sign in to comment.