Skip to content

Commit

Permalink
[DROOLS-7176] remove leaky abstractions in rule unit API (apache#2516)
Browse files Browse the repository at this point in the history
* [DROOLS-7176] remove leaky abstractions in rule unit API

* wip
  • Loading branch information
mariofusco committed Sep 28, 2022
1 parent 1ee67cc commit bb56b57
Show file tree
Hide file tree
Showing 30 changed files with 83 additions and 247 deletions.

This file was deleted.

4 changes: 4 additions & 0 deletions api/kogito-api/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@
<groupId>org.drools</groupId>
<artifactId>drools-ruleunits-api</artifactId>
</dependency>
<dependency>
<groupId>org.drools</groupId>
<artifactId>drools-ruleunits-impl</artifactId>
</dependency>
<dependency>
<groupId>org.kie</groupId>
<artifactId>kie-dmn-api</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import org.kie.kogito.KogitoConfig;

public interface RuleConfig extends org.drools.ruleunits.api.RuleConfig, KogitoConfig {
public interface RuleConfig extends KogitoConfig {

RuleEventListenerConfig ruleEventListeners();
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,13 @@
*/
package org.kie.kogito.rules;

public interface RuleEventListenerConfig extends org.drools.ruleunits.api.RuleEventListenerConfig {
import java.util.List;

import org.kie.api.event.rule.AgendaEventListener;
import org.kie.api.event.rule.RuleRuntimeEventListener;

public interface RuleEventListenerConfig {
List<AgendaEventListener> agendaListeners();

List<RuleRuntimeEventListener> ruleRuntimeListeners();
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.drools.ruleunits.api.conf.ClockType;
import org.drools.ruleunits.api.conf.EventProcessingType;

public class RuleUnitConfig extends org.drools.ruleunits.api.RuleUnitConfig {
public class RuleUnitConfig extends org.drools.ruleunits.impl.RuleUnitConfig {

public RuleUnitConfig(EventProcessingType eventProcessingType, ClockType clockType, Integer sessionPool) {
super(eventProcessingType, clockType, sessionPool);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@

import org.kie.kogito.KogitoEngine;

public interface RuleUnits extends org.drools.ruleunits.api.RuleUnits, KogitoEngine {
public interface RuleUnits extends org.drools.ruleunits.impl.RuleUnits, KogitoEngine {
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,13 @@
import java.util.HashMap;
import java.util.Map;

import org.drools.ruleunits.api.RuleUnit;
import org.drools.ruleunits.api.RuleUnitInstance;
import org.kie.kogito.rules.RuleUnits;

public abstract class AbstractRuleUnits implements RuleUnits {

private Map<String, RuleUnitInstance<?>> unitRegistry = new HashMap<>();

@Override
public <T extends org.drools.ruleunits.api.RuleUnitData> RuleUnit<T> create(Class<T> clazz) {
return (RuleUnit<T>) create(clazz.getCanonicalName());
}

protected abstract RuleUnit<?> create(String fqcn);

@Override
public void register(String name, RuleUnitInstance<?> unitInstance) {
if (name == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@
import org.drools.ruleunits.api.DataStore;
import org.drools.ruleunits.api.RuleUnit;
import org.drools.ruleunits.api.RuleUnitInstance;
import org.drools.ruleunits.api.RuleUnitQuery;
import org.drools.ruleunits.impl.AbstractRuleUnitInstance;
import org.drools.ruleunits.impl.InternalRuleUnit;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.EnumSource;
import org.junit.jupiter.params.provider.ValueSource;
Expand All @@ -53,7 +53,6 @@
import org.kie.kogito.rules.RuleUnits;

import static java.util.Arrays.asList;
import static java.util.stream.Collectors.toList;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
Expand Down Expand Up @@ -153,10 +152,10 @@ public void testRuleUnitQuery(SessionType sessionType) throws Exception {
RuleUnit<AdultUnit> unit = application.get(RuleUnits.class).create(AdultUnit.class);
RuleUnitInstance<AdultUnit> instance = unit.createInstance(adults);

Class<? extends RuleUnitQuery<List<String>>> queryClass = (Class<? extends RuleUnitQuery<List<String>>>) application.getClass()
Class<?> queryClass = application.getClass()
.getClassLoader().loadClass("org.kie.kogito.codegen.unit.AdultUnitQueryFindAdults");

List<String> results = instance.executeQuery(queryClass);
List<String> results = (List<String>) queryClass.getMethod("execute", RuleUnitInstance.class).invoke(null, instance);

assertEquals(2, results.size());
assertTrue(results.containsAll(asList("Mario", "Marilena")));
Expand All @@ -176,11 +175,7 @@ public void testRuleUnitQueryOnPrimitive(SessionType sessionType) throws Excepti
RuleUnit<AdultUnit> unit = application.get(RuleUnits.class).create(AdultUnit.class);
RuleUnitInstance<AdultUnit> instance = unit.createInstance(adults);

List<Integer> results = instance.executeQuery("FindAdultsAge")
.stream()
.map(m -> m.get("$age"))
.map(Integer.class::cast)
.collect(toList());
List<Object> results = instance.executeQuery("FindAdultsAge").toList("$age");

assertEquals(2, results.size());
assertTrue(results.containsAll(asList(45, 47)));
Expand All @@ -201,11 +196,7 @@ public void testRuleUnitQueryWithNoRules(SessionType sessionType) throws Excepti
RuleUnit<AdultUnit> unit = application.get(RuleUnits.class).create(AdultUnit.class);
RuleUnitInstance<AdultUnit> instance = unit.createInstance(adults);

List<Integer> results = instance.executeQuery("FindAdultsAge")
.stream()
.map(m -> m.get("$sum"))
.map(Integer.class::cast)
.collect(toList());
List<Object> results = instance.executeQuery("FindAdultsAge").toList("$sum");

assertEquals(1, results.size());
assertThat(results).containsExactlyInAnyOrder(99);
Expand All @@ -226,10 +217,10 @@ public void testRuleUnitExecutor(SessionType sessionType) throws Exception {
RuleUnit<AdultUnit> adultUnit = application.get(RuleUnits.class).create(AdultUnit.class);

AdultUnit adultData18 = new AdultUnit(persons, 18);
RuleUnitInstance<AdultUnit> adultUnitInstance18 = adultUnit.createInstance(adultData18, "adult18");
RuleUnitInstance<AdultUnit> adultUnitInstance18 = ((InternalRuleUnit) adultUnit).createInstance(adultData18, "adult18");

AdultUnit adultData21 = new AdultUnit(persons, 21);
RuleUnitInstance<AdultUnit> adultUnitInstance21 = adultUnit.createInstance(adultData21, "adult21");
RuleUnitInstance<AdultUnit> adultUnitInstance21 = ((InternalRuleUnit) adultUnit).createInstance(adultData21, "adult21");

RuleUnit<PersonsUnit> personsUnit = application.get(RuleUnits.class).create(PersonsUnit.class);
personsUnit.createInstance(new PersonsUnit(persons)).fire();
Expand Down Expand Up @@ -312,14 +303,10 @@ public void test2PatternsOopath(SessionType sessionType) throws Exception {
RuleUnit<AdultUnit> unit = application.get(RuleUnits.class).create(AdultUnit.class);
RuleUnitInstance<AdultUnit> instance = unit.createInstance(adults);

List<Person> results = instance.executeQuery("FindPeopleInMilano")
.stream()
.map(m -> m.get("$p"))
.map(Person.class::cast)
.collect(toList());
List<Object> results = instance.executeQuery("FindPeopleInMilano").toList("$p");

assertEquals(1, results.size());
assertEquals("Mario", results.get(0).getName());
assertEquals("Mario", ((Person) results.get(0)).getName());
}

@ParameterizedTest
Expand All @@ -346,7 +333,7 @@ public void testCep(SessionType sessionType) throws Exception {
stockUnit.getStockTicks().append(new StockTick("IBM", 1700, 170));
stockUnit.getStockTicks().append(new StockTick("IBM", 1500, 240));

ValueDrop valueDrop = (ValueDrop) instance.executeQuery("highestValueDrop", "IBM").get(0).get("$s");
ValueDrop valueDrop = (ValueDrop) instance.executeQuery("highestValueDrop", "IBM").iterator().next().get("$s");
assertEquals(300, valueDrop.getDropAmount());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Map<String, KieBaseModel> kieBaseModels() {
void addRuleUnitConfig(RuleUnitDescription ruleUnitDescription, RuleUnitConfig overridingConfig) {
// merge config from the descriptor with configs from application.conf
// application.conf overrides any other config
org.drools.ruleunits.api.RuleUnitConfig config =
org.drools.ruleunits.impl.RuleUnitConfig config =
((AbstractRuleUnitDescription) ruleUnitDescription).getConfig()
.merged(overridingConfig);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.body.Parameter;
import com.github.javaparser.ast.body.VariableDeclarator;
import com.github.javaparser.ast.expr.ClassExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.Name;
import com.github.javaparser.ast.expr.NameExpr;
Expand Down Expand Up @@ -117,7 +116,7 @@ private void generateQueryMethods(CompilationUnit cu, ClassOrInterfaceDeclaratio
.orElseThrow(() -> new NoSuchElementException("A method declaration doesn't contain a body!"))
.getStatement(1);
returnStatement.findAll(VariableDeclarator.class).forEach(decl -> setGeneric(decl.getType(), returnType));
returnStatement.findAll(ClassExpr.class).forEach(expr -> expr.setType(queryClassName));
returnStatement.findAll(MethodCallExpr.class).forEach(expr -> expr.setScope(new NameExpr(queryClassName)));

MethodDeclaration queryMethodSingle = clazz.getMethodsByName("executeQueryFirst").get(0);
queryMethodSingle.getParameter(0).setType(ruleUnit.getCanonicalName() + (hasDI ? "" : "DTO"));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@
import com.github.javaparser.ast.NodeList;
import com.github.javaparser.ast.body.ClassOrInterfaceDeclaration;
import com.github.javaparser.ast.body.ConstructorDeclaration;
import com.github.javaparser.ast.body.FieldDeclaration;
import com.github.javaparser.ast.body.MethodDeclaration;
import com.github.javaparser.ast.expr.AssignExpr;
import com.github.javaparser.ast.expr.CastExpr;
import com.github.javaparser.ast.expr.MethodCallExpr;
import com.github.javaparser.ast.expr.MethodReferenceExpr;
import com.github.javaparser.ast.expr.NameExpr;
import com.github.javaparser.ast.expr.ObjectCreationExpr;
import com.github.javaparser.ast.expr.StringLiteralExpr;
Expand Down Expand Up @@ -128,34 +128,20 @@ public GeneratedFile generate() {

cu.findAll(StringLiteralExpr.class).forEach(this::interpolateStrings);

FieldDeclaration ruleUnitDeclaration = clazz
.getFieldByName("instance")
.orElseThrow(() -> new NoSuchElementException("ClassOrInterfaceDeclaration doesn't contain a field named ruleUnit!"));
setGeneric(ruleUnitDeclaration.getElementType(), ruleUnit);

String returnType = getReturnType(clazz);
setGeneric(clazz.getImplementedTypes(0).getTypeArguments().get().get(0), returnType);
generateConstructors(clazz);
generateQueryMethod(cu, clazz, returnType);
generateQueryMethod(clazz, returnType);
clazz.getMembers().sort(new BodyDeclarationComparator());

return new GeneratedFile(QUERY_TYPE,
generatedFilePath(),
cu.toString());
}

private void generateConstructors(ClassOrInterfaceDeclaration clazz) {
for (ConstructorDeclaration c : clazz.getConstructors()) {
c.setName(targetClassName);
if (!c.getParameters().isEmpty()) {
setGeneric(c.getParameter(0).getType(), ruleUnit);
}
}
}

private void generateQueryMethod(CompilationUnit cu, ClassOrInterfaceDeclaration clazz, String returnType) {
private void generateQueryMethod(ClassOrInterfaceDeclaration clazz, String returnType) {
MethodDeclaration queryMethod = clazz.getMethodsByName("execute").get(0);
setGeneric(queryMethod.getType(), returnType);
setGeneric(queryMethod.getParameter(0).getType(), ruleUnit);
queryMethod.findAll(MethodReferenceExpr.class).forEach(mr -> mr.setScope(new NameExpr(targetClassName)));
}

private String getReturnType(ClassOrInterfaceDeclaration clazz) {
Expand Down

0 comments on commit bb56b57

Please sign in to comment.