Skip to content

Commit

Permalink
DROOLS-2284 exec model: split rule methods in separate classes, 10 ea…
Browse files Browse the repository at this point in the history
…ch (apache#1755)

* DROOLS-2284 exec model: split rule methods in separate classes, 10 each

WIP global as static

* WIP

* Rule methods splitted in a single separate class

* Split every 10 rule-methods
  • Loading branch information
tarilabs authored and mariofusco committed Feb 8, 2018
1 parent aa8cf5f commit 901bf75
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 50 deletions.
Expand Up @@ -6,17 +6,16 @@
import java.util.stream.Collectors; import java.util.stream.Collectors;


import org.drools.compiler.compiler.io.memory.MemoryFileSystem; import org.drools.compiler.compiler.io.memory.MemoryFileSystem;
import org.drools.javaparser.ast.CompilationUnit;
import org.drools.javaparser.ast.body.ClassOrInterfaceDeclaration; import org.drools.javaparser.ast.body.ClassOrInterfaceDeclaration;
import org.drools.javaparser.printer.PrettyPrinter; import org.drools.javaparser.printer.PrettyPrinter;
import org.drools.modelcompiler.builder.PackageModel.RuleSourceResult;


import static org.drools.core.util.StringUtils.generateUUID;
import static org.drools.modelcompiler.CanonicalKieModule.MODEL_FILE; import static org.drools.modelcompiler.CanonicalKieModule.MODEL_FILE;
import static org.drools.modelcompiler.builder.JavaParserCompiler.getPrettyPrinter; import static org.drools.modelcompiler.builder.JavaParserCompiler.getPrettyPrinter;


public class ModelWriter { public class ModelWriter {


private static final String RULES_FILE_NAME = "Rules";

public Result writeModel(MemoryFileSystem srcMfs, Collection<PackageModel> packageModels) { public Result writeModel(MemoryFileSystem srcMfs, Collection<PackageModel> packageModels) {
List<String> sourceFiles = new ArrayList<>(); List<String> sourceFiles = new ArrayList<>();
List<String> modelFiles = new ArrayList<>(); List<String> modelFiles = new ArrayList<>();
Expand All @@ -29,29 +28,37 @@ public Result writeModel(MemoryFileSystem srcMfs, Collection<PackageModel> packa


for (ClassOrInterfaceDeclaration generatedPojo : pkgModel.getGeneratedPOJOsSource()) { for (ClassOrInterfaceDeclaration generatedPojo : pkgModel.getGeneratedPOJOsSource()) {
final String source = JavaParserCompiler.toPojoSource(pkgModel.getName(), pkgModel.getImports(), generatedPojo); final String source = JavaParserCompiler.toPojoSource(pkgModel.getName(), pkgModel.getImports(), generatedPojo);
pkgModel.print( source ); pkgModel.sysout(source);
String pojoSourceName = "src/main/java/" + folderName + "/" + generatedPojo.getName() + ".java"; String pojoSourceName = "src/main/java/" + folderName + "/" + generatedPojo.getName() + ".java";
srcMfs.write( pojoSourceName, source.getBytes() ); srcMfs.write( pojoSourceName, source.getBytes() );
sourceFiles.add( pojoSourceName ); sourceFiles.add( pojoSourceName );
} }


String rulesFileName = generateRulesFileName(); RuleSourceResult rulesSourceResult = pkgModel.getRulesSource();
// main rules file:
String rulesFileName = pkgModel.getRulesFileName();
String rulesSourceName = "src/main/java/" + folderName + "/" + rulesFileName + ".java"; String rulesSourceName = "src/main/java/" + folderName + "/" + rulesFileName + ".java";
String rulesSource = pkgModel.getRulesSource( prettyPrinter, rulesFileName, pkgName ); String rulesSource = prettyPrinter.print(rulesSourceResult.getMainRuleClass());
pkgModel.print( rulesSource ); pkgModel.sysout(rulesSource);
byte[] rulesBytes = rulesSource.getBytes(); byte[] rulesBytes = rulesSource.getBytes();
srcMfs.write( rulesSourceName, rulesBytes ); srcMfs.write( rulesSourceName, rulesBytes );
modelFiles.add( pkgName + "." + rulesFileName ); modelFiles.add( pkgName + "." + rulesFileName );
sourceFiles.add( rulesSourceName ); sourceFiles.add( rulesSourceName );
// manage additional classes, please notice to not add to modelFiles.
for (CompilationUnit cu : rulesSourceResult.getSplitted()) {
String addFileName = cu.findFirst(ClassOrInterfaceDeclaration.class).get().getNameAsString();
String addSourceName = "src/main/java/" + folderName + "/" + addFileName + ".java";
String addSource = prettyPrinter.print(cu);
pkgModel.sysout(addSource);
byte[] addBytes = addSource.getBytes();
srcMfs.write(addSourceName, addBytes);
sourceFiles.add(addSourceName);
}
} }


return new Result(sourceFiles, modelFiles); return new Result(sourceFiles, modelFiles);
} }


private String generateRulesFileName() {
return RULES_FILE_NAME + generateUUID();
}

public void writeModelFile( List<String> modelSources, MemoryFileSystem trgMfs) { public void writeModelFile( List<String> modelSources, MemoryFileSystem trgMfs) {
final String pkgNames; final String pkgNames;
if(!modelSources.isEmpty()) { if(!modelSources.isEmpty()) {
Expand Down
Expand Up @@ -18,11 +18,13 @@


import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.HashSet; import java.util.HashSet;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Map.Entry;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;


Expand All @@ -44,7 +46,6 @@
import org.drools.javaparser.ast.stmt.BlockStmt; import org.drools.javaparser.ast.stmt.BlockStmt;
import org.drools.javaparser.ast.type.ClassOrInterfaceType; import org.drools.javaparser.ast.type.ClassOrInterfaceType;
import org.drools.javaparser.ast.type.Type; import org.drools.javaparser.ast.type.Type;
import org.drools.javaparser.printer.PrettyPrinter;
import org.drools.model.Global; import org.drools.model.Global;
import org.drools.model.Model; import org.drools.model.Model;
import org.drools.model.WindowReference; import org.drools.model.WindowReference;
Expand All @@ -54,11 +55,15 @@
import org.drools.modelcompiler.builder.generator.QueryParameter; import org.drools.modelcompiler.builder.generator.QueryParameter;
import org.kie.api.runtime.rule.AccumulateFunction; import org.kie.api.runtime.rule.AccumulateFunction;


import static org.drools.core.util.StringUtils.generateUUID;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toVar; import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.toVar;


public class PackageModel { public class PackageModel {


private static final String RULES_FILE_NAME = "Rules";

private final String name; private final String name;
private final String rulesFileName;


private Set<String> imports = new HashSet<>(); private Set<String> imports = new HashSet<>();


Expand All @@ -85,12 +90,22 @@ public class PackageModel {
private KnowledgeBuilderConfigurationImpl configuration; private KnowledgeBuilderConfigurationImpl configuration;
private Map<String, AccumulateFunction> accumulateFunctions; private Map<String, AccumulateFunction> accumulateFunctions;



public PackageModel(String name, KnowledgeBuilderConfigurationImpl configuration) { public PackageModel(String name, KnowledgeBuilderConfigurationImpl configuration) {
this.name = name; this.name = name;
this.rulesFileName = generateRulesFileName();
this.configuration = configuration; this.configuration = configuration;
exprIdGenerator = new DRLIdGenerator(); exprIdGenerator = new DRLIdGenerator();
} }


public String getRulesFileName() {
return rulesFileName;
}

private String generateRulesFileName() {
return RULES_FILE_NAME + generateUUID();
}

public KnowledgeBuilderConfigurationImpl getConfiguration() { public KnowledgeBuilderConfigurationImpl getConfiguration() {
return configuration; return configuration;
} }
Expand Down Expand Up @@ -189,39 +204,51 @@ public Map<String, AccumulateFunction> getAccumulateFunctions() {
return accumulateFunctions; return accumulateFunctions;
} }


public String getRulesSource(PrettyPrinter prettyPrinter, String className, String modelName) { public static class RuleSourceResult {
CompilationUnit cu = new CompilationUnit();
cu.setPackageDeclaration( name );


// fixed part private final CompilationUnit mainRuleClass;
cu.addImport(JavaParser.parseImport("import java.util.*;" )); private Collection<CompilationUnit> splitted = new ArrayList<>();
cu.addImport(JavaParser.parseImport("import org.drools.model.*;" ));
cu.addImport(JavaParser.parseImport("import static org.drools.model.DSL.*;" ));
cu.addImport(JavaParser.parseImport("import org.drools.model.Index.ConstraintType;"));
cu.addImport(JavaParser.parseImport("import java.time.*;"));
cu.addImport(JavaParser.parseImport("import java.time.format.*;"));
cu.addImport(JavaParser.parseImport("import java.text.*;"));
cu.addImport(JavaParser.parseImport("import org.drools.core.util.*;"));


// imports from DRL: public RuleSourceResult(CompilationUnit mainRuleClass) {
for ( String i : imports ) { this.mainRuleClass = mainRuleClass;
if ( i.equals(name+".*") ) {
continue; // skip same-package star import.
}
cu.addImport(JavaParser.parseImport("import "+i+";"));
} }

public CompilationUnit getMainRuleClass() {
return mainRuleClass;
}

/**
* Append additional class to source results.
* @param additionalCU
*/
public RuleSourceResult with(CompilationUnit additionalCU) {
splitted.add(additionalCU);
return this;
}

public Collection<CompilationUnit> getSplitted() {
return Collections.unmodifiableCollection(splitted);
}

}

public RuleSourceResult getRulesSource() {
CompilationUnit cu = new CompilationUnit();
cu.setPackageDeclaration( name );

manageImportForCompilationUnit(cu);


ClassOrInterfaceDeclaration rulesClass = cu.addClass(className); ClassOrInterfaceDeclaration rulesClass = cu.addClass(rulesFileName);
rulesClass.addImplementedType(Model.class); rulesClass.addImplementedType(Model.class);


BodyDeclaration<?> dateFormatter = JavaParser.parseBodyDeclaration( BodyDeclaration<?> dateFormatter = JavaParser.parseBodyDeclaration(
"private final static DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern(DateUtils.getDateFormatMask());\n"); "public final static DateTimeFormatter dateTimeFormatter = DateTimeFormatter.ofPattern(DateUtils.getDateFormatMask());\n");
rulesClass.addMember(dateFormatter); rulesClass.addMember(dateFormatter);


BodyDeclaration<?> getNameMethod = JavaParser.parseBodyDeclaration( BodyDeclaration<?> getNameMethod = JavaParser.parseBodyDeclaration(
" @Override\n" + " @Override\n" +
" public String getName() {\n" + " public String getName() {\n" +
" return \"" + modelName + "\";\n" + " return \"" + name + "\";\n" +
" }\n" " }\n"
); );
rulesClass.addMember(getNameMethod); rulesClass.addMember(getNameMethod);
Expand Down Expand Up @@ -276,7 +303,7 @@ public String getRulesSource(PrettyPrinter prettyPrinter, String className, Stri




for(Map.Entry<String, MethodCallExpr> windowReference : windowReferences.entrySet()) { for(Map.Entry<String, MethodCallExpr> windowReference : windowReferences.entrySet()) {
FieldDeclaration f = rulesClass.addField(WINDOW_REFERENCE_TYPE, windowReference.getKey()); FieldDeclaration f = rulesClass.addField(WINDOW_REFERENCE_TYPE, windowReference.getKey(), Modifier.PUBLIC, Modifier.STATIC, Modifier.FINAL);
f.getVariables().get(0).setInitializer(windowReference.getValue()); f.getVariables().get(0).setInitializer(windowReference.getValue());
} }


Expand All @@ -285,7 +312,7 @@ public String getRulesSource(PrettyPrinter prettyPrinter, String className, Stri
} }


for(Map.Entry<String, QueryGenerator.QueryDefWithType> queryDef: queryDefWithType.entrySet()) { for(Map.Entry<String, QueryGenerator.QueryDefWithType> queryDef: queryDefWithType.entrySet()) {
FieldDeclaration field = rulesClass.addField(queryDef.getValue().getQueryType(), queryDef.getKey(), Modifier.FINAL); FieldDeclaration field = rulesClass.addField(queryDef.getValue().getQueryType(), queryDef.getKey(), Modifier.PUBLIC, Modifier.STATIC, Modifier.FINAL);
field.getVariables().get(0).setInitializer(queryDef.getValue().getMethodCallExpr()); field.getVariables().get(0).setInitializer(queryDef.getValue().getMethodCallExpr());
} }


Expand All @@ -300,12 +327,6 @@ public String getRulesSource(PrettyPrinter prettyPrinter, String className, Stri
rulesClass.addMember(rulesListInitializer); rulesClass.addMember(rulesListInitializer);
BlockStmt rulesListInitializerBody = new BlockStmt(); BlockStmt rulesListInitializerBody = new BlockStmt();
rulesListInitializer.setBody(rulesListInitializerBody); rulesListInitializer.setBody(rulesListInitializerBody);
for ( String methodName : ruleMethods.keySet() ) {
NameExpr rulesFieldName = new NameExpr( "rules" );
MethodCallExpr add = new MethodCallExpr(rulesFieldName, "add");
add.addArgument( new MethodCallExpr(null, methodName) );
rulesListInitializerBody.addStatement( add );
}


for ( String methodName : queryMethods.keySet() ) { for ( String methodName : queryMethods.keySet() ) {
NameExpr rulesFieldName = new NameExpr( "queries" ); NameExpr rulesFieldName = new NameExpr( "queries" );
Expand Down Expand Up @@ -337,13 +358,60 @@ public String getRulesSource(PrettyPrinter prettyPrinter, String className, Stri


functions.forEach(rulesClass::addMember); functions.forEach(rulesClass::addMember);


RuleSourceResult results = new RuleSourceResult(cu);

// each method per Drlx parser result // each method per Drlx parser result
ruleMethods.values().forEach( rulesClass::addMember ); int count = 0; // I count which method it is.
int index = 0; // I decide which classIndex goes into.
Map<Integer, ClassOrInterfaceDeclaration> splitted = new LinkedHashMap<>();
for (Entry<String, MethodDeclaration> ruleMethodKV : ruleMethods.entrySet()) {
count++;
if (count % 10 == 0) {
index++;
}
ClassOrInterfaceDeclaration rulesMethodClass = splitted.computeIfAbsent(index, i -> {
CompilationUnit cuRulesMethod = new CompilationUnit();
results.with(cuRulesMethod);
cuRulesMethod.setPackageDeclaration(name);
manageImportForCompilationUnit(cuRulesMethod);
cuRulesMethod.addImport(JavaParser.parseImport("import static " + name + "." + rulesFileName + ".*;"));
String currentRulesMethodClassName = rulesFileName + "RuleMethods" + i;
ClassOrInterfaceDeclaration r = cuRulesMethod.addClass(currentRulesMethodClassName);
return r;
});
rulesMethodClass.addMember(ruleMethodKV.getValue());

// manage in main class init block:
NameExpr rulesFieldName = new NameExpr("rules");
MethodCallExpr add = new MethodCallExpr(rulesFieldName, "add");
add.addArgument(new MethodCallExpr(new NameExpr(rulesMethodClass.getNameAsString()), ruleMethodKV.getKey()));
rulesListInitializerBody.addStatement(add);
}

queryMethods.values().forEach(rulesClass::addMember); queryMethods.values().forEach(rulesClass::addMember);




// config.setColumnAlignFirstMethodChain(true); return results;
return prettyPrinter.print(cu); }

private void manageImportForCompilationUnit(CompilationUnit cu) {
// fixed part
cu.addImport(JavaParser.parseImport("import java.util.*;" ));
cu.addImport(JavaParser.parseImport("import org.drools.model.*;" ));
cu.addImport(JavaParser.parseImport("import static org.drools.model.DSL.*;" ));
cu.addImport(JavaParser.parseImport("import org.drools.model.Index.ConstraintType;"));
cu.addImport(JavaParser.parseImport("import java.time.*;"));
cu.addImport(JavaParser.parseImport("import java.time.format.*;"));
cu.addImport(JavaParser.parseImport("import java.text.*;"));
cu.addImport(JavaParser.parseImport("import org.drools.core.util.*;"));

// imports from DRL:
for ( String i : imports ) {
if ( i.equals(name+".*") ) {
continue; // skip same-package star import.
}
cu.addImport(JavaParser.parseImport("import "+i+";"));
}
} }


private static void addGlobalField(ClassOrInterfaceDeclaration classDeclaration, String packageName, String globalName, Class<?> globalClass) { private static void addGlobalField(ClassOrInterfaceDeclaration classDeclaration, String packageName, String globalName, Class<?> globalClass) {
Expand All @@ -358,12 +426,12 @@ private static void addGlobalField(ClassOrInterfaceDeclaration classDeclaration,
declarationOfCall.addArgument(new StringLiteralExpr(packageName)); declarationOfCall.addArgument(new StringLiteralExpr(packageName));
declarationOfCall.addArgument(new StringLiteralExpr(globalName)); declarationOfCall.addArgument(new StringLiteralExpr(globalName));


FieldDeclaration field = classDeclaration.addField(varType, toVar(globalName), Modifier.FINAL); FieldDeclaration field = classDeclaration.addField(varType, toVar(globalName), Modifier.PUBLIC, Modifier.STATIC, Modifier.FINAL);


field.getVariables().get(0).setInitializer(declarationOfCall); field.getVariables().get(0).setInitializer(declarationOfCall);
} }


public void print(String source) { public void sysout(String source) {
System.out.println("====="); System.out.println("=====");
System.out.println(source); System.out.println(source);
System.out.println("====="); System.out.println("=====");
Expand Down
Expand Up @@ -83,7 +83,6 @@


import static java.util.stream.Collectors.toList; import static java.util.stream.Collectors.toList;
import static java.util.stream.Collectors.toSet; import static java.util.stream.Collectors.toSet;

import static org.drools.javaparser.JavaParser.parseExpression; import static org.drools.javaparser.JavaParser.parseExpression;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.classToReferenceType; import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.classToReferenceType;
import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.generateLambdaWithoutParameters; import static org.drools.modelcompiler.builder.generator.DrlxParseUtil.generateLambdaWithoutParameters;
Expand Down Expand Up @@ -201,7 +200,8 @@ private static void processRule(KnowledgeBuilderImpl kbuilder, InternalKnowledge
} }


new ModelGeneratorVisitor(context, packageModel).visit(ruleDescr.getLhs()); new ModelGeneratorVisitor(context, packageModel).visit(ruleDescr.getLhs());
MethodDeclaration ruleMethod = new MethodDeclaration(EnumSet.of(Modifier.PRIVATE), RULE_TYPE, "rule_" + toId( ruleDescr.getName() ) ); final String ruleMethodName = "rule_" + toId(ruleDescr.getName());
MethodDeclaration ruleMethod = new MethodDeclaration(EnumSet.of(Modifier.PUBLIC, Modifier.STATIC), RULE_TYPE, ruleMethodName);


ruleMethod.setJavadocComment(" Rule name: " + ruleDescr.getName() + " "); ruleMethod.setJavadocComment(" Rule name: " + ruleDescr.getName() + " ");


Expand Down Expand Up @@ -242,7 +242,7 @@ private static void processRule(KnowledgeBuilderImpl kbuilder, InternalKnowledge
ruleVariablesBlock.addStatement(new AssignExpr(ruleVar, buildCall, AssignExpr.Operator.ASSIGN)); ruleVariablesBlock.addStatement(new AssignExpr(ruleVar, buildCall, AssignExpr.Operator.ASSIGN));


ruleVariablesBlock.addStatement( new ReturnStmt(RULE_CALL) ); ruleVariablesBlock.addStatement( new ReturnStmt(RULE_CALL) );
packageModel.putRuleMethod("rule_" + toId( ruleDescr.getName() ), ruleMethod); packageModel.putRuleMethod(ruleMethodName, ruleMethod);
} }


/** /**
Expand Down Expand Up @@ -468,7 +468,7 @@ private static MethodCallExpr onCall(Collection<String> usedArguments) {


if (!usedArguments.isEmpty()) { if (!usedArguments.isEmpty()) {
onCall = new MethodCallExpr(null, ON_CALL); onCall = new MethodCallExpr(null, ON_CALL);
usedArguments.stream().map( org.drools.modelcompiler.builder.generator.DrlxParseUtil::toVar).forEach(onCall::addArgument ); usedArguments.stream().map(DrlxParseUtil::toVar).forEach(onCall::addArgument );
} }
return onCall; return onCall;
} }
Expand Down

0 comments on commit 901bf75

Please sign in to comment.