Skip to content

Commit

Permalink
[BZ-1057000] fix jitting of a constraint containing an array creation
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Jan 23, 2014
1 parent 20da66c commit fb7c885
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 6 deletions.
Expand Up @@ -5015,4 +5015,73 @@ public void testStagedTupleLeak() throws Exception {
assertEquals(0, stagedRightTuples.insertSize());
assertNull(stagedRightTuples.getInsertFirst());
}

@Test
public void testJittingConstraintWithArrayParams() throws Exception {
// BZ-1057000
String str =
"import org.drools.compiler.integrationtests.Misc2Test.Strings\n" +
"\n" +
"global java.util.List allList;\n" +
"global java.util.List anyList;\n" +
"\n" +
"rule R_all when\n" +
" Strings( containsAll(\"1\", \"2\") )\n" +
"then\n" +
" allList.add(\"1\");\n" +
"end\n" +
"\n" +
"rule R_any when\n" +
" Strings( containsAny(new String[] {\"1\", \"2\"}) )\n" +
"then\n" +
" anyList.add(\"1\");\n" +
"end\n";

KnowledgeBase kbase = loadKnowledgeBaseFromString(str);
StatefulKnowledgeSession ksession = kbase.newStatefulKnowledgeSession();

List<String> allList = new ArrayList<String>();
ksession.setGlobal("allList", allList);
List<String> anyList = new ArrayList<String>();
ksession.setGlobal("anyList", anyList);

ksession.insert(new Strings("1", "2", "3"));
ksession.insert(new Strings("2", "3"));
ksession.fireAllRules();

assertEquals(1, allList.size());
assertEquals(2, anyList.size());
}

public static class Strings {
private final String[] strings;

public Strings(String... strings) {
this.strings = strings;
}

public boolean containsAny(String[] array) {
for (String candidate : array) {
for (String s : strings) {
if (candidate.equals(s)) {
return true;
}
}
}
return false;
}

public boolean containsAll(String... array) {
int counter = 0;
for (String candidate : array) {
for (String s : strings) {
if (candidate.equals(s)) {
counter++;
break;
}
}
}
return counter == array.length;
}
}
}
Expand Up @@ -26,6 +26,7 @@
import org.mvel2.compiler.ExecutableLiteral;
import org.mvel2.compiler.ExecutableStatement;
import org.mvel2.optimizers.dynamic.DynamicGetAccessor;
import org.mvel2.optimizers.impl.refl.collection.ArrayCreator;
import org.mvel2.optimizers.impl.refl.collection.ExprValueAccessor;
import org.mvel2.optimizers.impl.refl.collection.ListCreator;
import org.mvel2.optimizers.impl.refl.nodes.ArrayAccessor;
Expand All @@ -45,6 +46,7 @@
import org.mvel2.optimizers.impl.refl.nodes.ThisValueAccessor;
import org.mvel2.optimizers.impl.refl.nodes.VariableAccessor;

import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
Expand Down Expand Up @@ -294,6 +296,8 @@ private Expression analyzeNodeAccessor(Accessor accessor, ASTNode node) {
return analyzeNode(((CompiledExpression)accessor).getFirstNode());
} else if (accessor instanceof ListCreator) {
return analyzeListCreation(((ListCreator) accessor));
} else if (accessor instanceof ArrayCreator) {
return analyzeArrayCreation(((ArrayCreator) accessor));
} else {
throw new RuntimeException("Unknown accessor type: " + accessor);
}
Expand Down Expand Up @@ -334,22 +338,31 @@ private boolean isStaticAccessor(AccessorNode accessorNode) {
return false;
}

private Expression analyzeArrayCreation(ArrayCreator arrayCreator) {
Accessor[] accessors = getFieldValue(ArrayCreator.class, "template", (ArrayCreator) arrayCreator);
Class<?> type = arrayCreator.getKnownEgressType();
Class<?> arrayType = Array.newInstance(type, 0).getClass();
return getArrayCreationExpression( arrayType, type, accessors );
}

private EvaluatedExpression analyzeListCreation(ListCreator listCreator) {
Method listCreationMethod = null;
try {
listCreationMethod = Arrays.class.getMethod("asList", Object[].class);
} catch (NoSuchMethodException e) { }

Invocation invocation = new MethodInvocation(listCreationMethod);
invocation.addArgument( getArrayCreationExpression( Object[].class, Object.class, listCreator.getValues() ) );
return new EvaluatedExpression(invocation);
}

ArrayCreationExpression arrayExpression = new ArrayCreationExpression(Object[].class);
Accessor[] accessors = listCreator.getValues();
private ArrayCreationExpression getArrayCreationExpression(Class<?> arrayType, Class<?> type, Accessor[] accessors) {
ArrayCreationExpression arrayExpression = new ArrayCreationExpression(arrayType);
for (Accessor accessor : accessors) {
ExecutableStatement statement = ((ExprValueAccessor)accessor).getStmt();
arrayExpression.addItem(statementToExpression(statement, Object.class));
arrayExpression.addItem(statementToExpression(statement, type));
}
invocation.addArgument(arrayExpression);

return new EvaluatedExpression(invocation);
return arrayExpression;
}

private EvaluatedExpression analyzeExpressionNode(AccessorNode accessorNode, ASTNode containingNode, Class<?> variableType) {
Expand Down

0 comments on commit fb7c885

Please sign in to comment.