Skip to content

Commit

Permalink
Fixed subtraction
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed May 17, 2016
1 parent 6245904 commit 6af4bad
Show file tree
Hide file tree
Showing 14 changed files with 160 additions and 66 deletions.
Expand Up @@ -121,7 +121,19 @@ public void gt( Expression lhs, Expression rhs )
} }


@Override @Override
public void sub( Expression lhs, Expression rhs ) public void subtractInts( Expression lhs, Expression rhs )
{

}

@Override
public void subtractLongs( Expression lhs, Expression rhs )
{

}

@Override
public void subtractDoubles( Expression lhs, Expression rhs )
{ {


} }
Expand Down
28 changes: 26 additions & 2 deletions community/codegen/src/main/java/org/neo4j/codegen/Expression.java
Expand Up @@ -118,14 +118,38 @@ public void accept( ExpressionVisitor visitor )
}; };
} }


public static Expression sub( final Expression lhs, final Expression rhs ) public static Expression subtractInts( final Expression lhs, final Expression rhs )
{ {
return new Expression() return new Expression()
{ {
@Override @Override
public void accept( ExpressionVisitor visitor ) public void accept( ExpressionVisitor visitor )
{ {
visitor.sub( lhs, rhs ); visitor.subtractInts( lhs, rhs );
}
};
}

public static Expression subtractLongs( final Expression lhs, final Expression rhs )
{
return new Expression()
{
@Override
public void accept( ExpressionVisitor visitor )
{
visitor.subtractLongs( lhs, rhs );
}
};
}

public static Expression subtractDoubles( final Expression lhs, final Expression rhs )
{
return new Expression()
{
@Override
public void accept( ExpressionVisitor visitor )
{
visitor.subtractDoubles( lhs, rhs );
} }
}; };
} }
Expand Down
Expand Up @@ -189,7 +189,24 @@ public void gt( Expression lhs, Expression rhs )
} }


@Override @Override
public void sub( Expression lhs, Expression rhs ) public void subtractInts( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

@Override
public void subtractLongs( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

@Override
public void subtractDoubles( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

private void sub( Expression lhs, Expression rhs )
{ {
result.append( "sub(" ); result.append( "sub(" );
lhs.accept( this ); lhs.accept( this );
Expand Down
Expand Up @@ -53,7 +53,11 @@ public interface ExpressionVisitor


void gt( Expression lhs, Expression rhs ); void gt( Expression lhs, Expression rhs );


void sub( Expression lhs, Expression rhs ); void subtractInts( Expression lhs, Expression rhs );

void subtractLongs( Expression lhs, Expression rhs );

void subtractDoubles( Expression lhs, Expression rhs );


void cast( TypeReference type, Expression expression ); void cast( TypeReference type, Expression expression );


Expand Down
Expand Up @@ -352,39 +352,27 @@ public void gt( Expression lhs, Expression rhs )
} }


@Override @Override
public void sub( Expression lhs, Expression rhs ) public void subtractInts( Expression lhs, Expression rhs )
{ {
TypeReference lhsType = findType( lhs ); lhs.accept( this );
TypeReference rhsType = findType( rhs ); rhs.accept( this );
methodVisitor.visitInsn( ISUB );
}


if ( !lhsType.equals( rhsType ) ) @Override
{ public void subtractLongs( Expression lhs, Expression rhs )
throw new IllegalStateException( "Cannot compare values of different types" ); {
} lhs.accept( this );
rhs.accept( this );
methodVisitor.visitInsn( LSUB );
}


@Override
public void subtractDoubles( Expression lhs, Expression rhs )
{
lhs.accept( this ); lhs.accept( this );
rhs.accept( this ); rhs.accept( this );
switch ( lhsType.simpleName() ) methodVisitor.visitInsn( DSUB );
{
case "int":
case "byte":
case "short":
case "char":
case "boolean":
methodVisitor.visitInsn( ISUB );
break;
case "long":
methodVisitor.visitInsn( LSUB );
break;
case "float":
methodVisitor.visitInsn( FSUB );
break;
case "double":
methodVisitor.visitInsn( DSUB );
break;
default:
throw new IllegalStateException( "Subtraction is only supported for primitive number types" );
}
} }


@Override @Override
Expand Down
Expand Up @@ -444,7 +444,24 @@ public void gt( Expression lhs, Expression rhs )
} }


@Override @Override
public void sub( Expression lhs, Expression rhs ) public void subtractInts( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

@Override
public void subtractLongs( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

@Override
public void subtractDoubles( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

private void sub( Expression lhs, Expression rhs )
{ {
lhs.accept( this ); lhs.accept( this );
append( " - " ); append( " - " );
Expand Down
Expand Up @@ -65,7 +65,8 @@
import static org.neo4j.codegen.Expression.newInstance; import static org.neo4j.codegen.Expression.newInstance;
import static org.neo4j.codegen.Expression.not; import static org.neo4j.codegen.Expression.not;
import static org.neo4j.codegen.Expression.or; import static org.neo4j.codegen.Expression.or;
import static org.neo4j.codegen.Expression.sub; import static org.neo4j.codegen.Expression.subtractDoubles;
import static org.neo4j.codegen.Expression.subtractLongs;
import static org.neo4j.codegen.Expression.ternary; import static org.neo4j.codegen.Expression.ternary;
import static org.neo4j.codegen.ExpressionTemplate.cast; import static org.neo4j.codegen.ExpressionTemplate.cast;
import static org.neo4j.codegen.ExpressionTemplate.load; import static org.neo4j.codegen.ExpressionTemplate.load;
Expand Down Expand Up @@ -1065,16 +1066,13 @@ public void shouldHandleAddition() throws Throwable
{ {
assertThat( addForType( int.class, 17, 18 ), equalTo( 35 ) ); assertThat( addForType( int.class, 17, 18 ), equalTo( 35 ) );
assertThat( addForType( long.class, 17L, 18L ), equalTo( 35L ) ); assertThat( addForType( long.class, 17L, 18L ), equalTo( 35L ) );
assertThat( addForType( float.class, 17F, 18F ), equalTo( 35F ) );
assertThat( addForType( double.class, 17D, 18D ), equalTo( 35D ) ); assertThat( addForType( double.class, 17D, 18D ), equalTo( 35D ) );
} }


@Test @Test
public void shouldHandleSubtraction() throws Throwable public void shouldHandleSubtraction() throws Throwable
{ {
assertThat( subtractForType( int.class, 19, 18 ), equalTo( 1 ) );
assertThat( subtractForType( long.class, 19L, 18L ), equalTo( 1L ) ); assertThat( subtractForType( long.class, 19L, 18L ), equalTo( 1L ) );
assertThat( subtractForType( float.class, 19F, 18F ), equalTo( 1F ) );
assertThat( subtractForType( double.class, 19D, 18D ), equalTo( 1D ) ); assertThat( subtractForType( double.class, 19D, 18D ), equalTo( 1D ) );
} }


Expand Down Expand Up @@ -1128,7 +1126,18 @@ private <T> T subtractForType( Class<T> clazz, T lhs, T rhs ) throws Throwable
try ( CodeBlock block = simple.generateMethod( clazz, "sub", try ( CodeBlock block = simple.generateMethod( clazz, "sub",
param( clazz, "a" ), param( clazz, "b" ) ) ) param( clazz, "a" ), param( clazz, "b" ) ) )
{ {
block.returns( sub( block.load( "a" ), block.load( "b" ) ) ); if (clazz == long.class)
{
block.returns( subtractLongs( block.load( "a" ), block.load( "b" ) ) );
}
else if (clazz == double.class)
{
block.returns( subtractDoubles( block.load( "a" ), block.load( "b" ) ) );
}
else
{
fail( "adding " + clazz.getSimpleName() + " is not supported" );
}
} }


handle = simple.handle(); handle = simple.handle();
Expand Down
Expand Up @@ -78,7 +78,9 @@ trait MethodStructure[E] {
def add(lhs: E, rhs: E): E def add(lhs: E, rhs: E): E
def addIntegers(lhs: E, rhs: E): E def addIntegers(lhs: E, rhs: E): E
def addFloats(lhs: E, rhs: E): E def addFloats(lhs: E, rhs: E): E
def sub(lhs: E, rhs: E): E def subtract(lhs: E, rhs: E): E
def subtractIntegers(lhs: E, rhs: E): E
def subtractFloats(lhs: E, rhs: E): E
def mul(lhs: E, rhs: E): E def mul(lhs: E, rhs: E): E
def div(lhs: E, rhs: E): E def div(lhs: E, rhs: E): E
def mod(lhs: E, rhs: E): E def mod(lhs: E, rhs: E): E
Expand Down
Expand Up @@ -49,7 +49,7 @@ trait NumericalOpType {
override def cypherType(implicit context: CodeGenContext) = override def cypherType(implicit context: CodeGenContext) =
(lhs.cypherType, rhs.cypherType) match { (lhs.cypherType, rhs.cypherType) match {
case (CTInteger, CTInteger) => CTInteger case (CTInteger, CTInteger) => CTInteger
case (_: NumberType, _: NumberType) => CTFloat case (Number(_), Number(_)) => CTFloat
// Runtime we'll figure it out - can't store it in a primitive field unless we are 100% of the type // Runtime we'll figure it out - can't store it in a primitive field unless we are 100% of the type
case _ => CTAny case _ => CTAny
} }
Expand Down
Expand Up @@ -20,11 +20,28 @@
package org.neo4j.cypher.internal.compiler.v3_1.codegen.ir.expressions package org.neo4j.cypher.internal.compiler.v3_1.codegen.ir.expressions


import org.neo4j.cypher.internal.compiler.v3_1.codegen.{CodeGenContext, MethodStructure} import org.neo4j.cypher.internal.compiler.v3_1.codegen.{CodeGenContext, MethodStructure}
import org.neo4j.cypher.internal.frontend.v3_1.CypherTypeException
import org.neo4j.cypher.internal.frontend.v3_1.symbols._


case class Subtraction(lhs: CodeGenExpression, rhs: CodeGenExpression) case class Subtraction(lhs: CodeGenExpression, rhs: CodeGenExpression)
extends CodeGenExpression with BinaryOperator with NumericalOpType { extends CodeGenExpression with BinaryOperator with NumericalOpType {


override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.sub override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) =
(lhs.cypherType, rhs.cypherType) match {
//primitive cases
case (CTInteger, CTInteger) => structure.subtractIntegers
case (CTFloat, CTFloat) => structure.subtractFloats
case (CTInteger, CTFloat) => (l, r) => structure.subtractFloats(structure.toFloat(l), r)
case (CTFloat, CTInteger) => (l, r) => structure.subtractFloats(l, structure.toFloat(r))
case (CTBoolean, _) => throw new CypherTypeException(s"Cannot add a boolean and ${rhs.cypherType}")
case (_, CTBoolean) => throw new CypherTypeException(s"Cannot add a ${rhs.cypherType} and a boolean")

//reference cases
case (Number(t), _) => (l, r) => structure.subtract(structure.box(l, t), r)
case (_, Number(t)) => (l, r) => structure.subtract(l, structure.box(r, t))

case _ => structure.subtract
}


override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable
} }
Expand Up @@ -313,7 +313,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator


override def decreaseCounterAndCheckForZero(name: String): Expression = { override def decreaseCounterAndCheckForZero(name: String): Expression = {
val local = locals(name) val local = locals(name)
generator.assign(local, Expression.sub(local, Expression.constant(1))) generator.assign(local, Expression.subtractInts(local, Expression.constant(1)))
Expression.eq(Expression.constant(0), local) Expression.eq(Expression.constant(0), local)
} }


Expand Down Expand Up @@ -542,7 +542,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator
generator.assign(times, Expression.invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar))) generator.assign(times, Expression.invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar)))
using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body =>
block(copy(generator=body)) block(copy(generator=body))
body.assign(times, Expression.sub(times, Expression.constant(1))) body.assign(times, Expression.subtractInts(times, Expression.constant(1)))
} }
case LongsToCountTable => case LongsToCountTable =>
val times = generator.declare(typeRef[Int], context.namer.newVarName()) val times = generator.declare(typeRef[Int], context.namer.newVarName())
Expand All @@ -557,7 +557,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator


using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body =>
block(copy(generator=body)) block(copy(generator=body))
body.assign(times, Expression.sub(times, Expression.constant(1))) body.assign(times, Expression.subtractInts(times, Expression.constant(1)))
} }


case tableType@LongToListTable(structure,localVars) => case tableType@LongToListTable(structure,localVars) =>
Expand Down
Expand Up @@ -328,7 +328,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator


override def decreaseCounterAndCheckForZero(name: String): Expression = { override def decreaseCounterAndCheckForZero(name: String): Expression = {
val local = locals(name) val local = locals(name)
generator.assign(local, Expression.sub(local, Expression.constant(1))) generator.assign(local, Expression.subtractInts(local, Expression.constant(1)))
Expression.eq(Expression.constant(0), local) Expression.eq(Expression.constant(0), local)
} }


Expand Down Expand Up @@ -556,7 +556,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator
generator.assign(times, Expression.invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar))) generator.assign(times, Expression.invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar)))
using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body =>
block(copy(generator=body)) block(copy(generator=body))
body.assign(times, Expression.sub(times, Expression.constant(1))) body.assign(times, Expression.subtractInts(times, Expression.constant(1)))
} }
case LongsToCountTable => case LongsToCountTable =>
val times = generator.declare(typeRef[Int], context.namer.newVarName()) val times = generator.declare(typeRef[Int], context.namer.newVarName())
Expand All @@ -571,7 +571,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator


using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body =>
block(copy(generator=body)) block(copy(generator=body))
body.assign(times, Expression.sub(times, Expression.constant(1))) body.assign(times, Expression.subtractInts(times, Expression.constant(1)))
} }


case tableType@LongToListTable(structure,localVars) => case tableType@LongToListTable(structure,localVars) =>
Expand Down
Expand Up @@ -329,7 +329,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator


override def decreaseCounterAndCheckForZero(name: String): Expression = { override def decreaseCounterAndCheckForZero(name: String): Expression = {
val local = locals(name) val local = locals(name)
generator.assign(local, Expression.sub(local, Expression.constant(1))) generator.assign(local, Expression.subtractInts(local, Expression.constant(1)))
Expression.eq(Expression.constant(0), local) Expression.eq(Expression.constant(0), local)
} }


Expand Down Expand Up @@ -473,7 +473,11 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator


override def add(lhs: Expression, rhs: Expression) = math(Methods.mathAdd, lhs, rhs) override def add(lhs: Expression, rhs: Expression) = math(Methods.mathAdd, lhs, rhs)


override def sub(lhs: Expression, rhs: Expression) = math(Methods.mathSub, lhs, rhs) override def subtractIntegers(lhs: Expression, rhs: Expression) = Expression.subtractLongs(lhs, rhs)

override def subtractFloats(lhs: Expression, rhs: Expression) = Expression.subtractDoubles(lhs, rhs)

override def subtract(lhs: Expression, rhs: Expression) = math(Methods.mathSub, lhs, rhs)


override def mul(lhs: Expression, rhs: Expression) = math(Methods.mathMul, lhs, rhs) override def mul(lhs: Expression, rhs: Expression) = math(Methods.mathMul, lhs, rhs)


Expand Down Expand Up @@ -587,7 +591,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator
.invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar))) .invoke(generator.load(tableVar), Methods.countingTableGet, generator.load(keyVar)))
using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body =>
block(copy(generator=body)) block(copy(generator=body))
body.assign(times, Expression.sub(times, Expression.constant(1))) body.assign(times, Expression.subtractInts(times, Expression.constant(1)))
} }
case LongsToCountTable => case LongsToCountTable =>
val times = generator.declare(typeRef[Int], context.namer.newVarName()) val times = generator.declare(typeRef[Int], context.namer.newVarName())
Expand All @@ -602,7 +606,7 @@ private case class Method(fields: Fields, generator: CodeBlock, aux:AuxGenerator


using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body => using(generator.whileLoop(Expression.gt(times, Expression.constant(0)))) { body =>
block(copy(generator=body)) block(copy(generator=body))
body.assign(times, Expression.sub(times, Expression.constant(1))) body.assign(times, Expression.subtractInts(times, Expression.constant(1)))
} }


case tableType@LongToListTable(structure,localVars) => case tableType@LongToListTable(structure,localVars) =>
Expand Down

0 comments on commit 6af4bad

Please sign in to comment.