Skip to content

Commit

Permalink
[DROOLS-727] support variable unification in accumulate patterns
Browse files Browse the repository at this point in the history
  • Loading branch information
sotty authored and mariofusco committed Feb 25, 2015
1 parent 9b0eef0 commit a2888da
Show file tree
Hide file tree
Showing 16 changed files with 305 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ public Object end(final String uri,
} else if ( localName.equals( "external-function" ) ) {
accumulate.addFunction( element.getAttribute( "evaluator" ),
null, // no support to bindings yet?
false,
new String[] { element.getAttribute( "expression" ) });
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3793,6 +3793,7 @@ private void fromAccumulate( PatternDescrBuilder< ? > pattern ) throws Recogniti
} else {
// accumulate functions
accumulateFunction( accumulate,
false,
null );
if ( state.failed ) return;
}
Expand All @@ -3818,8 +3819,16 @@ private void fromAccumulate( PatternDescrBuilder< ? > pattern ) throws Recogniti
* @throws RecognitionException
*/
private void accumulateFunctionBinding( AccumulateDescrBuilder< ? > accumulate ) throws RecognitionException {
String label = label( DroolsEditorType.IDENTIFIER_VARIABLE );
String label = null;
boolean unif = false;
if (input.LA(2) == DRL6Lexer.COLON) {
label = label(DroolsEditorType.IDENTIFIER_VARIABLE);
} else if (input.LA(2) == DRL6Lexer.UNIFY) {
label = unif(DroolsEditorType.IDENTIFIER_VARIABLE);
unif = true;
}
accumulateFunction( accumulate,
unif,
label );
}

Expand All @@ -3829,6 +3838,7 @@ private void accumulateFunctionBinding( AccumulateDescrBuilder< ? > accumulate )
* @throws RecognitionException
*/
private void accumulateFunction( AccumulateDescrBuilder< ? > accumulate,
boolean unif,
String label ) throws RecognitionException {
Token function = match( input,
DRL5Lexer.ID,
Expand All @@ -3843,6 +3853,7 @@ private void accumulateFunction( AccumulateDescrBuilder< ? > accumulate,
if ( state.backtracking == 0 ) {
accumulate.function( function.getText(),
label,
unif,
parameters.toArray( new String[parameters.size()] ) );
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4122,7 +4122,8 @@ private void fromAccumulate(PatternDescrBuilder<?> pattern) throws RecognitionEx
} else {
// accumulate functions
accumulateFunction(accumulate,
null);
false,
null);
if (state.failed)
return;
}
Expand All @@ -4148,10 +4149,18 @@ private void fromAccumulate(PatternDescrBuilder<?> pattern) throws RecognitionEx
* @param accumulate
* @throws org.antlr.runtime.RecognitionException
*/
private void accumulateFunctionBinding(AccumulateDescrBuilder<?> accumulate) throws RecognitionException {
String label = label(DroolsEditorType.IDENTIFIER_VARIABLE);
accumulateFunction(accumulate,
label);
private void accumulateFunctionBinding( AccumulateDescrBuilder<?> accumulate ) throws RecognitionException {
String label = null;
boolean unif = false;
if (input.LA(2) == DRL6Lexer.COLON) {
label = label(DroolsEditorType.IDENTIFIER_VARIABLE);
} else if (input.LA(2) == DRL6Lexer.UNIFY) {
label = unif(DroolsEditorType.IDENTIFIER_VARIABLE);
unif = true;
}
accumulateFunction( accumulate,
unif,
label );
}

/**
Expand All @@ -4160,7 +4169,8 @@ private void accumulateFunctionBinding(AccumulateDescrBuilder<?> accumulate) thr
* @throws org.antlr.runtime.RecognitionException
*/
private void accumulateFunction(AccumulateDescrBuilder<?> accumulate,
String label) throws RecognitionException {
boolean unif,
String label) throws RecognitionException {
Token function = match(input,
DRL6Lexer.ID,
null,
Expand All @@ -4176,6 +4186,7 @@ private void accumulateFunction(AccumulateDescrBuilder<?> accumulate,
if (state.backtracking == 0) {
accumulate.function(function.getText(),
label,
unif,
parameters.toArray(new String[parameters.size()]));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4140,7 +4140,8 @@ private void fromAccumulate(PatternDescrBuilder<?> pattern) throws RecognitionEx
} else {
// accumulate functions
accumulateFunction(accumulate,
null);
false,
null);
if (state.failed)
return;
}
Expand All @@ -4167,9 +4168,17 @@ private void fromAccumulate(PatternDescrBuilder<?> pattern) throws RecognitionEx
* @throws org.antlr.runtime.RecognitionException
*/
private void accumulateFunctionBinding(AccumulateDescrBuilder<?> accumulate) throws RecognitionException {
String label = label(DroolsEditorType.IDENTIFIER_VARIABLE);
accumulateFunction(accumulate,
label);
String label = null;
boolean unif = false;
if (input.LA(2) == DRL6Lexer.COLON) {
label = label(DroolsEditorType.IDENTIFIER_VARIABLE);
} else if (input.LA(2) == DRL6Lexer.UNIFY) {
label = unif(DroolsEditorType.IDENTIFIER_VARIABLE);
unif = true;
}
accumulateFunction( accumulate,
unif,
label );
}

/**
Expand All @@ -4178,7 +4187,8 @@ private void accumulateFunctionBinding(AccumulateDescrBuilder<?> accumulate) thr
* @throws org.antlr.runtime.RecognitionException
*/
private void accumulateFunction(AccumulateDescrBuilder<?> accumulate,
String label) throws RecognitionException {
boolean unif,
String label) throws RecognitionException {
Token function = match(input,
DRL6Lexer.ID,
null,
Expand All @@ -4194,6 +4204,7 @@ private void accumulateFunction(AccumulateDescrBuilder<?> accumulate,
if (state.backtracking == 0) {
accumulate.function(function.getText(),
label,
unif,
parameters.toArray(new String[parameters.size()]));
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ public interface AccumulateDescrBuilder<P extends DescrBuilder< ?, ? >>
*
* @param name the name of the function being called. Mandatory non-null parameter.
* @param bind the name of the bound variable if there is one. Null if no binding should be made.
* @param parameters the array of parameters to the function.
*
* @return itself, so that it can be used as a fluent API
* @param isUnification true if the bound variable is expected to unify with the result of the acc function. false otherwise
*@param parameters the array of parameters to the function.
* @return itself, so that it can be used as a fluent API
*/
public AccumulateDescrBuilder<P> function( String name, String bind, String... parameters);
public AccumulateDescrBuilder<P> function( String name, String bind, boolean isUnification, String... parameters );

/**
* For accumulate CEs that use custom code blocks, this call
* sets the content of the init code block. Please node that the
Expand Down Expand Up @@ -125,4 +125,6 @@ public interface AccumulateDescrBuilder<P extends DescrBuilder< ?, ? >>
public AccumulateDescrBuilder<P> result( String expr );

public P end();

public AccumulateDescrBuilder<P> constraint( String constr );
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import org.drools.compiler.lang.api.PatternDescrBuilder;
import org.drools.compiler.lang.descr.AccumulateDescr;
import org.drools.compiler.lang.descr.AndDescr;
import org.drools.compiler.lang.descr.BaseDescr;
import org.drools.compiler.lang.descr.ConditionalElementDescr;
import org.drools.compiler.lang.descr.ExprConstraintDescr;
import org.drools.compiler.lang.descr.PatternDescr;

import java.util.List;

/**
* An implementation for the CollectDescrBuilder
Expand Down Expand Up @@ -63,9 +69,11 @@ public CEDescrBuilder<AccumulateDescrBuilder<P>, AndDescr> source() {

public AccumulateDescrBuilder<P> function( String name,
String bind,
boolean unif,
String... parameters ) {
descr.addFunction( name,
bind,
unif,
parameters );
return this;
}
Expand All @@ -89,4 +97,15 @@ public AccumulateDescrBuilder<P> result( String expr ) {
descr.setResultCode( expr );
return this;
}

@Override
public AccumulateDescrBuilder<P> constraint( String constr ) {
if ( parent instanceof PatternDescrBuilder ) {
( (PatternDescrBuilder) parent ).constraint( constr );
} else if ( parent instanceof CEDescrBuilder ) {
List<? extends BaseDescr> args = ((ConditionalElementDescr) parent.getDescr()).getDescrs();
( (PatternDescr) args.get( args.size() - 1 ) ).addConstraint( new ExprConstraintDescr( constr ) );
}
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,12 @@ public List<AccumulateFunctionCallDescr> getFunctions() {
}

public void addFunction( String function,
String bind,
String bind,
boolean unify,
String[] params ) {
addFunction( new AccumulateFunctionCallDescr( function,
bind,
bind,
unify,
params ) );
}

Expand Down Expand Up @@ -227,13 +229,16 @@ public static class AccumulateFunctionCallDescr

private final String function;
private final String bind;
private final boolean unification;
private final String[] params;

public AccumulateFunctionCallDescr(String function,
String bind,
boolean unify,
String[] params) {
this.function = function;
this.bind = bind;
this.unification = unify;
this.params = params;
}

Expand All @@ -249,6 +254,10 @@ public String[] getParams() {
return params;
}

public boolean isUnification() {
return unification;
}

@Override
public int hashCode() {
final int prime = 31;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,18 +248,19 @@ public RuleConditionElement build( RuleBuildContext context,
return rce;
}

boolean duplicateBindings = objectType instanceof ClassObjectType &&
String patternIdentifier = patternDescr.getIdentifier();
boolean duplicateBindings = patternIdentifier != null && objectType instanceof ClassObjectType &&
context.getDeclarationResolver().isDuplicated( context.getRule(),
patternDescr.getIdentifier(),
patternIdentifier,
((ClassObjectType) objectType).getClassName() );

Pattern pattern;
if ( !StringUtils.isEmpty( patternDescr.getIdentifier() ) && !duplicateBindings ) {
if ( !StringUtils.isEmpty( patternIdentifier ) && !duplicateBindings ) {

pattern = new Pattern( context.getNextPatternId(),
0, // offset is 0 by default
objectType,
patternDescr.getIdentifier(),
patternIdentifier,
patternDescr.isInternalFact() );
if ( objectType instanceof ClassObjectType ) {
// make sure PatternExtractor is wired up to correct ClassObjectType and set as a target for rewiring
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,16 @@
import org.drools.core.rule.Accumulate;
import org.drools.core.rule.Declaration;
import org.drools.core.rule.MultiAccumulate;
import org.drools.core.rule.MutableTypeConstraint;
import org.drools.core.rule.Pattern;
import org.drools.core.rule.RuleConditionElement;
import org.drools.core.rule.SingleAccumulate;
import org.drools.core.rule.constraint.MvelConstraint;
import org.drools.core.spi.Accumulator;
import org.drools.core.spi.Constraint;
import org.drools.core.spi.DeclarationScopeResolver;
import org.drools.core.spi.InternalReadAccessor;
import org.drools.core.util.index.IndexUtil;
import org.kie.api.runtime.rule.AccumulateFunction;

import java.util.Arrays;
Expand Down Expand Up @@ -150,7 +154,7 @@ private Accumulate buildExternalFunctionCall( final RuleBuildContext context,
return null;
}

bindReaderToDeclaration(context, accumDescr, pattern, fc, new ArrayElementReader(reader, index, function.getResultType()));
bindReaderToDeclaration(context, accumDescr, pattern, fc, new ArrayElementReader(reader, index, function.getResultType()), function.getResultType(), index);
accumulators[index++] = buildAccumulator(context, accumDescr, declsInScope, declCls, readLocalsFromTuple, sourceDeclArr, requiredDecl, fc, function);
}

Expand All @@ -164,7 +168,7 @@ private Accumulate buildExternalFunctionCall( final RuleBuildContext context,
return null;
}

bindReaderToDeclaration(context, accumDescr, pattern, fc, new SelfReferenceClassFieldReader( function.getResultType(), "this" ));
bindReaderToDeclaration(context, accumDescr, pattern, fc, new SelfReferenceClassFieldReader( function.getResultType(), "this" ), function.getResultType(), -1);
Accumulator accumulator = buildAccumulator(context, accumDescr, declsInScope, declCls, readLocalsFromTuple, sourceDeclArr, requiredDecl, fc, function);

return new SingleAccumulate( source,
Expand All @@ -173,17 +177,35 @@ private Accumulate buildExternalFunctionCall( final RuleBuildContext context,
}
}

private void bindReaderToDeclaration(RuleBuildContext context, AccumulateDescr accumDescr, Pattern pattern, AccumulateFunctionCallDescr fc, InternalReadAccessor readAccessor) {
// if there is a binding, create the binding
private void bindReaderToDeclaration( RuleBuildContext context, AccumulateDescr accumDescr, Pattern pattern, AccumulateFunctionCallDescr fc, InternalReadAccessor readAccessor, Class<?> resultType, int index ) {
if ( fc.getBind() != null ) {
if ( pattern.getDeclaration( fc.getBind() ) != null ) {
context.addError(new DescrBuildError(context.getParentDescr(),
accumDescr,
null,
"Duplicate declaration for variable '" + fc.getBind() + "' in the rule '" + context.getRule().getName() + "'"));
if ( context.getDeclarationResolver().isDuplicated( context.getRule(), fc.getBind(), resultType.getName() ) ) {
if ( ! fc.isUnification() ) {
context.addError( new DescrBuildError( context.getParentDescr(),
accumDescr,
null,
"Duplicate declaration for variable '" + fc.getBind() + "' in the rule '" + context.getRule().getName() + "'" ) );
} else {
Declaration inner = context.getDeclarationResolver().getDeclaration( context.getRule(), fc.getBind() );
Constraint c = new MvelConstraint( Arrays.asList( context.getPkg().getName() ),
index >= 0
? "this[ " + index + " ] == " + fc.getBind()
: "this == " + fc.getBind(),
new Declaration[] { inner },
null,
IndexUtil.ConstraintType.EQUAL,
context.getDeclarationResolver().getDeclaration( context.getRule(), fc.getBind() ),
index >= 0
? new ArrayElementReader( readAccessor, index, resultType )
: readAccessor,
true);
(( MutableTypeConstraint) c).setType( Constraint.ConstraintType.BETA );
pattern.addConstraint( c );
index++;
}
} else {
Declaration declr = pattern.addDeclaration( fc.getBind() );
declr.setReadAccessor(readAccessor);
declr.setReadAccessor( readAccessor );
}
}
}
Expand Down Expand Up @@ -253,7 +275,7 @@ private JavaAccumulatorFunctionExecutor generateFunctionCallCodeTemplate( final
final boolean readLocalsFromTuple ) {
final String className = "accumulateExpression" + context.getNextId();
final Map<String, Object> map = createVariableContext( className,
fc.getParams().length > 0 ? fc.getParams()[0] : "\"\"",
fc.getParams().length > 0 ? fc.getParams()[ 0 ] : "\"\"",
context,
previousDeclarations,
sourceDeclArr,
Expand Down
Loading

0 comments on commit a2888da

Please sign in to comment.