Skip to content

Commit

Permalink
[DROOLS-1026] Allow FromNodes sharing
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Mar 15, 2016
1 parent 1b3c16b commit b9b0618
Show file tree
Hide file tree
Showing 27 changed files with 448 additions and 208 deletions.
Expand Up @@ -96,29 +96,28 @@ protected StatelessKnowledgeSession createStatelessKnowledgeSession(KnowledgeBas
} }


protected KnowledgeBase loadKnowledgeBaseFromString(String... drlContentStrings) { protected KnowledgeBase loadKnowledgeBaseFromString(String... drlContentStrings) {
return loadKnowledgeBaseFromString(null, null, phreak, null, return loadKnowledgeBaseFromString(null, null, phreak, drlContentStrings);
drlContentStrings);
} }


protected KnowledgeBase loadKnowledgeBaseFromString(NodeFactory nodeFactory, String... drlContentStrings) { protected KnowledgeBase loadKnowledgeBaseFromString(NodeFactory nodeFactory, String... drlContentStrings) {
return loadKnowledgeBaseFromString(null, null, phreak, nodeFactory, return loadKnowledgeBaseFromString(null, null, phreak, nodeFactory, drlContentStrings);
drlContentStrings);
} }


protected KnowledgeBase loadKnowledgeBaseFromString(RuleEngineOption phreak, String... drlContentStrings) { protected KnowledgeBase loadKnowledgeBaseFromString(RuleEngineOption phreak, String... drlContentStrings) {
return loadKnowledgeBaseFromString(null, null, phreak, null, return loadKnowledgeBaseFromString(null, null, phreak, drlContentStrings);
drlContentStrings);
} }


protected KnowledgeBase loadKnowledgeBaseFromString(KnowledgeBuilderConfiguration config, String... drlContentStrings) { protected KnowledgeBase loadKnowledgeBaseFromString(KnowledgeBuilderConfiguration config, String... drlContentStrings) {
return loadKnowledgeBaseFromString(config, null, phreak, null, return loadKnowledgeBaseFromString(config, null, phreak, drlContentStrings);
drlContentStrings);
} }


protected KnowledgeBase loadKnowledgeBaseFromString( protected KnowledgeBase loadKnowledgeBaseFromString(
KieBaseConfiguration kBaseConfig, String... drlContentStrings) { KieBaseConfiguration kBaseConfig, String... drlContentStrings) {
return loadKnowledgeBaseFromString(null, kBaseConfig, phreak, null, return loadKnowledgeBaseFromString(null, kBaseConfig, phreak, drlContentStrings);
drlContentStrings); }

protected KnowledgeBase loadKnowledgeBaseFromString( KnowledgeBuilderConfiguration config, KieBaseConfiguration kBaseConfig, RuleEngineOption phreak, String... drlContentStrings) {
return loadKnowledgeBaseFromString( config, kBaseConfig, phreak, (NodeFactory)null, drlContentStrings);
} }


protected KnowledgeBase loadKnowledgeBaseFromString( KnowledgeBuilderConfiguration config, KieBaseConfiguration kBaseConfig, RuleEngineOption phreak, NodeFactory nodeFactory, String... drlContentStrings) { protected KnowledgeBase loadKnowledgeBaseFromString( KnowledgeBuilderConfiguration config, KieBaseConfiguration kBaseConfig, RuleEngineOption phreak, NodeFactory nodeFactory, String... drlContentStrings) {
Expand Down
Expand Up @@ -2983,13 +2983,11 @@ private KieSession getKieSessionFromResources( String... classPathResources ) {
} }


private KieBase loadKieBaseFromString( String... drlContentStrings ) { private KieBase loadKieBaseFromString( String... drlContentStrings ) {
return loadKnowledgeBaseFromString( null, null, phreak, null, return loadKnowledgeBaseFromString( null, null, phreak, drlContentStrings );
drlContentStrings );
} }


private KieSession getKieSessionFromContentStrings( String... drlContentStrings ) { private KieSession getKieSessionFromContentStrings( String... drlContentStrings ) {
KieBase kbase = loadKnowledgeBaseFromString( null, null, phreak, null, KieBase kbase = loadKnowledgeBaseFromString( null, null, phreak, drlContentStrings );
drlContentStrings );
return kbase.newKieSession(); return kbase.newKieSession();
} }


Expand Down
Expand Up @@ -110,7 +110,7 @@ public void testModifyWithLiaToFrom() {
str += "global java.util.List list \n"; str += "global java.util.List list \n";
str += "rule x1 \n"; str += "rule x1 \n";
str += "when \n"; str += "when \n";
str += " $pe : Person() from list\n"; str += " $pe : Person() from list\n";
str += "then \n"; str += "then \n";
str += "end \n"; str += "end \n";
str += "rule x2 \n"; str += "rule x2 \n";
Expand All @@ -120,7 +120,7 @@ public void testModifyWithLiaToFrom() {
str += "end \n"; str += "end \n";
str += "rule x3 \n"; str += "rule x3 \n";
str += "when \n"; str += "when \n";
str += " $ch : Cheese() from list\n"; str += " $ch : Cheese() from list\n";
str += "then \n"; str += "then \n";
str += "end \n"; str += "end \n";
str += "rule x4 \n"; str += "rule x4 \n";
Expand All @@ -141,11 +141,10 @@ public void testModifyWithLiaToFrom() {


LeftTupleSink[] sinks = liaNode.getSinkPropagator().getSinks(); LeftTupleSink[] sinks = liaNode.getSinkPropagator().getSinks();


assertEquals(2, sinks.length );
assertEquals(0, sinks[0].getLeftInputOtnId().getId() ); assertEquals(0, sinks[0].getLeftInputOtnId().getId() );
assertEquals(1, sinks[1].getLeftInputOtnId().getId() ); assertEquals(1, sinks[1].getLeftInputOtnId().getId() );
assertEquals(2, sinks[2].getLeftInputOtnId().getId() ); }
assertEquals(3, sinks[3].getLeftInputOtnId().getId() );
}


@Test @Test
public void testModifyWithLiaToAcc() { public void testModifyWithLiaToAcc() {
Expand Down
@@ -0,0 +1,169 @@
/*
* Copyright 2016 Red Hat, Inc. and/or its affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.drools.compiler.integrationtests;

import org.drools.compiler.Cheese;
import org.drools.compiler.Cheesery;
import org.drools.core.base.ClassObjectType;
import org.drools.core.impl.InternalKnowledgeBase;
import org.drools.core.reteoo.EntryPointNode;
import org.drools.core.reteoo.LeftInputAdapterNode;
import org.drools.core.reteoo.LeftTupleSink;
import org.drools.core.reteoo.ObjectTypeNode;
import org.junit.Test;
import org.kie.api.KieBase;
import org.kie.api.io.ResourceType;
import org.kie.api.runtime.KieSession;
import org.kie.api.runtime.rule.FactHandle;
import org.kie.internal.utils.KieHelper;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import static org.junit.Assert.assertEquals;

public class FromTest {

public static class ListsContainer {
public List<String> getList1() {
return Arrays.asList( "a", "bb", "ccc" );
}
public List<String> getList2() {
return Arrays.asList( "1", "22", "333" );
}
}

@Test
public void testFromSharing() {
String drl =
"import " + ListsContainer.class.getCanonicalName() + "\n" +
"global java.util.List output1;\n" +
"global java.util.List output2;\n" +
"rule R1 when\n" +
" ListsContainer( $list : list1 )\n" +
" $s : String( length == 2 ) from $list\n" +
"then\n" +
" output1.add($s);\n" +
"end\n" +
"rule R2 when\n" +
" ListsContainer( $list : list2 )\n" +
" $s : String( length == 2 ) from $list\n" +
"then\n" +
" output2.add($s);\n" +
"end\n" +
"rule R3 when\n" +
" ListsContainer( $list : list2 )\n" +
" $s : String( length == 2 ) from $list\n" +
"then\n" +
" output2.add($s);\n" +
"end\n";

KieBase kbase = new KieHelper().addContent( drl, ResourceType.DRL ).build();
KieSession ksession = kbase.newKieSession();

List<String> output1 = new ArrayList<String>();
ksession.setGlobal( "output1", output1 );
List<String> output2 = new ArrayList<String>();
ksession.setGlobal( "output2", output2 );

FactHandle fh = ksession.insert( new ListsContainer() );
ksession.fireAllRules();

assertEquals("bb", output1.get( 0 ));
assertEquals("22", output2.get( 0 ));
assertEquals("22", output2.get( 1 ));

EntryPointNode epn = ( (InternalKnowledgeBase)kbase ).getRete().getEntryPointNodes().values().iterator().next();
ObjectTypeNode otn = epn.getObjectTypeNodes().get( new ClassObjectType( ListsContainer.class ) );
LeftInputAdapterNode lian = (LeftInputAdapterNode)otn.getObjectSinkPropagator().getSinks()[0];

// There are only 2 FromNodes since R2 and R3 are sharing the second From
LeftTupleSink[] sinks = lian.getSinkPropagator().getSinks();
assertEquals( 2, sinks.length );

// The first from has R1 has sink
assertEquals( 1, sinks[0].getSinkPropagator().size() );

// The second from has both R2 and R3 as sinks
assertEquals( 2, sinks[1].getSinkPropagator().size() );
}

@Test
public void testFromSharingWithAccumulate() {
String drl =
"package org.drools.compiler\n" +
"\n" +
"import java.util.List;\n" +
"import java.util.ArrayList;\n" +
"\n" +
"global java.util.List output1;\n" +
"global java.util.List output2;\n" +
"\n" +
"rule R1\n" +
" when\n" +
" $cheesery : Cheesery()\n" +
" $list : List( ) from accumulate( $cheese : Cheese( ) from $cheesery.getCheeses(),\n" +
" init( List l = new ArrayList(); ),\n" +
" action( l.add( $cheese ); )\n" +
" result( l ) )\n" +
" then\n" +
" output1.add( $list );\n" +
"end\n" +
"rule R2\n" +
" when\n" +
" $cheesery : Cheesery()\n" +
" $list : List( ) from accumulate( $cheese : Cheese( ) from $cheesery.getCheeses(),\n" +
" init( List l = new ArrayList(); ),\n" +
" action( l.add( $cheese ); )\n" +
" result( l ) )\n" +
" then\n" +
" output2.add( $list );\n" +
"end\n";

KieBase kbase = new KieHelper().addContent( drl, ResourceType.DRL ).build();
KieSession ksession = kbase.newKieSession();

List<?> output1 = new ArrayList<Object>();
ksession.setGlobal( "output1", output1 );
List<?> output2 = new ArrayList<Object>();
ksession.setGlobal( "output2", output2 );

Cheesery cheesery = new Cheesery();
cheesery.addCheese( new Cheese( "stilton", 8 ) );
cheesery.addCheese( new Cheese( "provolone", 8 ) );

FactHandle cheeseryHandle = ksession.insert( cheesery );

ksession.fireAllRules();
assertEquals( 1, output1.size() );
assertEquals( 2, ( (List) output1.get( 0 ) ).size() );
assertEquals( 1, output2.size() );
assertEquals( 2, ( (List) output2.get( 0 ) ).size() );

output1.clear();
output2.clear();

ksession.update( cheeseryHandle, cheesery );
ksession.fireAllRules();

assertEquals( 1, output1.size() );
assertEquals( 2, ( (List) output1.get( 0 ) ).size() );
assertEquals( 1, output2.size() );
assertEquals( 2, ( (List) output2.get( 0 ) ).size() );
}
}
Expand Up @@ -48,13 +48,14 @@
import org.kie.api.runtime.rule.ViewChangedEventListener; import org.kie.api.runtime.rule.ViewChangedEventListener;
import org.kie.internal.KnowledgeBase; import org.kie.internal.KnowledgeBase;
import org.kie.internal.runtime.StatefulKnowledgeSession; import org.kie.internal.runtime.StatefulKnowledgeSession;
import static org.drools.compiler.integrationtests.incrementalcompilation.IncrementalCompilationTest.rulestoMap;


import java.util.ArrayList; import java.util.ArrayList;
import java.util.HashMap; import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;


import static org.drools.core.util.DroolsTestUtil.rulestoMap;



public class IndexingTest extends CommonTestMethodBase { public class IndexingTest extends CommonTestMethodBase {


Expand Down
Expand Up @@ -5,18 +5,24 @@
import org.drools.compiler.Person; import org.drools.compiler.Person;
import org.drools.core.base.ClassObjectType; import org.drools.core.base.ClassObjectType;
import org.drools.core.impl.KnowledgeBaseImpl; import org.drools.core.impl.KnowledgeBaseImpl;
import org.drools.core.reteoo.*; import org.drools.core.reteoo.AlphaNode;
import org.drools.core.reteoo.JoinNode;
import org.drools.core.reteoo.LeftInputAdapterNode;
import org.drools.core.reteoo.MethodCountingAlphaNode;
import org.drools.core.reteoo.MethodCountingLeftInputAdapterNode;
import org.drools.core.reteoo.MethodCountingObjectTypeNode;
import org.drools.core.reteoo.ObjectTypeNode;
import org.drools.core.reteoo.RuleTerminalNode;
import org.drools.core.reteoo.builder.MethodCountingNodeFactory; import org.drools.core.reteoo.builder.MethodCountingNodeFactory;
import org.drools.core.reteoo.builder.NodeFactory; import org.drools.core.reteoo.builder.NodeFactory;
import org.junit.Test; import org.junit.Test;
import org.kie.api.definition.rule.Rule; import org.kie.api.definition.rule.Rule;
import org.kie.internal.KnowledgeBase; import org.kie.internal.KnowledgeBase;


import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;


import static org.drools.compiler.integrationtests.IncrementalCompilationTest.rulestoMap; import static org.drools.core.util.DroolsTestUtil.rulestoMap;


public class SharingTest extends CommonTestMethodBase { public class SharingTest extends CommonTestMethodBase {


Expand All @@ -30,19 +36,20 @@ public class SharingTest extends CommonTestMethodBase {
private JoinNode joinNode; private JoinNode joinNode;





public void setupKnowledgeBase() throws Exception { public void setupKnowledgeBase() throws Exception {
NodeFactory nodeFactory = MethodCountingNodeFactory.getInstance(); NodeFactory nodeFactory = MethodCountingNodeFactory.getInstance();
kbase = loadKnowledgeBaseFromString( nodeFactory, getRules()); kbase = loadKnowledgeBaseFromString( nodeFactory, getRules());
rules = rulestoMap(kbase); rules = rulestoMap(kbase);
otn = getObjectTypeNode(kbase, Person.class ); otn = getObjectTypeNode(kbase, Person.class );
alphaNode_1 = ( AlphaNode ) otn.getSinkPropagator().getSinks()[0]; //AlphaNode name == "Mark" alphaNode_1 = ( AlphaNode ) otn.getObjectSinkPropagator().getSinks()[0]; //AlphaNode name == "Mark"
alphaNode_2 = (AlphaNode) alphaNode_1.getSinkPropagator().getSinks()[0]; // 2nd level (age = 50) alphaNode_2 = (AlphaNode) alphaNode_1.getObjectSinkPropagator().getSinks()[0]; // 2nd level (age = 50)


lian_1 = (LeftInputAdapterNode) alphaNode_1.getSinkPropagator().getSinks()[1]; lian_1 = (LeftInputAdapterNode) alphaNode_1.getObjectSinkPropagator().getSinks()[1];
lian_2 = (LeftInputAdapterNode) alphaNode_2.getSinkPropagator().getSinks()[0]; lian_2 = (LeftInputAdapterNode) alphaNode_2.getObjectSinkPropagator().getSinks()[0];


AlphaNode an = ( AlphaNode ) otn.getSinkPropagator().getSinks()[1]; // name = "John" AlphaNode an = ( AlphaNode ) otn.getObjectSinkPropagator().getSinks()[1]; // name = "John"
LeftInputAdapterNode lian =(LeftInputAdapterNode) an.getSinkPropagator().getSinks()[1]; LeftInputAdapterNode lian =(LeftInputAdapterNode) an.getObjectSinkPropagator().getSinks()[1];
joinNode = (JoinNode) lian.getSinkPropagator().getSinks()[0]; //this == $personCheese joinNode = (JoinNode) lian.getSinkPropagator().getSinks()[0]; //this == $personCheese
} }


Expand Down Expand Up @@ -101,19 +108,18 @@ private String getRules() {
return drl; return drl;
} }


//@Test(timeout=10000) @Test
@Test()
public void testOTNSharing() throws Exception { public void testOTNSharing() throws Exception {
setupKnowledgeBase(); setupKnowledgeBase();
assertEquals( 2, otn.getSinkPropagator().size() ); assertEquals( 2, otn.getObjectSinkPropagator().size() );
assertEquals(7, otn.getAssociationsSize()); assertEquals(7, otn.getAssociationsSize());
} }


@Test @Test
public void testAlphaNodeSharing() throws Exception { public void testAlphaNodeSharing() throws Exception {
setupKnowledgeBase(); setupKnowledgeBase();


assertEquals( 2, alphaNode_1.getSinkPropagator().size()); assertEquals( 2, alphaNode_1.getObjectSinkPropagator().size());
assertEquals( 4, alphaNode_1.getAssociationsSize() ); assertEquals( 4, alphaNode_1.getAssociationsSize() );
assertTrue( alphaNode_1.isAssociatedWith(rules.get("r1"))); assertTrue( alphaNode_1.isAssociatedWith(rules.get("r1")));
assertTrue( alphaNode_1.isAssociatedWith(rules.get("r2"))); assertTrue( alphaNode_1.isAssociatedWith(rules.get("r2")));
Expand All @@ -124,7 +130,7 @@ public void testAlphaNodeSharing() throws Exception {
assertEquals(6,countingMap.get("thisNodeEquals").intValue()); assertEquals(6,countingMap.get("thisNodeEquals").intValue());


//Check 2nd level of sharing (age = 50) //Check 2nd level of sharing (age = 50)
assertEquals( 1, alphaNode_2.getSinkPropagator().size() ); assertEquals( 1, alphaNode_2.getObjectSinkPropagator().size() );
assertEquals( 2, alphaNode_2.getAssociationsSize() ); assertEquals( 2, alphaNode_2.getAssociationsSize() );
assertTrue( alphaNode_2.isAssociatedWith(rules.get("r3"))); assertTrue( alphaNode_2.isAssociatedWith(rules.get("r3")));
assertTrue( alphaNode_2.isAssociatedWith(rules.get("r4"))); assertTrue( alphaNode_2.isAssociatedWith(rules.get("r4")));
Expand Down
Expand Up @@ -47,7 +47,7 @@
import java.util.Map; import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicInteger;


import static org.drools.compiler.integrationtests.incrementalcompilation.IncrementalCompilationTest.rulestoMap; import static org.drools.core.util.DroolsTestUtil.rulestoMap;
import static org.junit.Assert.*; import static org.junit.Assert.*;


public class AddRemoveRulesTest extends AbstractAddRemoveRulesTest { public class AddRemoveRulesTest extends AbstractAddRemoveRulesTest {
Expand Down

0 comments on commit b9b0618

Please sign in to comment.