Skip to content

Commit

Permalink
Replaced subtractInts, subtractDoubles, subtractLongs with subtract
Browse files Browse the repository at this point in the history
Now that the types are on the `Expression` there is no longer any need to have
one method per type.
  • Loading branch information
pontusmelke authored and henriknyman committed Nov 10, 2016
1 parent 3dac538 commit e43395c
Show file tree
Hide file tree
Showing 14 changed files with 47 additions and 129 deletions.
34 changes: 8 additions & 26 deletions community/codegen/src/main/java/org/neo4j/codegen/Expression.java
Expand Up @@ -161,38 +161,20 @@ public void accept( ExpressionVisitor visitor )
};
}

public static Expression subtractInts( final Expression lhs, final Expression rhs )
public static Expression subtract( final Expression lhs, final Expression rhs )
{
return new Expression( INT )
{
@Override
public void accept( ExpressionVisitor visitor )
{
visitor.subtractInts( lhs, rhs );
}
};
}

public static Expression subtractLongs( final Expression lhs, final Expression rhs )
{
return new Expression( LONG )
if ( !lhs.type.equals( rhs.type ) )
{
@Override
public void accept( ExpressionVisitor visitor )
{
visitor.subtractLongs( lhs, rhs );
}
};
}

public static Expression subtractDoubles( final Expression lhs, final Expression rhs )
{
return new Expression( DOUBLE )
throw new IllegalArgumentException(
String.format( "Cannot subtract variable with different types. LHS %s, RHS %s", lhs.type.simpleName(),
rhs.type.simpleName() ));
}
return new Expression( lhs.type )
{
@Override
public void accept( ExpressionVisitor visitor )
{
visitor.subtractDoubles( lhs, rhs );
visitor.subtract( lhs, rhs );
}
};
}
Expand Down
Expand Up @@ -228,24 +228,7 @@ public void lte( Expression lhs, Expression rhs )
}

@Override
public void subtractInts( Expression lhs, Expression rhs )
{
sub( lhs, rhs );
}

@Override
public void subtractLongs( Expression lhs, Expression rhs )
{
sub( lhs, rhs );
}

@Override
public void subtractDoubles( Expression lhs, Expression rhs )
{
sub( lhs, rhs );
}

private void sub( Expression lhs, Expression rhs )
public void subtract( Expression lhs, Expression rhs )
{
result.append( "sub(" );
lhs.accept( this );
Expand Down
Expand Up @@ -61,11 +61,7 @@ public interface ExpressionVisitor

void lte( Expression lhs, Expression rhs );

void subtractInts( Expression lhs, Expression rhs );

void subtractLongs( Expression lhs, Expression rhs );

void subtractDoubles( Expression lhs, Expression rhs );
void subtract( Expression lhs, Expression rhs );

void multiplyLongs( Expression lhs, Expression rhs );

Expand Down
Expand Up @@ -55,6 +55,7 @@
import static org.objectweb.asm.Opcodes.FCMPG;
import static org.objectweb.asm.Opcodes.FCMPL;
import static org.objectweb.asm.Opcodes.FLOAD;
import static org.objectweb.asm.Opcodes.FSUB;
import static org.objectweb.asm.Opcodes.GETFIELD;
import static org.objectweb.asm.Opcodes.GETSTATIC;
import static org.objectweb.asm.Opcodes.GOTO;
Expand Down Expand Up @@ -424,7 +425,7 @@ public void lt( Expression lhs, Expression rhs )
}

@Override
public void lte( Expression lhs, Expression rhs)
public void lte( Expression lhs, Expression rhs )
{
assertSameType( lhs, rhs, "compare" );
numberOperation( lhs.type(),
Expand All @@ -436,27 +437,15 @@ public void lte( Expression lhs, Expression rhs)
}

@Override
public void subtractInts( Expression lhs, Expression rhs )
public void subtract( Expression lhs, Expression rhs )
{
lhs.accept( this );
rhs.accept( this );
methodVisitor.visitInsn( ISUB );
}

@Override
public void subtractLongs( Expression lhs, Expression rhs )
{
lhs.accept( this );
rhs.accept( this );
methodVisitor.visitInsn( LSUB );
}

@Override
public void subtractDoubles( Expression lhs, Expression rhs )
{
lhs.accept( this );
rhs.accept( this );
methodVisitor.visitInsn( DSUB );
numberOperation( lhs.type(),
() -> methodVisitor.visitInsn( ISUB ),
() -> methodVisitor.visitInsn( LSUB ),
() -> methodVisitor.visitInsn( FSUB ),
() -> methodVisitor.visitInsn( DSUB ) );
}

@Override
Expand Down Expand Up @@ -680,7 +669,7 @@ private void assertSameType( Expression lhs, Expression rhs, String operation )
{
if ( !lhs.type().equals( rhs.type() ) )
{
throw new IllegalArgumentException( String.format( "Can only %s values of the same type", operation ));
throw new IllegalArgumentException( String.format( "Can only %s values of the same type", operation ) );
}
}

Expand Down
Expand Up @@ -418,24 +418,7 @@ public void lte( Expression lhs, Expression rhs )
}

@Override
public void subtractInts( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

@Override
public void subtractLongs( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

@Override
public void subtractDoubles( Expression lhs, Expression rhs )
{
sub( lhs, rhs);
}

private void sub( Expression lhs, Expression rhs )
public void subtract( Expression lhs, Expression rhs )
{
lhs.accept( this );
append( " - " );
Expand Down
Expand Up @@ -64,8 +64,7 @@
import static org.neo4j.codegen.Expression.newInstance;
import static org.neo4j.codegen.Expression.not;
import static org.neo4j.codegen.Expression.or;
import static org.neo4j.codegen.Expression.subtractDoubles;
import static org.neo4j.codegen.Expression.subtractLongs;
import static org.neo4j.codegen.Expression.subtract;
import static org.neo4j.codegen.Expression.ternary;
import static org.neo4j.codegen.ExpressionTemplate.cast;
import static org.neo4j.codegen.ExpressionTemplate.load;
Expand Down Expand Up @@ -1387,18 +1386,7 @@ private <T> T subtractForType( Class<T> clazz, T lhs, T rhs ) throws Throwable
try ( CodeBlock block = simple.generateMethod( clazz, "sub",
param( clazz, "a" ), param( clazz, "b" ) ) )
{
if ( clazz == long.class )
{
block.returns( subtractLongs( block.load( "a" ), block.load( "b" ) ) );
}
else if ( clazz == double.class )
{
block.returns( subtractDoubles( block.load( "a" ), block.load( "b" ) ) );
}
else
{
fail( "adding " + clazz.getSimpleName() + " is not supported" );
}
block.returns( subtract( block.load( "a" ), block.load( "b" ) ) );
}

handle = simple.handle();
Expand Down
Expand Up @@ -91,11 +91,11 @@ trait MethodStructure[E] {
def loadVariable(varName: String): E

// arithmetic
def add(lhs: E, rhs: E): E
def subtract(lhs: E, rhs: E): E
def multiply(lhs: E, rhs: E): E
def divide(lhs: E, rhs: E): E
def modulus(lhs: E, rhs: E): E
def addExpression(lhs: E, rhs: E): E
def subtractExpression(lhs: E, rhs: E): E
def multiplyExpression(lhs: E, rhs: E): E
def divideExpression(lhs: E, rhs: E): E
def modulusExpression(lhs: E, rhs: E): E

// predicates
def threeValuedNotExpression(value: E): E
Expand Down
Expand Up @@ -24,7 +24,7 @@ import org.neo4j.cypher.internal.frontend.v3_2.symbols._

case class Addition(lhs: CodeGenExpression, rhs: CodeGenExpression) extends CodeGenExpression with BinaryOperator {

override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.add
override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.addExpression

override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable

Expand Down
Expand Up @@ -25,7 +25,7 @@ case class Division(lhs: CodeGenExpression, rhs: CodeGenExpression)
extends CodeGenExpression with BinaryOperator{

override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) =
structure.divide
structure.divideExpression

override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable

Expand Down
Expand Up @@ -24,7 +24,7 @@ import org.neo4j.cypher.internal.frontend.v3_2.symbols._

case class Modulo(lhs: CodeGenExpression, rhs: CodeGenExpression) extends CodeGenExpression with BinaryOperator {

override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.modulus
override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.modulusExpression
override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable

override def codeGenType(implicit context: CodeGenContext) = CodeGenType(CTFloat, ReferenceType)
Expand Down
Expand Up @@ -26,7 +26,7 @@ case class Multiplication(lhs: CodeGenExpression, rhs: CodeGenExpression)

override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable

override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.multiply
override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) = structure.multiplyExpression

override def name: String = "multiply"
}
Expand Up @@ -25,7 +25,7 @@ case class Subtraction(lhs: CodeGenExpression, rhs: CodeGenExpression)
extends CodeGenExpression with BinaryOperator {

override protected def generator[E](structure: MethodStructure[E])(implicit context: CodeGenContext) =
structure.subtract
structure.subtractExpression

override def nullable(implicit context: CodeGenContext) = lhs.nullable || rhs.nullable

Expand Down
Expand Up @@ -189,8 +189,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A

override def decreaseCounterAndCheckForZero(name: String): Expression = {
val local = locals(name)
generator.assign(local, subtractInts(local, constant(1)))

generator.assign(local, subtract(local, constant(1)))
equal(constant(0), local)
}

Expand Down Expand Up @@ -468,8 +467,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A
invoke(generator.load(tableVar), countingTablePut, generator.load(keyVar),
ternary(
equal(generator.load(countName), get(staticField[LongKeyIntValueTable, Int]("NULL"))),
constant(1),
Expression.add(generator.load(countName), constant(1))))))
constant(1), add(generator.load(countName), constant(1))))))

case LongsToCountTable =>
val countName = context.namer.newVarName()
Expand All @@ -489,7 +487,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A
ternaryOnNull(generator.load(countName),
invoke(boxInteger,
constant(1)), invoke(boxInteger,
Expression.add(
add(
invoke(generator.load(countName),
unboxInteger),
constant(1)))))))
Expand All @@ -504,7 +502,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A
generator.assign(times, invoke(generator.load(tableVar), countingTableGet, generator.load(keyVar)))
using(generator.whileLoop(gt(times, constant(0)))) { body =>
block(copy(generator = body))
body.assign(times, subtractInts(times, constant(1)))
body.assign(times, subtract(times, constant(1)))
}
case LongsToCountTable =>
val times = generator.declare(typeRef[Int], context.namer.newVarName())
Expand All @@ -524,7 +522,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A

using(generator.whileLoop(gt(times, constant(0)))) { body =>
block(copy(generator = body))
body.assign(times, subtractInts(times, constant(1)))
body.assign(times, subtract(times, constant(1)))
}

case tableType@LongToListTable(structure, localVars) =>
Expand Down
Expand Up @@ -189,7 +189,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A

override def decrementCounter(name: String) = {
val local = locals(name)
generator.assign(local, subtractInts(local, constant(1)))
generator.assign(local, subtract(local, constant(1)))
}

override def checkCounter(name: String, comparator: Comparator, value: Int): Expression = {
Expand Down Expand Up @@ -374,15 +374,15 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A

override def loadVariable(varName: String) = generator.load(varName)

override def add(lhs: Expression, rhs: Expression) = math(Methods.mathAdd, lhs, rhs)
override def addExpression(lhs: Expression, rhs: Expression) = math(Methods.mathAdd, lhs, rhs)

override def subtract(lhs: Expression, rhs: Expression) = math(Methods.mathSub, lhs, rhs)
override def subtractExpression(lhs: Expression, rhs: Expression) = math(Methods.mathSub, lhs, rhs)

override def multiply(lhs: Expression, rhs: Expression) = math(Methods.mathMul, lhs, rhs)
override def multiplyExpression(lhs: Expression, rhs: Expression) = math(Methods.mathMul, lhs, rhs)

override def divide(lhs: Expression, rhs: Expression) = math(Methods.mathDiv, lhs, rhs)
override def divideExpression(lhs: Expression, rhs: Expression) = math(Methods.mathDiv, lhs, rhs)

override def modulus(lhs: Expression, rhs: Expression) = math(Methods.mathMod, lhs, rhs)
override def modulusExpression(lhs: Expression, rhs: Expression) = math(Methods.mathMod, lhs, rhs)

private def math(method: MethodReference, lhs: Expression, rhs: Expression): Expression =
invoke(method, lhs, rhs)
Expand Down Expand Up @@ -472,8 +472,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A
invoke(generator.load(tableVar), countingTablePut, generator.load(keyVar),
ternary(
equal(generator.load(countName), get(staticField[LongKeyIntValueTable, Int]("NULL"))),
constant(1),
addInts(generator.load(countName), constant(1))))))
constant(1), add(generator.load(countName), constant(1))))))

case LongsToCountTable =>
val countName = context.namer.newVarName()
Expand All @@ -493,7 +492,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A
ternaryOnNull(generator.load(countName),
invoke(boxInteger,
constant(1)), invoke(boxInteger,
addInts(
add(
invoke(generator.load(countName),
unboxInteger),
constant(1)))))))
Expand All @@ -508,7 +507,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A
generator.assign(times, invoke(generator.load(tableVar), countingTableGet, generator.load(keyVar)))
using(generator.whileLoop(gt(times, constant(0)))) { body =>
block(copy(generator = body))
body.assign(times, subtractInts(times, constant(1)))
body.assign(times, subtract(times, constant(1)))
}
case LongsToCountTable =>
val times = generator.declare(typeRef[Int], context.namer.newVarName())
Expand All @@ -528,7 +527,7 @@ case class GeneratedMethodStructure(fields: Fields, generator: CodeBlock, aux: A

using(generator.whileLoop(gt(times, constant(0)))) { body =>
block(copy(generator = body))
body.assign(times, subtractInts(times, constant(1)))
body.assign(times, subtract(times, constant(1)))
}

case tableType@LongToListTable(structure, localVars) =>
Expand Down

0 comments on commit e43395c

Please sign in to comment.