Skip to content

Commit

Permalink
[DROOLS-372] fix jitting of constraint using Serializable
Browse files Browse the repository at this point in the history
  • Loading branch information
mariofusco committed Dec 10, 2013
1 parent b77856c commit d198e3a
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 19 deletions.
Expand Up @@ -92,6 +92,7 @@
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.ArrayList;
Expand Down Expand Up @@ -4730,7 +4731,7 @@ public void testWildcardImportForTypeFieldOldApi() {
StatefulKnowledgeSession ksession = kbase.newStatefulKnowledgeSession();
}

@Test @Ignore("Fixed with mvel 2.1.8.Final")
@Test
public void testTypeCheckInOr() {
// BZ-1029911
String str = "import org.drools.compiler.*;\n" +
Expand All @@ -4753,7 +4754,7 @@ public void testTypeCheckInOr() {
ksession.fireAllRules();
}

@Test @Ignore("Fixed with mvel 2.1.8.Final")
@Test
public void testDynamicNegativeSalienceWithSpace() {
// DROOLS-302
String str =
Expand Down Expand Up @@ -4799,5 +4800,35 @@ public void testJoinNoLoop() {

assertEquals(40, mario.getAge());
}

@Test
public void testConstraintOnSerializable() {
// DROOLS-372
String str =
"import org.drools.compiler.integrationtests.Misc2Test.SerializableValue\n" +
"rule R\n" +
"when\n" +
" SerializableValue( value == \"1\" )\n" +
"then\n" +
"end\n";

KnowledgeBase kbase = loadKnowledgeBaseFromString(str);
StatefulKnowledgeSession ksession = kbase.newStatefulKnowledgeSession();

ksession.insert(new SerializableValue("0"));
ksession.fireAllRules();
}

public static class SerializableValue {
private final Serializable value;

public SerializableValue(Serializable value) {
this.value = value;
}

public Serializable getValue() {
return value;
}
}
}

Expand Up @@ -19,6 +19,7 @@
import java.util.Map;
import java.util.Set;

import static java.lang.reflect.Modifier.isAbstract;
import static org.drools.core.util.ClassUtils.convertFromPrimitiveType;
import static org.drools.core.util.ClassUtils.convertPrimitiveNameToType;
import static org.drools.core.util.ClassUtils.convertToPrimitiveType;
Expand Down Expand Up @@ -603,6 +604,8 @@ protected final void push(Object obj, Class<?> type) {
mv.visitLdcInsn(classGenerator.toType((Class<?>) obj));
} else if (type == Character.class) {
invokeConstructor(Character.class, new Object[]{ obj.toString().charAt(0) }, char.class);
} else if (type.isInterface() || isAbstract(type.getModifiers())) {
push(obj, obj.getClass());
} else {
invokeConstructor(type, new Object[]{ obj.toString() }, String.class);
}
Expand Down
Expand Up @@ -186,7 +186,7 @@ private void jitBinary(SingleCondition singleCondition) {
Expression left = singleCondition.getLeft();
Expression right = singleCondition.getRight();
Class<?> commonType = singleCondition.getOperation().needsSameType() ?
findCommonClass(left.getType(), !left.canBeNull(), right.getType(), !right.canBeNull()) :
findCommonClass(left.getType(), !left.canBeNull(), right.getType(), !right.canBeNull(), singleCondition.getOperation().isEquality()) :
null;

if (commonType == Object.class && singleCondition.getOperation().isComparison()) {
Expand Down Expand Up @@ -769,6 +769,9 @@ private void jitMethodInvocation(MethodInvocation invocation, Class<?> currentCl
if (!firstInvocation) {
mv.visitVarInsn(ALOAD, 1);
}
if (!invocation.getReturnType().isAssignableFrom(currentClass)) {
cast(invocation.getReturnType());
}
return;
}

Expand Down Expand Up @@ -900,7 +903,7 @@ private int toOpCode(BooleanOperator op, Class<?> type) {
throw new RuntimeException("Unknown operator: " + op);
}

private Class<?> findCommonClass(Class<?> class1, boolean primitive1, Class<?> class2, boolean primitive2) {
private Class<?> findCommonClass(Class<?> class1, boolean primitive1, Class<?> class2, boolean primitive2, boolean forEquality) {
Class<?> result = null;
if (class1 == class2) {
result = class1;
Expand Down Expand Up @@ -931,9 +934,13 @@ private Class<?> findCommonClass(Class<?> class1, boolean primitive1, Class<?> c
result = findCommonClass(class2, class1, primitive1);
}
if (result == null) {
throw new RuntimeException( "Cannot find a common class between " + class1.getName() + " and " + class2.getName() +
" || " + class1.hashCode() + " vs " + class2.hashCode()
);
if (forEquality) {
return Object.class;
} else {
throw new RuntimeException( "Cannot find a common class between " + class1.getName() + " and " + class2.getName() +
" || " + class1.hashCode() + " vs " + class2.hashCode()
);
}
}
return result == Number.class ? Double.class : result;
}
Expand Down
@@ -1,5 +1,6 @@
package org.drools.core.rule.constraint;

import org.drools.core.factmodel.traits.Thing;
import org.drools.core.rule.Declaration;
import org.mvel2.Operator;
import org.mvel2.ParserContext;
Expand Down Expand Up @@ -205,11 +206,11 @@ private Expression analyzeNode(ASTNode node) {
expression.firstExpression = analyzeNode(main);
if (accessor instanceof DynamicGetAccessor) {
AccessorNode accessorNode = (AccessorNode)((DynamicGetAccessor)accessor).getSafeAccessor();
expression.addInvocation(analyzeAccessor(accessorNode, null));
expression.addInvocation(analyzeAccessorInvocation(accessorNode, node, null));
} else if (accessor instanceof AccessorNode) {
AccessorNode accessorNode = (AccessorNode)accessor;
while (accessorNode != null) {
expression.addInvocation(analyzeAccessor(accessorNode, null));
expression.addInvocation(analyzeAccessorInvocation(accessorNode, node, null));
accessorNode = accessorNode.getNextNode();
}
} else {
Expand Down Expand Up @@ -239,7 +240,7 @@ private Expression analyzeNode(ASTNode node) {
}
Class<?> variableType = getVariableType(variableName);
return new VariableExpression(variableName,
analyzeExpressionNode(((AccessorNode) accessor).getNextNode()),
analyzeExpressionNode(((AccessorNode) accessor).getNextNode(), node),
variableType != null ? variableType : node.getEgressType());
}

Expand All @@ -254,7 +255,7 @@ private Expression analyzeNode(ASTNode node) {
String variableName = (String)(variableAccessor.getProperty());
Class<?> variableType = getVariableType(variableName);
if (variableType != null) {
return new VariableExpression(variableName, analyzeExpressionNode(accessorNode), variableType);
return new VariableExpression(variableName, analyzeExpressionNode(accessorNode, node), variableType);
} else {
if (node.getLiteralValue() instanceof ParserContext) {
ParserContext pCtx = (ParserContext)node.getLiteralValue();
Expand All @@ -281,10 +282,10 @@ private Expression analyzeNode(ASTNode node) {
throw new RuntimeException("Null accessor on node: " + node);
}

return analyzeAccessor(accessor);
return analyzeNodeAccessor(accessor, node);
}

private Expression analyzeAccessor(Accessor accessor) {
private Expression analyzeNodeAccessor(Accessor accessor, ASTNode node) {
AccessorNode accessorNode;
if (accessor instanceof DynamicGetAccessor) {
accessorNode = (AccessorNode)((DynamicGetAccessor)accessor).getSafeAccessor();
Expand All @@ -304,7 +305,7 @@ private Expression analyzeAccessor(Accessor accessor) {
accessorNode = accessorNode.getNextNode();
}
} else {
return analyzeAccessor(accessorNode);
return analyzeNodeAccessor(accessorNode, node);
}
}

Expand All @@ -317,7 +318,7 @@ private Expression analyzeAccessor(Accessor accessor) {
}
}

return analyzeExpressionNode(accessorNode);
return analyzeExpressionNode(accessorNode, node);
}

private boolean isStaticAccessor(AccessorNode accessorNode) {
Expand Down Expand Up @@ -352,12 +353,12 @@ private EvaluatedExpression analyzeListCreation(ListCreator listCreator) {
return new EvaluatedExpression(invocation);
}

private EvaluatedExpression analyzeExpressionNode(AccessorNode accessorNode) {
private EvaluatedExpression analyzeExpressionNode(AccessorNode accessorNode, ASTNode containingNode) {
if (accessorNode == null) return null;
EvaluatedExpression expression = new EvaluatedExpression();
Invocation invocation = null;
while (accessorNode != null) {
invocation = analyzeAccessor(accessorNode, invocation);
invocation = analyzeAccessorInvocation(accessorNode, containingNode, invocation);
if (invocation != null) {
expression.addInvocation(invocation);
}
Expand All @@ -366,7 +367,7 @@ private EvaluatedExpression analyzeExpressionNode(AccessorNode accessorNode) {
return expression;
}

private Invocation analyzeAccessor(AccessorNode accessorNode, Invocation formerInvocation) {
private Invocation analyzeAccessorInvocation(AccessorNode accessorNode, ASTNode containingNode, Invocation formerInvocation) {
if (accessorNode instanceof GetterAccessor) {
return new MethodInvocation(((GetterAccessor)accessorNode).getMethod(), conditionClass);
}
Expand Down Expand Up @@ -450,7 +451,7 @@ private Invocation analyzeAccessor(AccessorNode accessorNode, Invocation formerI
}

if (accessorNode instanceof ThisValueAccessor) {
return new MethodInvocation(null);
return new ThisInvocation(accessorNode.getNextNode() == null ? containingNode.getEgressType() : Object.class);
}

throw new RuntimeException("Unknown AccessorNode type: " + accessorNode.getClass().getName());
Expand Down Expand Up @@ -911,6 +912,9 @@ private Method getMethodFromInterface(Class<?> clazz, Method method, String cond
return iMethod;
}
}
if (clazz != Thing.class) {
return null;
}
}
try {
return clazz.getMethod(method.getName(), method.getParameterTypes());
Expand All @@ -932,6 +936,20 @@ public Class<?> getReturnType() {
}
}

public static class ThisInvocation extends MethodInvocation {
private final Class<?> thisClass;

public ThisInvocation(Class<?> thisClass) {
super(null);
this.thisClass = thisClass;
}

@Override
public Class<?> getReturnType() {
return thisClass;
}
}

public static class ConstructorInvocation extends Invocation {
private final Constructor constructor;

Expand Down

0 comments on commit d198e3a

Please sign in to comment.