From 3d57f34c152b22cb2742b87af20205ad9b6e5776 Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Mon, 4 Mar 2024 22:43:25 +0800 Subject: [PATCH] [SPARK-47131][SQL][COLLATION] String function support: contains, startswith, endswith ### What changes were proposed in this pull request? Refactor built-in string functions to support collation for: contains, startsWith, endsWith. ### Why are the changes needed? Add collation support for built-in string functions in Spark. ### Does this PR introduce _any_ user-facing change? Yes, users should now be able to use COLLATE within arguments for built-in string functions: CONTAINS, STARTSWITH, ENDSWITH in Spark SQL queries. ### How was this patch tested? Unit tests for: - string expressions (StringExpressionsSuite) - queries using "collate" (CollationSuite) ### Was this patch authored or co-authored using generative AI tooling? Yes. Closes #45216 from uros-db/string-functions. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Wenchen Fan --- .../sql/catalyst/util/CollationFactory.java | 3 + .../apache/spark/unsafe/types/UTF8String.java | 44 ++++ .../main/resources/error/error-classes.json | 18 ++ ...onditions-datatype-mismatch-error-class.md | 4 + ...tions-unsupported-collation-error-class.md | 37 +++ docs/sql-error-conditions.md | 8 + .../apache/spark/sql/types/StringType.scala | 3 +- .../expressions/stringExpressions.scala | 89 +++++++- .../sql/catalyst/types/PhysicalDataType.scala | 4 +- .../sql/execution/columnar/ColumnType.scala | 3 +- .../org/apache/spark/sql/CollationSuite.scala | 216 +++++++++++++++++- 11 files changed, 411 insertions(+), 18 deletions(-) create mode 100644 docs/sql-error-conditions-unsupported-collation-error-class.md diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index 83cac849e848b..c0c011926be9c 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -105,6 +105,9 @@ public Collation( private static final Collation[] collationTable = new Collation[4]; private static final HashMap collationNameToIdMap = new HashMap<>(); + public static final int DEFAULT_COLLATION_ID = 0; + public static final int LOWERCASE_COLLATION_ID = 1; + static { // Binary comparison. This is the default collation. // No custom comparators will be used for this collation. diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index bb794446472fe..4b9c18010162a 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -22,6 +22,7 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.Arrays; +import java.util.HashMap; import java.util.Map; import java.util.regex.Pattern; @@ -30,6 +31,7 @@ import com.esotericsoftware.kryo.io.Input; import com.esotericsoftware.kryo.io.Output; +import org.apache.spark.SparkException; import org.apache.spark.sql.catalyst.util.CollationFactory; import org.apache.spark.unsafe.Platform; import org.apache.spark.unsafe.UTF8StringBuilder; @@ -341,6 +343,21 @@ public boolean contains(final UTF8String substring) { return false; } + public boolean contains(final UTF8String substring, int collationId) throws SparkException { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + return this.contains(substring); + } + if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + return this.toLowerCase().contains(substring.toLowerCase()); + } + // TODO: enable ICU collation support for "contains" (SPARK-47248) + Map params = new HashMap<>(); + params.put("functionName", "contains"); + params.put("collationName", CollationFactory.fetchCollation(collationId).collationName); + throw new SparkException("UNSUPPORTED_COLLATION.FOR_FUNCTION", + SparkException.constructMessageParams(params), null); + } + /** * Returns the byte at position `i`. */ @@ -355,14 +372,41 @@ public boolean matchAt(final UTF8String s, int pos) { return ByteArrayMethods.arrayEquals(base, offset + pos, s.base, s.offset, s.numBytes); } + private boolean matchAt(final UTF8String s, int pos, int collationId) { + if (s.numBytes + pos > numBytes || pos < 0) { + return false; + } + return this.substring(pos, pos + s.numBytes).semanticCompare(s, collationId) == 0; + } + public boolean startsWith(final UTF8String prefix) { return matchAt(prefix, 0); } + public boolean startsWith(final UTF8String prefix, int collationId) { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + return this.startsWith(prefix); + } + if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + return this.toLowerCase().startsWith(prefix.toLowerCase()); + } + return matchAt(prefix, 0, collationId); + } + public boolean endsWith(final UTF8String suffix) { return matchAt(suffix, numBytes - suffix.numBytes); } + public boolean endsWith(final UTF8String suffix, int collationId) { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + return this.endsWith(suffix); + } + if (collationId == CollationFactory.LOWERCASE_COLLATION_ID) { + return this.toLowerCase().endsWith(suffix.toLowerCase()); + } + return matchAt(suffix, numBytes - suffix.numBytes, collationId); + } + /** * Returns the upper case of this string */ diff --git a/common/utils/src/main/resources/error/error-classes.json b/common/utils/src/main/resources/error/error-classes.json index 35e4feaf90b0a..493635d1f8d3f 100644 --- a/common/utils/src/main/resources/error/error-classes.json +++ b/common/utils/src/main/resources/error/error-classes.json @@ -696,6 +696,11 @@ "To convert values from to , you can use the functions instead." ] }, + "COLLATION_MISMATCH" : { + "message" : [ + "Collations and are not compatible. Please use the same collation for both strings." + ] + }, "CREATE_MAP_KEY_DIFF_TYPES" : { "message" : [ "The given keys of function should all be the same type, but they are ." @@ -3768,6 +3773,19 @@ ], "sqlState" : "0A000" }, + "UNSUPPORTED_COLLATION" : { + "message" : [ + "Collation is not supported for:" + ], + "subClass" : { + "FOR_FUNCTION" : { + "message" : [ + "function . Please try to use a different collation." + ] + } + }, + "sqlState" : "0A000" + }, "UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY" : { "message" : [ "Unsupported data source type for direct query on files: " diff --git a/docs/sql-error-conditions-datatype-mismatch-error-class.md b/docs/sql-error-conditions-datatype-mismatch-error-class.md index 1d18836ac9e77..cd7feb9262f3a 100644 --- a/docs/sql-error-conditions-datatype-mismatch-error-class.md +++ b/docs/sql-error-conditions-datatype-mismatch-error-class.md @@ -76,6 +76,10 @@ If you have to cast `` to ``, you can set `` as `` to ``. To convert values from `` to ``, you can use the functions `` instead. +## COLLATION_MISMATCH + +Collations `` and `` are not compatible. Please use the same collation for both strings. + ## CREATE_MAP_KEY_DIFF_TYPES The given keys of function `` should all be the same type, but they are ``. diff --git a/docs/sql-error-conditions-unsupported-collation-error-class.md b/docs/sql-error-conditions-unsupported-collation-error-class.md new file mode 100644 index 0000000000000..ae410a30317a1 --- /dev/null +++ b/docs/sql-error-conditions-unsupported-collation-error-class.md @@ -0,0 +1,37 @@ +--- +layout: global +title: UNSUPPORTED_COLLATION error class +displayTitle: UNSUPPORTED_COLLATION error class +license: | + Licensed to the Apache Software Foundation (ASF) under one or more + contributor license agreements. See the NOTICE file distributed with + this work for additional information regarding copyright ownership. + The ASF licenses this file to You under the Apache License, Version 2.0 + (the "License"); you may not use this file except in compliance with + the License. You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +--- + + + +[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported) + +Collation `` is not supported for: + +This error class has the following derived error classes: + +## FOR_FUNCTION + +function ``. Please try to use a different collation. + + diff --git a/docs/sql-error-conditions.md b/docs/sql-error-conditions.md index fbfea9bf57653..510f56f413c6c 100644 --- a/docs/sql-error-conditions.md +++ b/docs/sql-error-conditions.md @@ -2462,6 +2462,14 @@ For more details see [UNSUPPORTED_CALL](sql-error-conditions-unsupported-call-er The char/varchar type can't be used in the table schema. If you want Spark treat them as string type as same as Spark 3.0 and earlier, please set "spark.sql.legacy.charVarcharAsString" to "true". +### [UNSUPPORTED_COLLATION](sql-error-conditions-unsupported-collation-error-class.html) + +[SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported) + +Collation `` is not supported for: + +For more details see [UNSUPPORTED_COLLATION](sql-error-conditions-unsupported-collation-error-class.html) + ### UNSUPPORTED_DATASOURCE_FOR_DIRECT_QUERY [SQLSTATE: 0A000](sql-error-conditions-sqlstates.html#class-0A-feature-not-supported) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala index 3026139161cf7..313f525742ae9 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StringType.scala @@ -31,7 +31,7 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa /** * Returns whether assigned collation is the default spark collation (UCS_BASIC). */ - def isDefaultCollation: Boolean = collationId == StringType.DEFAULT_COLLATION_ID + def isDefaultCollation: Boolean = collationId == CollationFactory.DEFAULT_COLLATION_ID /** * Binary collation implies that strings are considered equal only if they are @@ -69,6 +69,5 @@ class StringType private(val collationId: Int) extends AtomicType with Serializa */ @Stable case object StringType extends StringType(0) { - val DEFAULT_COLLATION_ID = 0 def apply(collationId: Int): StringType = new StringType(collationId) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 82b5f628578e7..e6114ca277cad 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -24,7 +24,7 @@ import java.util.{HashMap, Locale, Map => JMap} import scala.collection.mutable.ArrayBuffer -import org.apache.spark.QueryContext +import org.apache.spark.{QueryContext, SparkException} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, FunctionRegistry, TypeCheckResult} import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.expressions.objects.StaticInvoke import org.apache.spark.sql.catalyst.trees.BinaryLike import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UPPER_OR_LOWER} -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData, TypeUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, TypeUtils} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -497,10 +497,32 @@ case class Lower(child: Expression) abstract class StringPredicate extends BinaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { + final lazy val collationId: Int = left.dataType.asInstanceOf[StringType].collationId + def compare(l: UTF8String, r: UTF8String): Boolean override def inputTypes: Seq[DataType] = Seq(StringType, StringType) + override def checkInputDataTypes(): TypeCheckResult = { + val checkResult = super.checkInputDataTypes() + if (checkResult.isFailure) { + return checkResult + } + // Additional check needed for collation compatibility + val rightCollationId: Int = right.dataType.asInstanceOf[StringType].collationId + if (collationId != rightCollationId) { + DataTypeMismatch( + errorSubClass = "COLLATION_MISMATCH", + messageParameters = Map( + "collationNameLeft" -> CollationFactory.fetchCollation(collationId).collationName, + "collationNameRight" -> CollationFactory.fetchCollation(rightCollationId).collationName + ) + ) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + protected override def nullSafeEval(input1: Any, input2: Any): Any = compare(input1.asInstanceOf[UTF8String], input2.asInstanceOf[UTF8String]) @@ -586,9 +608,38 @@ object ContainsExpressionBuilder extends StringBinaryPredicateExpressionBuilderB } case class Contains(left: Expression, right: Expression) extends StringPredicate { - override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) + override def checkInputDataTypes(): TypeCheckResult = { + val checkResult = super.checkInputDataTypes() + if (checkResult.isFailure) { + return checkResult + } + // Additional check needed for collation support + if (!CollationFactory.fetchCollation(collationId).isBinaryCollation + && collationId != CollationFactory.LOWERCASE_COLLATION_ID) { + throw new SparkException( + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + messageParameters = Map( + "functionName" -> "contains", + "collationName" -> CollationFactory.fetchCollation(collationId).collationName), + cause = null + ) + } else { + TypeCheckResult.TypeCheckSuccess + } + } + override def compare(l: UTF8String, r: UTF8String): Boolean = { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + l.contains(r) + } else { + l.contains(r, collationId) + } + } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2)") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"$c1.contains($c2, $collationId)") + } } override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): Contains = copy(left = newLeft, right = newRight) @@ -623,9 +674,20 @@ object StartsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilde } case class StartsWith(left: Expression, right: Expression) extends StringPredicate { - override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + l.startsWith(r) + } else { + l.startsWith(r, collationId) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2)") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"$c1.startsWith($c2, $collationId)") + } } override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): StartsWith = copy(left = newLeft, right = newRight) @@ -660,9 +722,20 @@ object EndsWithExpressionBuilder extends StringBinaryPredicateExpressionBuilderB } case class EndsWith(left: Expression, right: Expression) extends StringPredicate { - override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) + override def compare(l: UTF8String, r: UTF8String): Boolean = { + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + l.endsWith(r) + } else { + l.endsWith(r, collationId) + } + } + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") + if (CollationFactory.fetchCollation(collationId).isBinaryCollation) { + defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2)") + } else { + defineCodeGen(ctx, ev, (c1, c2) => s"$c1.endsWith($c2, $collationId)") + } } override protected def withNewChildrenInternal( newLeft: Expression, newRight: Expression): EndsWith = copy(left = newLeft, right = newRight) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala index cc8008a9e11c4..0b0c36b27e71c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/PhysicalDataType.scala @@ -40,8 +40,8 @@ object PhysicalDataType { case ShortType => PhysicalShortType case IntegerType => PhysicalIntegerType case LongType => PhysicalLongType - case VarcharType(_) => PhysicalStringType(StringType.DEFAULT_COLLATION_ID) - case CharType(_) => PhysicalStringType(StringType.DEFAULT_COLLATION_ID) + case VarcharType(_) => PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID) + case CharType(_) => PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID) case s: StringType => PhysicalStringType(s.collationId) case FloatType => PhysicalFloatType case DoubleType => PhysicalDoubleType diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala index 06a9fe2b0b5b8..ccabaf8d8b120 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnType.scala @@ -25,6 +25,7 @@ import scala.annotation.tailrec import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.types.{PhysicalArrayType, PhysicalBinaryType, PhysicalBooleanType, PhysicalByteType, PhysicalCalendarIntervalType, PhysicalDataType, PhysicalDecimalType, PhysicalDoubleType, PhysicalFloatType, PhysicalIntegerType, PhysicalLongType, PhysicalMapType, PhysicalNullType, PhysicalShortType, PhysicalStringType, PhysicalStructType} +import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ import org.apache.spark.unsafe.Platform @@ -492,7 +493,7 @@ private[columnar] trait DirectCopyColumnType[JvmType] extends ColumnType[JvmType } private[columnar] object STRING - extends NativeColumnType(PhysicalStringType(StringType.DEFAULT_COLLATION_ID), 8) + extends NativeColumnType(PhysicalStringType(CollationFactory.DEFAULT_COLLATION_ID), 8) with DirectCopyColumnType[UTF8String] { override def actualSize(row: InternalRow, ordinal: Int): Int = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 1c8a2b2495172..f68085f803522 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -185,6 +185,211 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("checkCollation throws exception for incompatible collationIds") { + val left: String = "abc" // collate with 'UNICODE_CI' + val leftCollationName: String = "UNICODE_CI"; + var right: String = null // collate with 'UNICODE' + val rightCollationName: String = "UNICODE"; + // contains + right = left.substring(1, 2); + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT contains(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"$leftCollationName", + "collationNameRight" -> s"$rightCollationName", + "sqlExpr" -> "\"contains(collate(abc), collate(b))\"" + ), + context = ExpectedContext(fragment = + s"contains(collate('abc', 'UNICODE_CI'),collate('b', 'UNICODE'))", + start = 7, stop = 68) + ) + // startsWith + right = left.substring(0, 1); + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT startsWith(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"$leftCollationName", + "collationNameRight" -> s"$rightCollationName", + "sqlExpr" -> "\"startswith(collate(abc), collate(a))\"" + ), + context = ExpectedContext(fragment = + s"startsWith(collate('abc', 'UNICODE_CI'),collate('a', 'UNICODE'))", + start = 7, stop = 70) + ) + // endsWith + right = left.substring(2, 3); + checkError( + exception = intercept[ExtendedAnalysisException] { + spark.sql(s"SELECT endsWith(collate('$left', '$leftCollationName')," + + s"collate('$right', '$rightCollationName'))") + }, + errorClass = "DATATYPE_MISMATCH.COLLATION_MISMATCH", + sqlState = "42K09", + parameters = Map( + "collationNameLeft" -> s"$leftCollationName", + "collationNameRight" -> s"$rightCollationName", + "sqlExpr" -> "\"endswith(collate(abc), collate(c))\"" + ), + context = ExpectedContext(fragment = + s"endsWith(collate('abc', 'UNICODE_CI'),collate('c', 'UNICODE'))", + start = 7, stop = 68) + ) + } + + case class CollationTestCase[R](left: String, right: String, collation: String, expectedResult: R) + + test("Support contains string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("", "", "UCS_BASIC", true), + CollationTestCase("c", "", "UCS_BASIC", true), + CollationTestCase("", "c", "UCS_BASIC", false), + CollationTestCase("abcde", "c", "UCS_BASIC", true), + CollationTestCase("abcde", "C", "UCS_BASIC", false), + CollationTestCase("abcde", "bcd", "UCS_BASIC", true), + CollationTestCase("abcde", "BCD", "UCS_BASIC", false), + CollationTestCase("abcde", "fgh", "UCS_BASIC", false), + CollationTestCase("abcde", "FGH", "UCS_BASIC", false), + CollationTestCase("", "", "UNICODE", true), + CollationTestCase("c", "", "UNICODE", true), + CollationTestCase("", "c", "UNICODE", false), + CollationTestCase("abcde", "c", "UNICODE", true), + CollationTestCase("abcde", "C", "UNICODE", false), + CollationTestCase("abcde", "bcd", "UNICODE", true), + CollationTestCase("abcde", "BCD", "UNICODE", false), + CollationTestCase("abcde", "fgh", "UNICODE", false), + CollationTestCase("abcde", "FGH", "UNICODE", false), + CollationTestCase("", "", "UCS_BASIC_LCASE", true), + CollationTestCase("c", "", "UCS_BASIC_LCASE", true), + CollationTestCase("", "c", "UCS_BASIC_LCASE", false), + CollationTestCase("abcde", "c", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "C", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "bcd", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "BCD", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "fgh", "UCS_BASIC_LCASE", false), + CollationTestCase("abcde", "FGH", "UCS_BASIC_LCASE", false) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT contains(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + // Unsupported collations + checkError( + exception = intercept[SparkException] { + sql(s"SELECT contains(collate('abcde', 'UNICODE_CI')," + + s"collate('BCD', 'UNICODE_CI'))") + }, + errorClass = "UNSUPPORTED_COLLATION.FOR_FUNCTION", + sqlState = "0A000", + parameters = Map( + "functionName" -> "contains", + "collationName" -> "UNICODE_CI" + ) + ) + } + + test("Support startsWith string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("", "", "UCS_BASIC", true), + CollationTestCase("c", "", "UCS_BASIC", true), + CollationTestCase("", "c", "UCS_BASIC", false), + CollationTestCase("abcde", "a", "UCS_BASIC", true), + CollationTestCase("abcde", "A", "UCS_BASIC", false), + CollationTestCase("abcde", "abc", "UCS_BASIC", true), + CollationTestCase("abcde", "ABC", "UCS_BASIC", false), + CollationTestCase("abcde", "bcd", "UCS_BASIC", false), + CollationTestCase("abcde", "BCD", "UCS_BASIC", false), + CollationTestCase("", "", "UNICODE", true), + CollationTestCase("c", "", "UNICODE", true), + CollationTestCase("", "c", "UNICODE", false), + CollationTestCase("abcde", "a", "UNICODE", true), + CollationTestCase("abcde", "A", "UNICODE", false), + CollationTestCase("abcde", "abc", "UNICODE", true), + CollationTestCase("abcde", "ABC", "UNICODE", false), + CollationTestCase("abcde", "bcd", "UNICODE", false), + CollationTestCase("abcde", "BCD", "UNICODE", false), + CollationTestCase("", "", "UCS_BASIC_LCASE", true), + CollationTestCase("c", "", "UCS_BASIC_LCASE", true), + CollationTestCase("", "c", "UCS_BASIC_LCASE", false), + CollationTestCase("abcde", "a", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "A", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "abc", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "ABC", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "bcd", "UCS_BASIC_LCASE", false), + CollationTestCase("abcde", "BCD", "UCS_BASIC_LCASE", false), + CollationTestCase("", "", "UNICODE_CI", true), + CollationTestCase("c", "", "UNICODE_CI", true), + CollationTestCase("", "c", "UNICODE_CI", false), + CollationTestCase("abcde", "a", "UNICODE_CI", true), + CollationTestCase("abcde", "A", "UNICODE_CI", true), + CollationTestCase("abcde", "abc", "UNICODE_CI", true), + CollationTestCase("abcde", "ABC", "UNICODE_CI", true), + CollationTestCase("abcde", "bcd", "UNICODE_CI", false), + CollationTestCase("abcde", "BCD", "UNICODE_CI", false) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT startswith(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + } + + test("Support endsWith string expression with Collation") { + // Supported collations + val checks = Seq( + CollationTestCase("", "", "UCS_BASIC", true), + CollationTestCase("c", "", "UCS_BASIC", true), + CollationTestCase("", "c", "UCS_BASIC", false), + CollationTestCase("abcde", "e", "UCS_BASIC", true), + CollationTestCase("abcde", "E", "UCS_BASIC", false), + CollationTestCase("abcde", "cde", "UCS_BASIC", true), + CollationTestCase("abcde", "CDE", "UCS_BASIC", false), + CollationTestCase("abcde", "bcd", "UCS_BASIC", false), + CollationTestCase("abcde", "BCD", "UCS_BASIC", false), + CollationTestCase("", "", "UNICODE", true), + CollationTestCase("c", "", "UNICODE", true), + CollationTestCase("", "c", "UNICODE", false), + CollationTestCase("abcde", "e", "UNICODE", true), + CollationTestCase("abcde", "E", "UNICODE", false), + CollationTestCase("abcde", "cde", "UNICODE", true), + CollationTestCase("abcde", "CDE", "UNICODE", false), + CollationTestCase("abcde", "bcd", "UNICODE", false), + CollationTestCase("abcde", "BCD", "UNICODE", false), + CollationTestCase("", "", "UCS_BASIC_LCASE", true), + CollationTestCase("c", "", "UCS_BASIC_LCASE", true), + CollationTestCase("", "c", "UCS_BASIC_LCASE", false), + CollationTestCase("abcde", "e", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "E", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "cde", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "CDE", "UCS_BASIC_LCASE", true), + CollationTestCase("abcde", "bcd", "UCS_BASIC_LCASE", false), + CollationTestCase("abcde", "BCD", "UCS_BASIC_LCASE", false), + CollationTestCase("", "", "UNICODE_CI", true), + CollationTestCase("c", "", "UNICODE_CI", true), + CollationTestCase("", "c", "UNICODE_CI", false), + CollationTestCase("abcde", "e", "UNICODE_CI", true), + CollationTestCase("abcde", "E", "UNICODE_CI", true), + CollationTestCase("abcde", "cde", "UNICODE_CI", true), + CollationTestCase("abcde", "CDE", "UNICODE_CI", true), + CollationTestCase("abcde", "bcd", "UNICODE_CI", false), + CollationTestCase("abcde", "BCD", "UNICODE_CI", false) + ) + checks.foreach(testCase => { + checkAnswer(sql(s"SELECT endswith(collate('${testCase.left}', '${testCase.collation}')," + + s"collate('${testCase.right}', '${testCase.collation}'))"), Row(testCase.expectedResult)) + }) + } + test("aggregates count respects collation") { Seq( ("ucs_basic", Seq("AAA", "aaa"), Seq(Row(1, "AAA"), Row(1, "aaa"))), @@ -234,6 +439,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }.nonEmpty) } } + } test("create table with collation") { @@ -295,13 +501,13 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { sql(s"INSERT INTO $tableName VALUES ('AAA')") checkAnswer(sql(s"SELECT DISTINCT COLLATION(c1) FROM $tableName"), - Seq(Row(defaultCollation))) + Seq(Row(defaultCollation))) sql( - s""" - |ALTER TABLE $tableName - |ADD COLUMN c2 STRING COLLATE '$collationName' - |""".stripMargin) + s""" + |ALTER TABLE $tableName + |ADD COLUMN c2 STRING COLLATE '$collationName' + |""".stripMargin) sql(s"INSERT INTO $tableName VALUES ('aaa', 'aaa')") sql(s"INSERT INTO $tableName VALUES ('AAA', 'AAA')")