Skip to content

Commit

Permalink
[SPARK-47131][SQL][COLLATION] String function support: contains, star…
Browse files Browse the repository at this point in the history
…tswith, endswith

### What changes were proposed in this pull request?
Refactor built-in string functions to support collation for: contains, startsWith, endsWith.

### Why are the changes needed?
Add collation support for built-in string functions in Spark.

### Does this PR introduce _any_ user-facing change?
Yes, users should now be able to use COLLATE within arguments for built-in string functions: CONTAINS, STARTSWITH, ENDSWITH in Spark SQL queries.

### How was this patch tested?
Unit tests for:
- string expressions (StringExpressionsSuite)
- queries using "collate" (CollationSuite)

### Was this patch authored or co-authored using generative AI tooling?
Yes.

Closes apache#45216 from uros-db/string-functions.

Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
uros-db authored and pull[bot] committed Apr 11, 2024
1 parent 3684800 commit 3d57f34
Show file tree
Hide file tree
Showing 11 changed files with 411 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ public Collation(
private static final Collation[] collationTable = new Collation[4];
private static final HashMap<String, Integer> collationNameToIdMap = new HashMap<>();

public static final int DEFAULT_COLLATION_ID = 0;
public static final int LOWERCASE_COLLATION_ID = 1;

static {
// Binary comparison. This is the default collation.
// No custom comparators will be used for this collation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.regex.Pattern;

Expand All @@ -30,6 +31,7 @@
import com.esotericsoftware.kryo.io.Input;
import com.esotericsoftware.kryo.io.Output;

import org.apache.spark.SparkException;
import org.apache.spark.sql.catalyst.util.CollationFactory;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.UTF8StringBuilder;
Expand Down Expand Up @@ -341,6 +343,21 @@ public boolean contains(final UTF8String substring) {
return false;
}

public boolean contains(final UTF8String substring, int collationId) throws SparkException {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.contains(substring);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().contains(substring.toLowerCase());
}
// TODO: enable ICU collation support for "contains" (SPARK-47248)
Map<String, String> params = new HashMap<>();
params.put("functionName", "contains");
params.put("collationName", CollationFactory.fetchCollation(collationId).collationName);
throw new SparkException("UNSUPPORTED_COLLATION.FOR_FUNCTION",
SparkException.constructMessageParams(params), null);
}

/**
* Returns the byte at position `i`.
*/
Expand All @@ -355,14 +372,41 @@ public boolean matchAt(final UTF8String s, int pos) {
return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes);
}

private boolean matchAt(final UTF8String s, int pos, int collationId) {
if (s.numBytes + pos > numBytes || pos < 0) {
return false;
}
return this.substring(pos, pos + s.numBytes).semanticCompare(s, collationId) == 0;
}

public boolean startsWith(final UTF8String prefix) {
return matchAt(prefix, 0);
}

public boolean startsWith(final UTF8String prefix, int collationId) {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.startsWith(prefix);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().startsWith(prefix.toLowerCase());
}
return matchAt(prefix, 0, collationId);
}

public boolean endsWith(final UTF8String suffix) {
return matchAt(suffix, numBytes - suffix.numBytes);
}

public boolean endsWith(final UTF8String suffix, int collationId) {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
return this.endsWith(suffix);
}
if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) {
return this.toLowerCase().endsWith(suffix.toLowerCase());
}
return matchAt(suffix, numBytes - suffix.numBytes, collationId);
}

/**
* Returns the upper case of this string
*/
Expand Down
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-classes.json
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,11 @@
"To convert values from <srcType> to <targetType>, you can use the functions <functionNames> instead."
]
},
"COLLATION_MISMATCH" : {
"message" : [
"Collations <collationNameLeft> and <collationNameRight> are not compatible. Please use the same collation for both strings."
]
},
"CREATE_MAP_KEY_DIFF_TYPES" : {
"message" : [
"The given keys of function <functionName> should all be the same type, but they are <dataType>."
Expand Down Expand Up @@ -3768,6 +3773,19 @@
],
"sqlState" : "0A000"
},
"UNSUPPORTED_COLLATION" : {
"message" : [
"Collation <collationName> is not supported for:"
],
"subClass" : {
"FOR_FUNCTION" : {
"message" : [
"function <functionName>. Please try to use a different collation."
]
}
},
"sqlState" : "0A000"
},
"UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY" : {
"message" : [
"Unsupported data source type for direct query on files: <dataSourceType>"
Expand Down
4 changes: 4 additions & 0 deletions docs/sql-error-conditions-datatype-mismatch-error-class.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ If you have to cast `<srcType>` to `<targetType>`, you can set `<config>` as `<c
cannot cast `<srcType>` to `<targetType>`.
To convert values from `<srcType>` to `<targetType>`, you can use the functions `<functionNames>` instead.

## COLLATION_MISMATCH

Collations `<collationNameLeft>` and `<collationNameRight>` are not compatible. Please use the same collation for both strings.

## CREATE_MAP_KEY_DIFF_TYPES

The given keys of function `<functionName>` should all be the same type, but they are `<dataType>`.
Expand Down
37 changes: 37 additions & 0 deletions docs/sql-error-conditions-unsupported-collation-error-class.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
---
layout: global
title: UNSUPPORTED_COLLATION error class
displayTitle: UNSUPPORTED_COLLATION error class
license: |
Licensed to the Apache Software Foundation (ASF) under one or more
contributor license agreements. See the NOTICE file distributed with
this work for additional information regarding copyright ownership.
The ASF licenses this file to You under the Apache License, Version 2.0
(the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
---

<!--
DO NOT EDIT THIS FILE.
It was generated automatically by `org.apache.spark.SparkThrowableSuite`.
-->

[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported)

Collation `<collationName>` is not supported for:

This error class has the following derived error classes:

## FOR_FUNCTION

function `<functionName>`. Please try to use a different collation.


8 changes: 8 additions & 0 deletions docs/sql-error-conditions.md
Original file line number Diff line number Diff line change
Expand Up @@ -2462,6 +2462,14 @@ For more details see [UNSUPPORTED_CALL](sql-error-conditions-unsupported-call-er
The char/varchar type can't be used in the table schema.
If you want Spark treat them as string type as same as Spark 3.0 and earlier, please set "spark.sql.legacy.charVarcharAsString" to "true".

### [UNSUPPORTED_COLLATION](sql-error-conditions-unsupported-collation-error-class.html)

[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported)

Collation `<collationName>` is not supported for:

For more details see [UNSUPPORTED_COLLATION](sql-error-conditions-unsupported-collation-error-class.html)

### UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY

[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
/**
* Returns whether assigned collation is the default spark collation (UCS_BASIC).
*/
def isDefaultCollation: Boolean = collationId == StringType.DEFAULT_COLLATION_ID
def isDefaultCollation: Boolean = collationId == CollationFactory.DEFAULT_COLLATION_ID

/**
* Binary collation implies that strings are considered equal only if they are
Expand Down Expand Up @@ -69,6 +69,5 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa
*/
@Stable
case object StringType extends StringType(0) {
val DEFAULT_COLLATION_ID = 0
def apply(collationId: Int): StringType = new StringType(collationId)
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import java.util.{HashMap, Locale, Map => JMap}

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.QueryContext
import org.apache.spark.{QueryContext, SparkException}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
Expand All @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke
import org.apache.spark.sql.catalyst.trees.BinaryLike
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER}
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, TypeUtils}
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
Expand Down Expand Up @@ -497,10 +497,32 @@ case class Lower(child: Expression)
abstract class StringPredicate extends BinaryExpression
with Predicate with ImplicitCastInputTypes with NullIntolerant {

final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId

def compare(l: UTF8String, r: UTF8String): Boolean

override def inputTypes: Seq[DataType] = Seq(StringType, StringType)

override def checkInputDataTypes(): TypeCheckResult = {
val checkResult = super.checkInputDataTypes()
if (checkResult.isFailure) {
return checkResult
}
// Additional check needed for collation compatibility
val rightCollationId: Int = right.dataType.asInstanceOf[StringType].collationId
if (collationId != rightCollationId) {
DataTypeMismatch(
errorSubClass = "COLLATION_MISMATCH",
messageParameters = Map(
"collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName,
"collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName
)
)
} else {
TypeCheckResult.TypeCheckSuccess
}
}

protected override def nullSafeEval(input1: Any, input2: Any): Any =
compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String])

Expand Down Expand Up @@ -586,9 +608,38 @@ object ContainsExpressionBuilder extends StringBinaryPredicateExpressionBuilderB
}

case class Contains(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r)
override def checkInputDataTypes(): TypeCheckResult = {
val checkResult = super.checkInputDataTypes()
if (checkResult.isFailure) {
return checkResult
}
// Additional check needed for collation support
if (!CollationFactory.fetchCollation(collationId).isBinaryCollation
&& collationId != CollationFactory.LOWERCASE_COLLATION_ID) {
throw new SparkException(
errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION",
messageParameters = Map(
"functionName" -> "contains",
"collationName" -> CollationFactory.fetchCollation(collationId).collationName),
cause = null
)
} else {
TypeCheckResult.TypeCheckSuccess
}
}
override def compare(l: UTF8String, r: UTF8String): Boolean = {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
l.contains(r)
} else {
l.contains(r, collationId)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)")
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2)")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2, $collationId)")
}
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight)
Expand Down Expand Up @@ -623,9 +674,20 @@ object StartsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilde
}

case class StartsWith(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r)
override def compare(l: UTF8String, r: UTF8String): Boolean = {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
l.startsWith(r)
} else {
l.startsWith(r, collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)")
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2)")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2, $collationId)")
}
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight)
Expand Down Expand Up @@ -660,9 +722,20 @@ object EndsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilderB
}

case class EndsWith(left: Expression, right: Expression) extends StringPredicate {
override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r)
override def compare(l: UTF8String, r: UTF8String): Boolean = {
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
l.endsWith(r)
} else {
l.endsWith(r, collationId)
}
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)")
if (CollationFactory.fetchCollation(collationId).isBinaryCollation) {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2)")
} else {
defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2, $collationId)")
}
}
override protected def withNewChildrenInternal(
newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ object PhysicalDataType {
case ShortType => PhysicalShortType
case IntegerType => PhysicalIntegerType
case LongType => PhysicalLongType
case VarcharType(_) => PhysicalStringType(StringType.DEFAULT_COLLATION_ID)
case CharType(_) => PhysicalStringType(StringType.DEFAULT_COLLATION_ID)
case VarcharType(_) => PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID)
case CharType(_) => PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID)
case s: StringType => PhysicalStringType(s.collationId)
case FloatType => PhysicalFloatType
case DoubleType => PhysicalDoubleType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.annotation.tailrec
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalBinaryType, PhysicalBooleanType, PhysicalByteType, PhysicalCalendarIntervalType, PhysicalDataType, PhysicalDecimalType, PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType, PhysicalMapType, PhysicalNullType, PhysicalShortType, PhysicalStringType, PhysicalStructType}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.errors.ExecutionErrors
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
Expand Down Expand Up @@ -492,7 +493,7 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType
}

private[columnar] object STRING
extends NativeColumnType(PhysicalStringType(StringType.DEFAULT_COLLATION_ID), 8)
extends NativeColumnType(PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID), 8)
with DirectCopyColumnType[UTF8String] {

override def actualSize(row: InternalRow, ordinal: Int): Int = {
Expand Down
Loading

0 comments on commit 3d57f34

Please sign in to comment.