Navigation Menu

Skip to content
This repository has been archived by the owner on Jan 20, 2022. It is now read-only.

Commit

Permalink
PMML : minor upgrades from 4.0 to 4.1
Browse files Browse the repository at this point in the history
  • Loading branch information
sotty committed Sep 26, 2012
1 parent c47d35e commit f7a244e
Show file tree
Hide file tree
Showing 9 changed files with 144 additions and 63 deletions.
41 changes: 32 additions & 9 deletions drools-pmml/src/main/java/org/drools/pmml_4_1/PMML4Wrapper.java
Expand Up @@ -359,20 +359,24 @@ public String mapFunction(String functor, String... args) {
String ans = "(";
if ("+".equals(functor)) {
ans += args[0];
for (int j = 1; j < args.length; j++)
for (int j = 1; j < args.length; j++) {
ans += " + " + args[j];
}
} else if ("-".equals(functor)) {
ans += args[0];
for (int j = 1; j < args.length; j++)
for (int j = 1; j < args.length; j++) {
ans += " - " + args[j];
}
} else if ("*".equals(functor)) {
ans += args[0];
for (int j = 1; j < args.length; j++)
for (int j = 1; j < args.length; j++) {
ans += " * " + args[j];
}
} else if ("/".equals(functor)) {
ans += args[0];
if (ans.length() > 1)
if (ans.length() > 1) {
ans += " / " + args[1];
}
} else if ("identity".equals(functor)) {
ans += args[0];
} else if ("min".equals(functor)) {
Expand All @@ -381,12 +385,25 @@ public String mapFunction(String functor, String... args) {
ans += associativeNaryToBinary("Math.max",0,args);
} else if ("sum".equals(functor)) {
ans += args[0];
for (int j = 1; j < args.length; j++)
for (int j = 1; j < args.length; j++) {
ans += " + " + args[j];
}
} else if ("median".equals(functor)) {
if ( args.length % 2 == 0 ) {
ans += " 0.5 * " + args[ args.length / 2 - 1 ] + " + 0.5 * " + args[ args.length / 2 ] + " ";
} else {
ans += args[ args.length / 2 ];
}
} else if ("product".equals(functor)) {
ans += args[0];
for (int j = 1; j < args.length; j++) {
ans += " * " + args[j];
}
} else if ("avg".equals(functor)) {
ans += "(" + args[0];
for (int j = 1; j < args.length; j++)
for (int j = 1; j < args.length; j++) {
ans += " + " + args[j];
}
ans += ") / " + args.length;
} else if ("log10".equals(functor)) {
ans += "Math.log10(" + args[0] +")";
Expand All @@ -399,7 +416,11 @@ public String mapFunction(String functor, String... args) {
} else if ("exp".equals(functor)) {
ans += "Math.exp(" + args[0] +")";
} else if ("pow".equals(functor)) {
ans += "Math.pow(" + args[0] +","+ args[1] +")";
if ( "0".equals( args[0] ) && "0".equals( args[1] ) ) {
ans += "1";
} else {
ans += "Math.pow(" + args[0] +","+ args[1] +")";
}
} else if ("threshold".equals(functor)) {
ans += args[0] + " > " + args[1] + " ? 1 : 0";
} else if ("floor".equals(functor)) {
Expand Down Expand Up @@ -451,12 +472,14 @@ public String mapFunction(String functor, String... args) {
ans += "( ! " + args[0] + " )";
} else if ("and".equals(functor)) {
ans += args[0];
for (int j = 1; j < args.length; j++)
for (int j = 1; j < args.length; j++) {
ans += " && " + args[j];
}
} else if ("or".equals(functor)) {
ans += args[0];
for (int j = 1; j < args.length; j++)
for (int j = 1; j < args.length; j++) {
ans += " || " + args[j];
}
} else if ("if".equals(functor)) {
ans += args[0] + " ? " + args[1] + " : " + ( args.length > 2 ? args[2] : "null" );
} else {
Expand Down
Expand Up @@ -501,12 +501,16 @@ rule "processDerivedField_fieldRef"
dialect "mvel"
when
$fld : DerivedField( $ref : fieldRef )
FieldRef( this == $ref, $f : field)
FieldRef( this == $ref, $f : field, $miss : mapMissingTo )
TypeOfField( name == $f, $type : dataType )
then
HashMap map = new HashMap(7);
HashMap map = new HashMap( 7 );
map.put( "context", utils.context );
map.put( "name", utils.compactUpperCase( $fld.name ) );
map.put( "origField", utils.compactUpperCase( $f ) );
map.put( "mapsMissing", $miss != null );
map.put( "mapMissingTo", $miss );
map.put( "type", $type );
applyTemplate( "aliasedField.drlt", utils, registry, map, theory );
end

Expand Down Expand Up @@ -1069,6 +1073,9 @@ end






declare MatchContext
father : Apply @key
root : Apply @key
Expand All @@ -1088,7 +1095,7 @@ then
retract( $a );
Apply idApply = new Apply();
idApply.setFunction( "identity" );
idApply.getConstantsAndFieldRevesAndNormContinuouses().add( $a );
idApply.getConstantsAndFieldRevesAndNormContinuouses().add( $a );
modify ( $df ) {
setNormContinuous( null ),
setApply( idApply );
Expand Down Expand Up @@ -1192,6 +1199,13 @@ end










//**********************************************************************************************************
//
// MODELS
Expand Down
Expand Up @@ -29,23 +29,23 @@
*/
}

@declare{'applyRule'}
@declare{ 'applyRule' }
rule "fun_@{name}"
when
@code{ keys = exprFieldList.keySet() }
@foreach{ field : keys }
@{field}( valid == true, missing == false, @{exprFieldList.get(field)} : value
@if{ context != null } , context == @{format("string",context)} @end{} )
@{ field }( valid == true, missing == false, @{ exprFieldList.get( field ) } : value
@if{ context != null } , context == @{ format( "string", context ) } @end{} )
@end{}
then
@{name} x = new @{name}();
x.setValue((@{dataType}) @{funExpr});
x.setMissing(false);
x.setValid(true);
x.setName(@{format("string",name)});
x.setContext(@{context});
insertLogical(x);
@{name} x = new @{ name }();
x.setValue( ( @{ dataType } ) @{ funExpr } );
x.setMissing( false );
x.setValid( true );
x.setName( @{ format( "string", name ) } );
x.setContext( @{ context } );
insertLogical( x );
end
@end{}

@includeNamed{'applyRule'}
@includeNamed{ 'applyRule' }
Expand Up @@ -27,16 +27,16 @@
rule "aliasedField_@{origField}_to_@{name}"
when
$src : @{origField}( $m : missing, $v : valid, $val : value, $ctx : context
@if{ context != null } , context == @{format("string",context)} @end{} )
@if{ context != null } , context == @{ format( "string", context ) } @end{} )
then
System.out.println("Cloning " + $src );
@{name} x = new @{name}();
x.setValue($val);
x.setMissing($m);
x.setValid($v);
x.setName(@{format("string",name)});
x.setContext($ctx);
insertLogical(x);
// System.out.println("Cloning " + $src );
@{ name } x = new @{ name }();
x.setValue( @if{ mapsMissing } $m ? @{ format( type, mapMissingTo ) } : $val @else{} $val @end{} );
x.setMissing( @if{ mapsMissing } false @else{} $m @end{} );
x.setValid( $v );
x.setName( @{ format( "string", name ) } );
x.setContext( $ctx );
insertLogical( x );
end
@end{}

Expand Down
Expand Up @@ -40,15 +40,15 @@ public void testSchemaWithValidValues() throws Exception {
getKSession().fireAllRules();

getKSession().getWorkingMemoryEntryPoint("in_Feat1").insert(2.2);
getKSession().fireAllRules();
getKSession().fireAllRules();
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat1"),
true, false,"Test_MLP",2.2);
refreshKSession();
true, false,"Test_MLP",2.2);
refreshKSession();

getKSession().getWorkingMemoryEntryPoint("in_Feat2").insert(5);
getKSession().fireAllRules();
getKSession().fireAllRules();
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat2"),
true, false,"Test_MLP",5);
true, false,"Test_MLP",5);
}


Expand All @@ -60,15 +60,15 @@ public void testSchemaWithOutliers() throws Exception {


getKSession().getWorkingMemoryEntryPoint("in_Feat1").insert(0.24);
getKSession().fireAllRules();
getKSession().fireAllRules();
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat1"),
true, false,"Test_MLP",1.0);
refreshKSession();
true, false,"Test_MLP",1.0);
refreshKSession();

getKSession().getWorkingMemoryEntryPoint("in_Feat1").insert(999.9);
getKSession().fireAllRules();
getKSession().fireAllRules();
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat1"),
true, false,"Test_MLP",6.9);
true, false,"Test_MLP",6.9);



Expand All @@ -84,24 +84,24 @@ public void testSchemaWithInvalid() throws Exception {

//invalid as missing
getKSession().getWorkingMemoryEntryPoint("in_Feat1").insert(-37.0);
getKSession().fireAllRules();
getKSession().fireAllRules();
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat1"),
false,false,null,-37.0);
false,false,null,-37.0);
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat1"),
true, false,"Test_MLP",3.95);
refreshKSession();
true, false,"Test_MLP",3.95);
refreshKSession();



getKSession().getWorkingMemoryEntryPoint("in_Feat2").insert(-1);
getKSession().fireAllRules();
getKSession().fireAllRules();

System.err.println(reportWMObjects(getKSession()));
System.err.println(reportWMObjects(getKSession()));

checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat2"),
false,false,null,-1);
false,false,null,-1);
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat2"),
true, false,"Test_MLP",5);
true, false,"Test_MLP",5);

}

Expand All @@ -113,20 +113,43 @@ public void testSchemaWithMissing() throws Exception {
setKbase(getKSession().getKnowledgeBase());


getKSession().getWorkingMemoryEntryPoint("in_Feat2").insert(0);
getKSession().fireAllRules();
getKSession().getWorkingMemoryEntryPoint("in_Feat2").insert(0);
getKSession().fireAllRules();

System.err.println(reportWMObjects(getKSession()));
System.err.println(reportWMObjects(getKSession()));

checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat2"),
false,true,null,0);
false,true,null,0);
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat2"),
true, false,"Test_MLP",5);
true, false,"Test_MLP",5);

}



@Test
public void testSchemaWithMixedIntervalAndValues() throws Exception {
setKSession(getModelSession(source,VERBOSE));
setKbase(getKSession().getKnowledgeBase());

getKSession().fireAllRules();

getKSession().getWorkingMemoryEntryPoint("in_Feat3").insert(4.0);
getKSession().fireAllRules();
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat3"),
true, false,"Test_MLP",4.0);

getKSession().getWorkingMemoryEntryPoint("in_Feat3").insert(7.78);
getKSession().fireAllRules();
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat3"),
true, false,"Test_MLP",7.78);

getKSession().getWorkingMemoryEntryPoint("in_Feat3").insert(6.2);
getKSession().fireAllRules();
checkFirstDataFieldOfTypeStatus(getKbase().getFactType(packageName,"Feat3"),
false, false,"Test_MLP",6.2);

}



Expand Down
Expand Up @@ -58,22 +58,25 @@ public void testKonst() throws Exception {
}


@Test
@Test
public void testAlias() throws Exception {
FactType alias = getKbase().getFactType(packageName, "AliasAge");
assertNotNull(alias);
FactType alias = getKbase().getFactType( packageName, "AliasAge" );
FactType aliasmm = getKbase().getFactType( packageName, "AliasAgeMM" );
assertNotNull( alias );
assertNotNull( aliasmm );

getKSession().getWorkingMemoryEntryPoint("in_Age").insert(33);
getKSession().getWorkingMemoryEntryPoint( "in_Age" ).insert( 33 );
getKSession().fireAllRules();

checkFirstDataFieldOfTypeStatus(alias,true,false, null,33);
checkFirstDataFieldOfTypeStatus( alias, true, false, null, 33 );

refreshKSession();

getKSession().getWorkingMemoryEntryPoint("in_Age").insert(-1);
getKSession().getWorkingMemoryEntryPoint( "in_Age" ).insert( -1 );
getKSession().fireAllRules();

checkFirstDataFieldOfTypeStatus(alias,true,true, null,-1);
checkFirstDataFieldOfTypeStatus( alias, true, true, null, -1 );
checkFirstDataFieldOfTypeStatus( aliasmm, true, false, null, 99 );

}

Expand Down
Expand Up @@ -22,6 +22,12 @@
import org.junit.Before;
import org.junit.Test;

import javax.print.attribute.standard.MediaName;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.junit.Assert.assertEquals;

public class FunctionsTest extends DroolsAbstractPMMLTest {
Expand Down Expand Up @@ -64,14 +70,18 @@ public void testFunctionMapping() {
assertEquals("(Math.min(2,Math.min(3,4)))" , ctx.mapFunction("min","2","3","4"));
assertEquals("(Math.max(2,Math.max(3,4)))" , ctx.mapFunction("max","2","3","4"));
assertEquals("(2 + 3 + 4)" , ctx.mapFunction("sum","2","3","4"));
assertEquals("(2 * 3 * 4)" , ctx.mapFunction("product","2","3","4"));
assertEquals("((2 + 3 + 4) / 3)" , ctx.mapFunction("avg","2","3","4"));
assertEquals("(3)" , ctx.mapFunction("median","1","2","3","4","5"));
assertEquals("( 0.5 * 3 + 0.5 * 4 )" , ctx.mapFunction("median","1","2","3","4","5","6"));

assertEquals("(Math.log10(2))" , ctx.mapFunction("log10","2"));
assertEquals("(Math.log(2))" , ctx.mapFunction("ln","2"));
assertEquals("(Math.sqrt(2))" , ctx.mapFunction("sqrt","2"));
assertEquals("(Math.abs(2))" , ctx.mapFunction("abs","2"));
assertEquals("(Math.exp(2))" , ctx.mapFunction("exp","2"));
assertEquals("(Math.pow(2,3))" , ctx.mapFunction("pow","2","3"));
assertEquals("(1)" , ctx.mapFunction("pow","0","0"));
assertEquals("(2 > 3 ? 1 : 0)" , ctx.mapFunction("threshold","2","3"));
assertEquals("(Math.floor(2))" , ctx.mapFunction("floor","2"));
assertEquals("(Math.ceil(2))" , ctx.mapFunction("ceil","2"));
Expand Down Expand Up @@ -109,4 +119,5 @@ public void testFunctionMapping() {




}

0 comments on commit f7a244e

Please sign in to comment.