Skip to content

Commit

Permalink
Support for trim functions
Browse files Browse the repository at this point in the history
  • Loading branch information
pontusmelke committed Jul 9, 2018
1 parent 0ffcc95 commit bafd783
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 15 deletions.
Expand Up @@ -93,30 +93,24 @@ case class ToUpperFunction(argument: Expression) extends StringFunction(argument

case class LTrimFunction(argument: Expression) extends StringFunction(argument) {

override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = value match {
case t: TextValue => t.ltrim()
case _ => StringFunction.notAString(value)
}
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue =
CypherFunctions.ltrim(value)

override def rewrite(f: (Expression) => Expression) = f(LTrimFunction(argument.rewrite(f)))
}

case class RTrimFunction(argument: Expression) extends StringFunction(argument) {

override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = value match {
case t: TextValue => t.rtrim()
case _ => StringFunction.notAString(value)
}
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue =
CypherFunctions.rtrim(value)

override def rewrite(f: (Expression) => Expression) = f(RTrimFunction(argument.rewrite(f)))
}

case class TrimFunction(argument: Expression) extends StringFunction(argument) {

override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = value match {
case t: TextValue => t.trim()
case _ => StringFunction.notAString(value)
}
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue =
CypherFunctions.trim(value)

override def rewrite(f: (Expression) => Expression) = f(TrimFunction(argument.rewrite(f)))
}
Expand Down
Expand Up @@ -456,9 +456,43 @@ public static TextValue left( AnyValue in, AnyValue endPos )
}
else
{
throw new CypherTypeException(
format("Expected a string value for `left`, but got: %s; consider converting it to a string with toString().", in),
null);
throw notAString( "left", in );
}
}

public static TextValue ltrim( AnyValue in )
{
if ( in instanceof TextValue )
{
return ((TextValue) in).ltrim();
}
else
{
throw notAString( "ltrim", in );
}
}

public static TextValue rtrim( AnyValue in )
{
if ( in instanceof TextValue )
{
return ((TextValue) in).rtrim();
}
else
{
throw notAString( "rtrim", in );
}
}

public static TextValue trim( AnyValue in )
{
if ( in instanceof TextValue )
{
return ((TextValue) in).trim();
}
else
{
throw notAString( "trim", in );
}
}

Expand Down Expand Up @@ -565,4 +599,13 @@ private static CypherTypeException needsNumbers( String method )
{
return new CypherTypeException( format( "%s requires numbers", method ), null );
}

private static CypherTypeException notAString( String method, AnyValue in )
{
return new CypherTypeException(
format("Expected a string value for `%s`, but got: %s; consider converting it to a string with toString().",
method, in), null);
}


}
Expand Up @@ -484,6 +484,24 @@ class IntermediateCodeGeneration(slots: SlotConfiguration) {
in.ir, endPos.ir)), in.nullable)
}

case functions.LTrim =>
for (in <- compile(c.args.head)) yield {
IntermediateExpression(
noValueCheck(in)(invokeStatic(method[CypherFunctions, TextValue, AnyValue]("ltrim"), in.ir)), in.nullable)
}

case functions.RTrim =>
for (in <- compile(c.args.head)) yield {
IntermediateExpression(
noValueCheck(in)(invokeStatic(method[CypherFunctions, TextValue, AnyValue]("rtrim"), in.ir)), in.nullable)
}

case functions.Trim =>
for (in <- compile(c.args.head)) yield {
IntermediateExpression(
noValueCheck(in)(invokeStatic(method[CypherFunctions, TextValue, AnyValue]("trim"), in.ir)), in.nullable)
}

case _ => None
}

Expand Down
Expand Up @@ -308,6 +308,30 @@ class CodeGenerationTest extends CypherFunSuite with AstConstructionTestSupport
compiled.evaluate(ctx, db, map(Array("a", "b"), Array(NO_VALUE, intValue(4)))) should equal(NO_VALUE)
}

test("ltrim function") {
val compiled = compile(function("ltrim", parameter("a")))

compiled.evaluate(ctx, db, map(Array("a"), Array(stringValue(" HELLO ")))) should
equal(stringValue("HELLO "))
compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE)
}

test("rtrim function") {
val compiled = compile(function("rtrim", parameter("a")))

compiled.evaluate(ctx, db, map(Array("a"), Array(stringValue(" HELLO ")))) should
equal(stringValue(" HELLO"))
compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE)
}

test("trim function") {
val compiled = compile(function("trim", parameter("a")))

compiled.evaluate(ctx, db, map(Array("a"), Array(stringValue(" HELLO ")))) should
equal(stringValue("HELLO"))
compiled.evaluate(ctx, db, map(Array("a"), Array(NO_VALUE))) should equal(NO_VALUE)
}

test("id on node") {
val compiled = compile(function("id", parameter("a")))

Expand Down

0 comments on commit bafd783

Please sign in to comment.