From e2bfa0649815a46021c05f436235d459d1b4658a Mon Sep 17 00:00:00 2001 From: Luca Molteni Date: Sun, 8 Jul 2018 17:43:12 +0200 Subject: [PATCH] [DROOLS-2625] Support reverse in exec model's accumulator (#1976) * Support reverse * If not present in the reverse map take it from the memory --- .../constraints/LambdaAccumulator.java | 21 +++++++++++++++++-- .../modelcompiler/domain/TargetPolicy.java | 9 ++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/constraints/LambdaAccumulator.java b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/constraints/LambdaAccumulator.java index 2fa749cfd00..7dda335c1a9 100644 --- a/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/constraints/LambdaAccumulator.java +++ b/drools-model/drools-model-compiler/src/main/java/org/drools/modelcompiler/constraints/LambdaAccumulator.java @@ -2,7 +2,9 @@ import java.io.Serializable; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import org.drools.core.WorkingMemory; import org.drools.core.common.InternalFactHandle; @@ -16,6 +18,8 @@ public abstract class LambdaAccumulator implements Accumulator { private final org.kie.api.runtime.rule.AccumulateFunction accumulateFunction; protected final List sourceVariables; + private Map reverseSupport; + protected LambdaAccumulator(org.kie.api.runtime.rule.AccumulateFunction accumulateFunction, List sourceVariables) { this.accumulateFunction = accumulateFunction; @@ -40,11 +44,18 @@ public Serializable createContext() { @Override public void init(Object workingMemoryContext, Object context, Tuple leftTuple, Declaration[] declarations, WorkingMemory workingMemory) throws Exception { accumulateFunction.init((Serializable) context); + if(supportsReverse()) { + reverseSupport = new HashMap<>(); + } } @Override public void accumulate(Object workingMemoryContext, Object context, Tuple leftTuple, InternalFactHandle handle, Declaration[] declarations, Declaration[] innerDeclarations, WorkingMemory workingMemory) throws Exception { - accumulateFunction.accumulate((Serializable) context, getAccumulatedObject( declarations, innerDeclarations, handle, leftTuple, ( InternalWorkingMemory ) workingMemory )); + final Object accumulatedObject = getAccumulatedObject(declarations, innerDeclarations, handle, leftTuple, (InternalWorkingMemory) workingMemory); + if (supportsReverse()) { + reverseSupport.put(handle.getId(), accumulatedObject); + } + accumulateFunction.accumulate((Serializable) context, accumulatedObject); } protected abstract Object getAccumulatedObject( Declaration[] declarations, Declaration[] innerDeclarations, InternalFactHandle handle, Tuple tuple, InternalWorkingMemory wm ); @@ -56,7 +67,13 @@ public boolean supportsReverse() { @Override public void reverse(Object workingMemoryContext, Object context, Tuple leftTuple, InternalFactHandle handle, Declaration[] declarations, Declaration[] innerDeclarations, WorkingMemory workingMemory) throws Exception { - accumulateFunction.reverse((Serializable) context, getAccumulatedObject( declarations, innerDeclarations, handle, leftTuple, ( InternalWorkingMemory ) workingMemory )); + final Object accumulatedObject = reverseSupport.remove(handle.getId()); + if(accumulatedObject == null) { + final Object accumulatedObject2 = getAccumulatedObject(declarations, innerDeclarations, handle, leftTuple, (InternalWorkingMemory) workingMemory); + accumulateFunction.reverse((Serializable) context, accumulatedObject2); + } else { + accumulateFunction.reverse((Serializable) context, accumulatedObject); + } } @Override diff --git a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/domain/TargetPolicy.java b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/domain/TargetPolicy.java index c49685fddb7..848901974ad 100644 --- a/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/domain/TargetPolicy.java +++ b/drools-model/drools-model-compiler/src/test/java/org/drools/modelcompiler/domain/TargetPolicy.java @@ -47,5 +47,14 @@ public int getCoefficient() { public void setCoefficient(int coefficient) { this.coefficient = coefficient; } + + @Override + public String toString() { + return "TargetPolicy{" + + "customerCode='" + customerCode + '\'' + + ", productCode='" + productCode + '\'' + + ", coefficient=" + coefficient + + '}'; + } }