Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
drools-scorecards: making imports row optional in the excel
drools-scorecards: initial changes for introducing multiple scoring strategies

drools-scorecards: Changes for multiple scoring strategies (AGGREGATE, AVERAGE, MINIMUM, MAXIMUM)

drools-scorecards: Addition of tests for all scoring strategies

additions to fully implement reasonCodeAlgorithm as per PMML 4.1 Scorecard spec
  • Loading branch information
vinodkiran authored and sotty committed Feb 7, 2014
1 parent 88f4f7f commit afc6f90
Show file tree
Hide file tree
Showing 15 changed files with 734 additions and 143 deletions.
Expand Up @@ -56,35 +56,10 @@ public void setCalculatedScore(double calculatedScore) {
this.calculatedScore = calculatedScore;
}

public void sortReasonCodes() {

}

// public void addPartialScore(int partialScore) {
// this.calculatedScore += partialScore;
// }
//
// public void setInitialScore(int initialScore) {
// this.calculatedScore = initialScore;
// }

public void setInitialScore(double initialScore) {
this.calculatedScore = initialScore;
}

// public void addPartialScore(double partialScore) {
// this.calculatedScore += partialScore;
// }
//
// public void addPartialScore(String field, double partialScore, String reasonCode) {
// this.calculatedScore += partialScore;
// reasonCodes.add(reasonCode);
// }

// public void addReasonCode(String reasonCode){
// reasonCodes.add(reasonCode);
// }
//
public List<String> getReasonCodes() {
return Collections.unmodifiableList(reasonCodes);
}
Expand All @@ -93,26 +68,40 @@ public void setReasonCodes(List<String> reasonCodes) {
this.reasonCodes = reasonCodes;
}

public void sortReasonCodes(List<PartialScore> partialScores) {
public void sortAndSetReasonCodes(List<PartialScore> partialScores) {
sortAndSetReasonCodes(reasonCodeAlgorithm, partialScores);
}

public void sortAndSetReasonCodes(int reasonCodeAlgorithm, List<PartialScore> partialScores) {
setReasonCodeAlgorithm(reasonCodeAlgorithm);
TreeMap<Double, String> distanceMap = new TreeMap<Double, String>();
for (PartialScore partialScore : partialScores ){
if (baselineScoreMap.get(partialScore.getCharacteristic()) != null ) {
double baseline = baselineScoreMap.get(partialScore.getCharacteristic());
double distance = 0;
if (getReasonCodeAlgorithm() == REASON_CODE_ALGORITHM_POINTSABOVE) {
distance = (baseline - partialScore.getScore())+partialScore.getPosition();
double baseline = partialScore.getBaselineScore();
double distance = 0;
if (getReasonCodeAlgorithm() == REASON_CODE_ALGORITHM_POINTSABOVE) {
distance = (baseline - partialScore.getScore())+partialScore.getPosition();
if( distance >= baseline) {
distanceMap.put(distance, partialScore.getReasoncode());
} else if (getReasonCodeAlgorithm() == REASON_CODE_ALGORITHM_POINTSBELOW){
distance = (partialScore.getScore()-baseline)+partialScore.getPosition();
}
} else if (getReasonCodeAlgorithm() == REASON_CODE_ALGORITHM_POINTSBELOW){
distance = (partialScore.getScore()-baseline)+partialScore.getPosition();
if( distance <= baseline) {
distanceMap.put(distance, partialScore.getReasoncode());
}
}
}

List<String> reasonCodes = new ArrayList<String>();
for ( Double distance : distanceMap.descendingKeySet()) {
System.out.println(distance+" "+distanceMap.get(distance));
reasonCodes.add(distanceMap.get(distance));
}
while (reasonCodes.size() < partialScores.size()){
reasonCodes.add(reasonCodes.get(reasonCodes.size()-1));
}
setReasonCodes(reasonCodes);
}


public DroolsScorecard() {
}
}
Expand Up @@ -20,13 +20,21 @@
public class PartialScore extends BaselineScore implements Serializable {
protected String reasoncode;
protected int position;
protected double baselineScore;

public PartialScore(String scorecardName, String characteristic, double score, String reasoncode, int position) {
super(scorecardName, characteristic, score);
this.reasoncode = reasoncode;
this.position = position;
}

public PartialScore(String scorecardName, String characteristic, double score, String reasoncode, double baselineScore, int position) {
super(scorecardName, characteristic, score);
this.reasoncode = reasoncode;
this.position = position;
this.baselineScore = baselineScore;
}

public PartialScore(String scorecardName, String characteristic, double score) {
super(scorecardName, characteristic, score);
this.scorecardName = scorecardName;
Expand All @@ -45,4 +53,16 @@ public String getReasoncode() {
public void setReasoncode(String reasoncode) {
this.reasoncode = reasoncode;
}

public double getBaselineScore() {
return baselineScore;
}

public void setBaselineScore(double baselineScore) {
this.baselineScore = baselineScore;
}

public void setPosition(int position) {
this.position = position;
}
}
@@ -0,0 +1,23 @@
/*
* Copyright 2012 JBoss Inc
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.drools.scorecards;

public enum ScoringStrategy {
AGGREGATE_SCORE, AVERAGE_SCORE, MAXIMUM_SCORE, MINIMUM_SCORE,
WEIGHTED_AGGREGATE_SCORE, WEIGHTED_AVERAGE_SCORE, WEIGHTED_MAXIMUM_SCORE, WEIGHTED_MINIMUM_SCORE

}
Expand Up @@ -17,6 +17,7 @@

import org.dmg.pmml.pmml_4_1.descr.*;
import org.drools.core.util.StringUtils;
import org.drools.scorecards.ScoringStrategy;
import org.drools.scorecards.parser.xls.XLSKeywords;
import org.drools.scorecards.pmml.PMMLExtensionNames;
import org.drools.scorecards.pmml.PMMLOperators;
Expand Down Expand Up @@ -100,11 +101,15 @@ public String emitDRL( PMML pmml ) {
private void addImports( PMML pmml,
Package aPackage ) {
String importsFromDelimitedString = ScorecardPMMLUtils.getExtensionValue( pmml.getHeader().getExtensions(), PMMLExtensionNames.SCORECARD_IMPORTS );
if ( !( importsFromDelimitedString == null || importsFromDelimitedString.isEmpty() ) ) {
for ( String importStatement : importsFromDelimitedString.split( "," ) ) {
if ( StringUtils.isEmpty(importsFromDelimitedString) ) {
Import imp = new Import();
imp.setClassName("java.util.*");
aPackage.addImport(imp);
} else {
for (String importStatement : importsFromDelimitedString.split(",")) {
Import imp = new Import();
imp.setClassName( importStatement );
aPackage.addImport( imp );
imp.setClassName(importStatement);
aPackage.addImport(imp);
}
}
Import defaultScorecardImport = new Import();
Expand Down Expand Up @@ -155,9 +160,9 @@ protected List<Rule> createRuleList( PMML pmmlDocument ) {
if ( desc != null ) {
rule.setDescription( desc );
}
attributePosition++;
populateLHS( rule, pmmlDocument, scorecard, c, scoreAttribute );
populateLHS(rule, pmmlDocument, scorecard, c, scoreAttribute);
populateRHS( rule, pmmlDocument, scorecard, c, scoreAttribute, attributePosition );
attributePosition++;
ruleList.add( rule );
}
}
Expand All @@ -176,7 +181,9 @@ protected void createInitialRule( List<Rule> ruleList,
rule.setDescription( "set the initial score" );

Condition condition = createInitialRuleCondition( scorecard, objectClass );
rule.addCondition( condition );
if ( condition != null) {
rule.addCondition(condition);
}
if ( scorecard.getInitialScore() > 0 ) {
Consequence consequence = new Consequence();
//consequence.setSnippet("$sc.setInitialScore(" + scorecard.getInitialScore() + ");");
Expand All @@ -201,16 +208,16 @@ protected void createInitialRule( List<Rule> ruleList,
}
}
}
if ( scorecard.getReasonCodeAlgorithm() != null ) {
Consequence consequence = new Consequence();
if ( "pointsAbove".equalsIgnoreCase( scorecard.getReasonCodeAlgorithm() ) ) {
//TODO: ReasonCode Algorithm
consequence.setSnippet( "//$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSABOVE);" );
} else if ( "pointsBelow".equalsIgnoreCase( scorecard.getReasonCodeAlgorithm() ) ) {
consequence.setSnippet( "//$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSBELOW);" );
}
rule.addConsequence( consequence );
}
// if ( scorecard.getReasonCodeAlgorithm() != null ) {
// Consequence consequence = new Consequence();
// if ( "pointsAbove".equalsIgnoreCase( scorecard.getReasonCodeAlgorithm() ) ) {
// //TODO: ReasonCode Algorithm
// consequence.setSnippet( "//$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSABOVE);" );
// } else if ( "pointsBelow".equalsIgnoreCase( scorecard.getReasonCodeAlgorithm() ) ) {
// consequence.setSnippet( "//$sc.setReasonCodeAlgorithm(DroolsScorecard.REASON_CODE_ALGORITHM_POINTSBELOW);" );
// }
// rule.addConsequence( consequence );
// }
}
ruleList.add( rule );
}
Expand Down Expand Up @@ -331,13 +338,21 @@ protected void populateRHS( Rule rule,
String setter = "insertLogical(new PartialScore(\"";
String field = ScorecardPMMLUtils.extractFieldNameFromCharacteristic( c );

stringBuilder.append( setter ).append( objectClass ).append( "\",\"" ).append( field ).append( "\"," ).append( scoreAttribute.getPartialScore() );
//stringBuilder.append( setter ).append( objectClass ).append( "\",\"" ).append( field ).append( "\"," ).append( scoreAttribute.getPartialScore() );
ScoringStrategy scoringStrategy = getScoringStrategy(scorecard);
if ( scoringStrategy.toString().startsWith("WEIGHTED")) {
String weight = ScorecardPMMLUtils.getExtensionValue(scoreAttribute.getExtensions(), PMMLExtensionNames.CHARACTERTISTIC_WEIGHT);
stringBuilder.append(setter).append(objectClass).append("\",\"").append(field).append("\",(").append(scoreAttribute.getPartialScore()).append("*").append(weight).append(")");
} else {
stringBuilder.append(setter).append(objectClass).append("\",\"").append(field).append("\",").append(scoreAttribute.getPartialScore());
}
if ( scorecard.isUseReasonCodes() ) {
String reasonCode = scoreAttribute.getReasonCode();
if ( reasonCode == null || StringUtils.isEmpty( reasonCode ) ) {
reasonCode = c.getReasonCode();
}
stringBuilder.append( ",\"" ).append( reasonCode ).append( "\", " ).append( position );
stringBuilder.append(",\"").append(reasonCode).append("\", ").append(c.getBaselineScore());
stringBuilder.append(",").append(position);
}
stringBuilder.append( "));" );
consequence.setSnippet( stringBuilder.toString() );
Expand All @@ -350,7 +365,30 @@ protected void createSummationRules( List<Rule> ruleList,
Rule calcTotalRule = new Rule( objectClass + "_calculateTotalScore", 1, 1 );
StringBuilder stringBuilder = new StringBuilder();
Condition condition = new Condition();
stringBuilder.append( "$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"" ).append( objectClass ).append( "\", $partialScore:score), sum($partialScore))" );
//stringBuilder.append( "$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"" ).append( objectClass ).append( "\", $partialScore:score), sum($partialScore))" );
ScoringStrategy strategy = getScoringStrategy(scorecard);
switch (strategy) {
case WEIGHTED_AGGREGATE_SCORE:
case AGGREGATE_SCORE: {
stringBuilder.append("$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"").append(objectClass).append("\", $partialScore:score), sum($partialScore))");
break;
}
case WEIGHTED_AVERAGE_SCORE:
case AVERAGE_SCORE:{
stringBuilder.append("$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"").append(objectClass).append("\", $partialScore:score), average($partialScore))");
break;
}
case WEIGHTED_MAXIMUM_SCORE:
case MAXIMUM_SCORE:{
stringBuilder.append("$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"").append(objectClass).append("\", $partialScore:score), max($partialScore))");
break;
}
case WEIGHTED_MINIMUM_SCORE:
case MINIMUM_SCORE:{
stringBuilder.append("$calculatedScore : Double() from accumulate (PartialScore(scorecardName ==\"").append(objectClass).append("\", $partialScore:score), min($partialScore))");
break;
}
}
condition.setSnippet( stringBuilder.toString() );
calcTotalRule.addCondition( condition );
if ( scorecard.getInitialScore() > 0 ) {
Expand All @@ -368,7 +406,8 @@ protected void createSummationRules( List<Rule> ruleList,
rule.setDescription( "collect and sort the reason codes as per the specified algorithm" );
condition = new Condition();
stringBuilder = new StringBuilder();
stringBuilder.append( "$reasons : List() from accumulate ( PartialScore(scorecardName == \"" ).append( objectClass ).append( "\", $reasonCode : reasoncode ); collectList($reasonCode) )" );
// stringBuilder.append("$reasons : List() from accumulate ( PartialScore(scorecardName == \"").append(objectClass).append("\", $reasonCode : reasoncode ); collectList($reasonCode) )");
stringBuilder.append("$partialScoresList : List() from collect ( PartialScore(scorecardName == \"").append(objectClass).append("\"))");
condition.setSnippet( stringBuilder.toString() );
rule.addCondition( condition );
ruleList.add( rule );
Expand All @@ -381,32 +420,22 @@ protected void createSummationRules( List<Rule> ruleList,
addAdditionalSummationConsequence( calcTotalRule, scorecard );
}

protected abstract void addDeclaredTypeContents( PMML pmmlDocument,
StringBuilder stringBuilder,
Scorecard scorecard );

protected abstract void internalEmitDRL( PMML pmml,
List<Rule> ruleList,
Package aPackage );

protected abstract void addLHSConditions( Rule rule,
PMML pmmlDocument,
Scorecard scorecard,
Characteristic c,
Attribute scoreAttribute );

protected abstract void addAdditionalReasonCodeConsequence( Rule rule,
Scorecard scorecard );

protected abstract void addAdditionalReasonCodeCondition( Rule rule,
Scorecard scorecard );

protected abstract void addAdditionalSummationConsequence( Rule rule,
Scorecard scorecard );

protected abstract void addAdditionalSummationCondition( Rule rule,
Scorecard scorecard );
protected ScoringStrategy getScoringStrategy(Scorecard scorecard) {
ScoringStrategy strategy = ScoringStrategy.AGGREGATE_SCORE;
String scoringStrategyName = ScorecardPMMLUtils.getExtensionValue(scorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), PMMLExtensionNames.SCORECARD_SCORING_STRATEGY);
if ( !StringUtils.isEmpty(scoringStrategyName)) {
strategy = ScoringStrategy.valueOf(scoringStrategyName);
}
return strategy;
}

protected abstract Condition createInitialRuleCondition( Scorecard scorecard,
String objectClass );
protected abstract void addDeclaredTypeContents( PMML pmmlDocument, StringBuilder stringBuilder, Scorecard scorecard );
protected abstract void internalEmitDRL( PMML pmml, List<Rule> ruleList, Package aPackage );
protected abstract void addLHSConditions( Rule rule, PMML pmmlDocument, Scorecard scorecard,
Characteristic c, Attribute scoreAttribute );
protected abstract void addAdditionalReasonCodeConsequence( Rule rule, Scorecard scorecard );
protected abstract void addAdditionalReasonCodeCondition( Rule rule, Scorecard scorecard );
protected abstract void addAdditionalSummationConsequence( Rule rule, Scorecard scorecard );
protected abstract void addAdditionalSummationCondition( Rule rule, Scorecard scorecard );
protected abstract Condition createInitialRuleCondition( Scorecard scorecard, String objectClass );
}

0 comments on commit afc6f90

Please sign in to comment.