-
Notifications
You must be signed in to change notification settings - Fork 2.3k
/
StringFunctions.scala
218 lines (164 loc) · 8.44 KB
/
StringFunctions.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
/*
* Copyright (c) 2002-2018 "Neo4j,"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.cypher.internal.runtime.interpreted.commands.expressions
import org.neo4j.cypher.internal.runtime.interpreted.ExecutionContext
import org.neo4j.cypher.internal.runtime.interpreted.pipes.QueryState
import org.neo4j.cypher.operations.CypherFunctions
import org.neo4j.values._
import org.neo4j.values.storable.Values.NO_VALUE
import org.neo4j.values.storable._
import org.neo4j.values.virtual.VirtualValues
import org.opencypher.v9_0.util.symbols._
import org.opencypher.v9_0.util.{CypherTypeException, ParameterWrongTypeException}
abstract class StringFunction(arg: Expression) extends NullInNullOutExpression(arg) {
def innerExpectedType = CTString
override def arguments = Seq(arg)
override def symbolTableDependencies = arg.symbolTableDependencies
}
object StringFunction {
def notAString(a: Any) = throw new CypherTypeException(
"Expected a string value for %s, but got: %s; consider converting it to a string with toString()."
.format(toString, a.toString))
}
case object asString extends (AnyValue => String) {
override def apply(a: AnyValue): String = a match {
case NO_VALUE => null
case x: TextValue => x.stringValue()
case _ => StringFunction.notAString(a)
}
}
case class ToStringFunction(argument: Expression) extends StringFunction(argument) {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = argument(m, state) match {
case v: IntegralValue => Values.stringValue(v.longValue().toString)
case v: FloatingPointValue => Values.stringValue(v.doubleValue().toString)
case v: TextValue => v
case v: BooleanValue => Values.stringValue(v.booleanValue().toString)
case v: TemporalValue[_,_] => Values.stringValue(v.toString)
case v: DurationValue => Values.stringValue(v.toString)
case v: PointValue => Values.stringValue(v.toString)
case v =>
throw new ParameterWrongTypeException("Expected a String, Number, Boolean, Temporal or Duration, got: " + v.toString)
}
override def rewrite(f: (Expression) => Expression): Expression = f(ToStringFunction(argument.rewrite(f)))
}
case class ToLowerFunction(argument: Expression) extends StringFunction(argument) {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = value match {
case t: TextValue => t.toLower
case _ => StringFunction.notAString(value)
}
override def rewrite(f: (Expression) => Expression) = f(ToLowerFunction(argument.rewrite(f)))
}
case class ToUpperFunction(argument: Expression) extends StringFunction(argument) {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue =value match {
case t: TextValue => t.toUpper
case _ => StringFunction.notAString(value)
}
override def rewrite(f: (Expression) => Expression) = f(ToUpperFunction(argument.rewrite(f)))
}
case class LTrimFunction(argument: Expression) extends StringFunction(argument) {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue =
CypherFunctions.ltrim(value)
override def rewrite(f: (Expression) => Expression) = f(LTrimFunction(argument.rewrite(f)))
}
case class RTrimFunction(argument: Expression) extends StringFunction(argument) {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue =
CypherFunctions.rtrim(value)
override def rewrite(f: (Expression) => Expression) = f(RTrimFunction(argument.rewrite(f)))
}
case class TrimFunction(argument: Expression) extends StringFunction(argument) {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue =
CypherFunctions.trim(value)
override def rewrite(f: (Expression) => Expression) = f(TrimFunction(argument.rewrite(f)))
}
case class SubstringFunction(orig: Expression, start: Expression, length: Option[Expression])
extends NullInNullOutExpression(orig) with NumericHelper {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = value match {
case text: TextValue =>
val startVal = asInt(start(m, state)).value()
length match {
case None => text.substring(startVal)
case Some(func) => text.substring(startVal, asInt(func(m, state)).value())
}
case _ => StringFunction.notAString(value)
}
override def arguments = Seq(orig, start) ++ length
override def rewrite(f: (Expression) => Expression) = f(
SubstringFunction(orig.rewrite(f), start.rewrite(f), length.map(_.rewrite(f))))
override def symbolTableDependencies = {
val a = orig.symbolTableDependencies ++
start.symbolTableDependencies
val b = length.toIndexedSeq.flatMap(_.symbolTableDependencies.toIndexedSeq).toSet
a ++ b
}
}
case class ReplaceFunction(orig: Expression, search: Expression, replaceWith: Expression)
extends NullInNullOutExpression(orig) {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = value match {
case t: TextValue =>
val searchVal = asString(search(m, state))
val replaceWithVal = asString(replaceWith(m, state))
if (searchVal == null || replaceWithVal == null) NO_VALUE else t.replace(searchVal, replaceWithVal)
case _ => StringFunction.notAString(value)
}
override def arguments = Seq(orig, search, replaceWith)
override def rewrite(f: (Expression) => Expression) = f(
ReplaceFunction(orig.rewrite(f), search.rewrite(f), replaceWith.rewrite(f)))
override def symbolTableDependencies = orig.symbolTableDependencies ++
search.symbolTableDependencies ++
replaceWith.symbolTableDependencies
}
case class SplitFunction(orig: Expression, separator: Expression)
extends NullInNullOutExpression(orig) {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = value match {
case t: TextValue if t.length() == 0 => VirtualValues.list(Values.EMPTY_STRING)
case t: TextValue =>
val separatorVal = asString(separator(m, state))
if (separatorVal == null) NO_VALUE else t.split(separatorVal)
case _ => StringFunction.notAString(value)
}
override def arguments = Seq(orig, separator)
override def rewrite(f: (Expression) => Expression) = f(SplitFunction(orig.rewrite(f), separator.rewrite(f)))
override def symbolTableDependencies = orig.symbolTableDependencies ++ separator.symbolTableDependencies
}
case class LeftFunction(orig: Expression, length: Expression)
extends NullInNullOutExpression(orig) with NumericHelper {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue =
CypherFunctions.left(value, length(m, state))
override def arguments = Seq(orig, length)
override def rewrite(f: (Expression) => Expression) = f(LeftFunction(orig.rewrite(f), length.rewrite(f)))
override def symbolTableDependencies = orig.symbolTableDependencies ++
length.symbolTableDependencies
}
case class RightFunction(orig: Expression, length: Expression)
extends NullInNullOutExpression(orig) with NumericHelper {
override def compute(value: AnyValue, m: ExecutionContext, state: QueryState): AnyValue = value match {
case origVal: TextValue =>
// if length goes off the end of the string, let's be nice and handle that.
val lengthVal = asInt(length(m, state)).value()
if (lengthVal < 0) throw new IndexOutOfBoundsException(s"negative length")
val startVal = origVal.length - lengthVal
origVal.substring(Math.max(0,startVal))
case _ => StringFunction.notAString(value)
}
override def arguments = Seq(orig, length)
override def rewrite(f: (Expression) => Expression) = f(RightFunction(orig.rewrite(f), length.rewrite(f)))
override def symbolTableDependencies = orig.symbolTableDependencies ++
length.symbolTableDependencies
}