Browse files

PMML : Fix regressions

  • Loading branch information...
1 parent 7b2f607 commit 56885cf4495e8f4bc30dfec87d506f75b7be3796 @sotty sotty committed Oct 9, 2013
View
6 drools-pmml/src/main/java/org/drools/pmml/pmml_4_1/PMML4Compiler.java
@@ -534,7 +534,7 @@ public String compile(InputStream source, Map<String,PackageRegistry> registries
}
}
- if ( getResults().isEmpty() ) {
+ if ( visitorBuildResults.isEmpty() && results.isEmpty() ) {
return generateTheory( pmml );
} else {
return null;
@@ -548,6 +548,10 @@ public String compile(InputStream source, Map<String,PackageRegistry> registries
}
+ public void clearResults() {
+ this.results.clear();
+ }
+
public void dump( String s, OutputStream ostream ) {
// write to outstream
Writer writer = null;
View
4 drools-pmml/src/main/java/org/drools/pmml_4_0/PMML4Compiler.java
@@ -27,4 +27,8 @@ public String compile( InputStream inputStream, Map<String, PackageRegistry> str
return compiler.getResults();
}
+ public void clearResults() {
+ compiler.clearResults();
+ }
+
}
View
3 ...-pmml/src/main/resources/org/drools/pmml/pmml_4_1/templates/models/neural/neuralFire.drlt
@@ -54,6 +54,9 @@ when
@if{ needsNormal } normalized == true, @end{}
$index : index,
$val : value != null )
+ accumulate( $c : Charge( context == @{ ctx }, index == $neur.index, $in : value ),
+ $num : count( $c );
+ $num == $neur.fanIn )
then
Stym y = new Stym();
y.setContext( @{ctx} );
View
38 drools-pmml/src/test/java/org/drools/pmml/pmml_4_1/predictive/models/NeuralNetworkTest.java
@@ -30,11 +30,17 @@
import org.drools.io.impl.ClassPathResource;
import org.drools.pmml.pmml_4_1.DroolsAbstractPMMLTest;
import org.drools.pmml.pmml_4_1.ModelMarker;
+import org.drools.pmml.pmml_4_1.PMML4Compiler;
import org.drools.runtime.StatefulKnowledgeSession;
+import org.drools.runtime.rule.FactHandle;
+import org.drools.runtime.rule.QueryResults;
import org.drools.runtime.rule.Variable;
import org.junit.After;
import org.junit.Test;
+import java.util.Collection;
+import java.util.Iterator;
+
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@@ -162,6 +168,38 @@ public void testCold() throws Exception {
Assert.assertEquals( 0.44, queryDoubleField( "Cold", "MockCold" ), 1e-6 );
}
+ @Test
+ public void testClearOutput() throws Exception {
+ setKSession( getModelSession( source7, VERBOSE ) );
+ setKbase( getKSession().getKnowledgeBase() );
+
+ getKSession().fireAllRules(); //init model
+
+ getKSession().getWorkingMemoryEntryPoint( "in_Temp" ).insert( 28.0 );
+
+ getKSession().fireAllRules();
+
+ Assert.assertEquals( 0.44, queryDoubleField( "Cold", "MockCold" ), 1e-6 );
+
+ for ( Object o : getKSession().getObjects() ) {
+ System.out.println( o );
+ }
+
+ FactType tempKlass = getKSession().getKnowledgeBase().getFactType( "org.drools.pmml.pmml_4_1.test", "Temp" );
+ Collection temps = getKSession().getObjects( new ClassObjectFilter( tempKlass.getFactClass() ) );
+ Iterator iter = temps.iterator();
+ Object temp = iter.next();
+
+ if ( tempKlass.get( temp, "value" ) != null ) {
+ temp = iter.next();
+ }
+ getKSession().retract( getKSession().getFactHandle( temp ) );
+ getKSession().fireAllRules();
+
+ QueryResults results = getKSession().getQueryResults( "Cold", "MockCold", Variable.v );
+ assertEquals( 0, results.size() );
+ }
+
@Test
public void testPTSD() throws Exception {

0 comments on commit 56885cf

Please sign in to comment.