Skip to content

Commit

Permalink
implement forall
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Dec 14, 2017
1 parent b09cfc5 commit 831d81c
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 43 deletions.
Expand Up @@ -15,7 +15,7 @@ default List<Condition> getSubConditions() {

enum Type {
PATTERN( false ), QUERY( false ), ACCUMULATE( false ), TEMPORAL( false ), OOPATH( false ),
OR( true ), AND( true ), NOT( false ), EXISTS( false ), CONSEQUENCE( false );
OR( true ), AND( true ), NOT( false ), EXISTS( false ), FORALL( false ), CONSEQUENCE( false );

private final boolean composite;

Expand Down
Expand Up @@ -200,6 +200,10 @@ public static <T, U> ExprViewItem<T> exists(Variable<T> var1, Variable<U> var2,
return exists(new Expr2ViewItemImpl<T, U>( var1, var2, predicate) );
}

public static ExprViewItem forall(ExprViewItem expression, ExprViewItem... expressions) {
return new ExistentialExprViewItem( Condition.Type.FORALL, and( expression, expressions) );
}

public static <T> ExprViewItem<T> accumulate(ExprViewItem<T> expr, AccumulateFunction<T, ?, ?>... functions) {
return new AccumulateExprViewItem(expr, functions);
}
Expand Down
Expand Up @@ -261,6 +261,16 @@ private RuleConditionElement conditionToElement( RuleContext ctx, Condition cond
ge.addChild( conditionToElement( ctx, condition.getSubConditions().get(0) ) );
return ge;
}
case FORALL: {
Condition innerCondition = condition.getSubConditions().get(0);
Pattern basePattern = (Pattern) conditionToElement( ctx, innerCondition.getSubConditions().get(0) );
List<Pattern> remainingPatterns = new ArrayList<>();
for (int i = 1; i < innerCondition.getSubConditions().size(); i++) {
remainingPatterns.add( (Pattern) conditionToElement( ctx, innerCondition.getSubConditions().get(i) ) );
}
Forall forall = new Forall(basePattern, remainingPatterns);
return forall;
}
case CONSEQUENCE:
if (condition instanceof NamedConsequenceImpl) {
NamedConsequenceImpl consequence = (NamedConsequenceImpl) condition;
Expand Down
Expand Up @@ -16,11 +16,9 @@

package org.drools.modelcompiler.builder.generator;

import org.drools.compiler.compiler.DrlExprParser;
import org.drools.compiler.lang.descr.*;
import org.drools.core.definitions.InternalKnowledgePackage;
import org.drools.core.rule.Behavior;
import org.drools.core.rule.Pattern;
import org.drools.core.time.TimeUtils;
import org.drools.core.util.ClassUtils;
import org.drools.core.util.index.IndexUtil;
Expand All @@ -42,21 +40,20 @@
import org.drools.javaparser.ast.type.Type;
import org.drools.javaparser.ast.type.TypeParameter;
import org.drools.javaparser.ast.type.UnknownType;
import org.drools.model.*;
import org.drools.model.BitMask;
import org.drools.model.Query;
import org.drools.model.Rule;
import org.drools.model.Variable;
import org.drools.modelcompiler.builder.PackageModel;
import org.drools.modelcompiler.builder.RuleDescrImpl;
import org.kie.internal.builder.conf.LanguageLevelOption;

import java.util.*;
import java.util.Map.Entry;
import java.util.function.Consumer;
import java.util.stream.Collectors;
import java.util.stream.Stream;

import static org.drools.javaparser.printer.PrintUtil.toDrlx;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.generateLambdaWithoutParameters;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.parseBlock;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toVar;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.*;
import static org.drools.modelcompiler.builder.generator.StringUtil.toId;

public class ModelGenerator {
Expand Down Expand Up @@ -360,15 +357,17 @@ private static void visit(RuleContext context, PackageModel packageModel, BaseDe
if ( descr instanceof AndDescr) {
visit(context, packageModel, ( (AndDescr) descr ));
} else if ( descr instanceof OrDescr) {
visit( context, packageModel, ( (OrDescr) descr ));
visit( context, packageModel, ( (OrDescr) descr ), "or");
} else if ( descr instanceof PatternDescr && ((PatternDescr)descr).getSource() instanceof AccumulateDescr) {
visit( context, packageModel, ( (AccumulateDescr)((PatternDescr) descr).getSource() ));
} else if ( descr instanceof PatternDescr ) {
visit( context, packageModel, ( (PatternDescr) descr ));
} else if ( descr instanceof NotDescr) {
visit( context, packageModel, ( (NotDescr) descr ));
visit( context, packageModel, ( (NotDescr) descr ), "not");
} else if ( descr instanceof ExistsDescr) {
visit( context, packageModel, ( (ExistsDescr) descr ));
visit( context, packageModel, ( (ExistsDescr) descr ), "exists");
} else if ( descr instanceof ForallDescr) {
visit( context, packageModel, ( (ForallDescr) descr ), "forall");
} else if ( descr instanceof QueryDescr) {
visit( context, packageModel, ( (QueryDescr) descr ));
} else if ( descr instanceof NamedConsequenceDescr) {
Expand All @@ -395,6 +394,16 @@ private static void visit(RuleContext context, PackageModel packageModel, QueryD
visit(context, packageModel, descr.getLhs());
}

private static void visit( RuleContext context, PackageModel packageModel, ConditionalElementDescr descr, String methodName ) {
final MethodCallExpr ceDSL = new MethodCallExpr(null, methodName);
context.addExpression(ceDSL);
context.pushExprPointer( ceDSL::addArgument );
for (BaseDescr subDescr : descr.getDescrs()) {
visit(context, packageModel, subDescr );
}
context.popExprPointer();
}

private static void visit( RuleContext context, PackageModel packageModel, AccumulateDescr descr ) {
final MethodCallExpr accumulateDSL = new MethodCallExpr(null, "accumulate");
context.addExpression(accumulateDSL);
Expand Down Expand Up @@ -450,27 +459,6 @@ private static Class<?> getReturnTypeForAggregateFunction(String functionName, C
}
}

private static void visit( RuleContext context, PackageModel packageModel, NotDescr descr ) {
final MethodCallExpr notDSL = new MethodCallExpr(null, "not");
context.addExpression(notDSL);
context.pushExprPointer( notDSL::addArgument );
for (BaseDescr subDescr : descr.getDescrs()) {
visit(context, packageModel, subDescr );
}
context.popExprPointer();
}

private static void visit( RuleContext context, PackageModel packageModel, ExistsDescr descr ) {
final MethodCallExpr existsDSL = new MethodCallExpr(null, "exists");
context.addExpression(existsDSL);
context.pushExprPointer( existsDSL::addArgument );
for (Object subDescr : descr.getDescrs()) {
if(subDescr instanceof BaseDescr)
visit(context, packageModel, (BaseDescr)subDescr );
}
context.popExprPointer();
}

private static void visit(RuleContext context, PackageModel packageModel, AndDescr descr) {
// if it's the first (implied) `and` wrapping the first level of patterns, skip adding it to the DSL.
if ( context.getExprPointerLevel() != 1 ) {
Expand All @@ -487,16 +475,6 @@ private static void visit(RuleContext context, PackageModel packageModel, AndDes
}
}

private static void visit( RuleContext context, PackageModel packageModel, OrDescr descr ) {
final MethodCallExpr orDSL = new MethodCallExpr(null, "or");
context.addExpression(orDSL);
context.pushExprPointer( orDSL::addArgument );
for (BaseDescr subDescr : descr.getDescrs()) {
visit(context, packageModel, subDescr );
}
context.popExprPointer();
}

private static void visit(RuleContext context, PackageModel packageModel, PatternDescr pattern ) {
String className = pattern.getObjectType();

Expand Down
Expand Up @@ -501,6 +501,30 @@ public void testExists() {
assertEquals( "ok", results.iterator().next().getValue() );
}

@Test
public void testForall() {
String str =
"import " + Person.class.getCanonicalName() + ";" +
"import " + Result.class.getCanonicalName() + ";" +
"rule R when\n" +
" forall( $p : Person( name.length == 5 ) " +
" Person( this == $p, age > 40 ) )\n" +
"then\n" +
" insert(new Result(\"ok\"));\n" +
"end";

KieSession ksession = getKieSession( str );

ksession.insert( new Person( "Mario", 41 ) );
ksession.insert( new Person( "Mark", 39 ) );
ksession.insert( new Person( "Edson", 42 ) );
ksession.fireAllRules();

Collection<Result> results = getObjects( ksession, Result.class );
assertEquals( 1, results.size() );
assertEquals( "ok", results.iterator().next().getValue() );
}

@Test
public void testExistsEmptyPredicate() {
String str =
Expand Down
Expand Up @@ -218,6 +218,34 @@ public void testNot() {
assertEquals("Oldest person is Mario", result.getValue());
}

@Test
public void testForall() {
Variable<Person> p1V = declarationOf( type( Person.class ) );
Variable<Person> p2V = declarationOf( type( Person.class ) );

Rule rule = rule("not")
.build(
forall( expr( "exprA", p1V, p -> p.getName().length() == 5 ),
expr( "exprB", p2V, p1V, (p2, p1) -> p2 == p1 ),
expr( "exprC", p2V, p -> p.getAge() > 40 ) ),
execute(drools -> drools.insert( new Result("ok") ))
);

Model model = new ModelImpl().addRule( rule );
KieBase kieBase = KieBaseBuilder.createKieBaseFromModel( model );

KieSession ksession = kieBase.newKieSession();

ksession.insert( new Person( "Mario", 41 ) );
ksession.insert( new Person( "Mark", 39 ) );
ksession.insert( new Person( "Edson", 42 ) );
ksession.fireAllRules();

Collection<Result> results = getObjects( ksession, Result.class );
assertEquals( 1, results.size() );
assertEquals( "ok", results.iterator().next().getValue() );
}

@Test
public void testAccumulate1() {
Result result = new Result();
Expand Down

0 comments on commit 831d81c

Please sign in to comment.