Skip to content

Commit

Permalink
[SC-56391] Add char/varchar length check to Delta CONSTRAINT
Browse files Browse the repository at this point in the history
This PR adds the char/varchar type input string length check to Delta CONSTRAINT, so that it applies to all the write paths. This also introduces a reserved constraint name that end-users can't use: `__CHAR_VARCHAR_STRING_LENGTH_CHECK__`

This PR removes the char type padding from Delta INSERT, to be consistent with other write paths. We can add the padding back when we have the infra to do it in the future.

new tests

Before this PR, when writing data to delta tables via UPDATE/MERGE, no char/varchar check is done. The written data may exceed the char/varchar length limitation, which violates the char/varchar semantic. This PR fixes this issue by adding the length check in Delta CONSTRAINTS so that it applies to all the write paths. This PR also removes the char padding logic from Delta INSERT, to be consistent with other write paths that only do length check.

GitOrigin-RevId: 91e96ee9643bf1eba4d6ed159858d433e8357c95
  • Loading branch information
cloud-fan authored and Yaohua628 committed Oct 28, 2021
1 parent 0b6e439 commit 685820b
Show file tree
Hide file tree
Showing 9 changed files with 105 additions and 220 deletions.
201 changes: 0 additions & 201 deletions core/src/main/scala/org/apache/spark/sql/delta/CharVarcharUtils.scala

This file was deleted.

20 changes: 4 additions & 16 deletions core/src/main/scala/org/apache/spark/sql/delta/DeltaAnalysis.scala
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ class DeltaAnalysis(session: SparkSession)
needsSchemaAdjustment(d.name(), a.query, r.schema) =>
val projection = resolveQueryColumnsByOrdinal(a.query, r.output, d.name())
if (projection != a.query) {
val cleanedTable = r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
a.copy(query = projection, table = cleanedTable)
a.copy(query = projection)
} else {
a
}
Expand All @@ -82,8 +81,7 @@ class DeltaAnalysis(session: SparkSession)
val newDeleteExpr = o.deleteExpr.transformUp {
case a: AttributeReference => aliases.getOrElse(a, a)
}
val cleanedTable = r.copy(output = r.output.map(CharVarcharUtils.cleanAttrMetadata))
o.copy(deleteExpr = newDeleteExpr, query = projection, table = cleanedTable)
o.copy(deleteExpr = newDeleteExpr, query = projection)
} else {
o
}
Expand Down Expand Up @@ -273,8 +271,7 @@ class DeltaAnalysis(session: SparkSession)
case _ =>
getCastFunction(attr, targetAttr.dataType)
}
val strLenChecked = CharVarcharUtils.stringLengthCheck(expr, targetAttr)
Alias(strLenChecked, targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))
Alias(expr, targetAttr.name)(explicitMetadata = Option(targetAttr.metadata))
}

/**
Expand All @@ -293,17 +290,8 @@ class DeltaAnalysis(session: SparkSession)
// Now we should try our best to match everything that already exists, and leave the rest
// for schema evolution to WriteIntoDelta
val existingSchemaOutput = output.take(schema.length)
val rawSchema = getRawSchema(schema)
existingSchemaOutput.map(_.name) != schema.map(_.name) ||
!SchemaUtils.isReadCompatible(rawSchema.asNullable, existingSchemaOutput.toStructType)
}

private def getRawSchema(schema: StructType): StructType = {
StructType(schema.map { field =>
CharVarcharUtils.getRawType(field.metadata).map {
rawType => field.copy(dataType = rawType)
}.getOrElse(field)
})
!SchemaUtils.isReadCompatible(schema.asNullable, existingSchemaOutput.toStructType)
}

// Get cast operation for the level of strictness in the schema a user asked for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,10 @@ object DeltaErrors
s"constraint first.\nOld constraint:\n${oldExpr}")
}

def invalidConstraintName(name: String): AnalysisException = {
new AnalysisException(s"Cannot use '$name' as the name of a CHECK constraint.")
}

def checkConstraintNotBoolean(name: String, expr: String): AnalysisException = {
new AnalysisException(s"CHECK constraint '$name' ($expr) should be a boolean expression.'")
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import org.apache.spark.sql.delta.stats.FileSizeHistogram

import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.util.{Clock, Utils}

/** Record metrics about a successful commit. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import scala.util.control.NonFatal
import org.apache.spark.sql.delta._
import org.apache.spark.sql.delta.actions.Protocol
import org.apache.spark.sql.delta.catalog.DeltaTableV2
import org.apache.spark.sql.delta.constraints.Constraints
import org.apache.spark.sql.delta.constraints.{CharVarcharConstraint, Constraints}
import org.apache.spark.sql.delta.schema.{SchemaMergingUtils, SchemaUtils}
import org.apache.spark.sql.delta.schema.SchemaUtils.transformColumnsStructs
import org.apache.spark.sql.delta.sources.DeltaSQLConf
Expand Down Expand Up @@ -533,6 +533,9 @@ case class AlterTableAddConstraintDeltaCommand(

override def run(sparkSession: SparkSession): Seq[Row] = {
val deltaLog = table.deltaLog
if (name == CharVarcharConstraint.INVARIANT_NAME) {
throw DeltaErrors.invalidConstraintName(name)
}
recordDeltaOperation(deltaLog, "delta.ddl.alter.addConstraint") {
val txn = startTransaction()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/*
* Copyright (2021) The Delta Lake Project Authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.delta.constraints

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.CharVarcharUtils
import org.apache.spark.sql.types._

// Delta implements char/varchar length check with CONSTRAINTS, and needs to generate predicate
// expression which is different from the OSS version.
object CharVarcharConstraint {
final val INVARIANT_NAME = "__CHAR_VARCHAR_STRING_LENGTH_CHECK__"

def stringConstraints(schema: StructType): Seq[Constraint] = {
schema.flatMap { f =>
val targetType = CharVarcharUtils.getRawType(f.metadata).getOrElse(f.dataType)
val col = UnresolvedAttribute(Seq(f.name))
checkStringLength(col, targetType).map { lengthCheckExpr =>
Constraints.Check(INVARIANT_NAME, lengthCheckExpr)
}
}
}

private def checkStringLength(expr: Expression, dt: DataType): Option[Expression] = dt match {
case VarcharType(length) =>
Some(Or(IsNull(expr), LessThanOrEqual(Length(expr), Literal(length))))

case CharType(length) =>
checkStringLength(expr, VarcharType(length))

case StructType(fields) =>
fields.zipWithIndex.flatMap { case (f, i) =>
checkStringLength(GetStructField(expr, i, Some(f.name)), f.dataType)
}.reduceOption(And(_, _))

case ArrayType(et, containsNull) =>
checkStringLengthInArray(expr, et, containsNull)

case MapType(kt, vt, valueContainsNull) =>
(checkStringLengthInArray(MapKeys(expr), kt, false) ++
checkStringLengthInArray(MapValues(expr), vt, valueContainsNull))
.reduceOption(And(_, _))

case _ => None
}

private def checkStringLengthInArray(
arr: Expression, et: DataType, containsNull: Boolean): Option[Expression] = {
val cleanedType = CharVarcharUtils.replaceCharVarcharWithString(et)
val param = NamedLambdaVariable("x", cleanedType, containsNull)
checkStringLength(param, et).map { checkExpr =>
Or(IsNull(arr), ArrayForAll(arr, LambdaFunction(checkExpr, Seq(param))))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,13 @@ object Constraints {
def getAll(metadata: Metadata, spark: SparkSession): Seq[Constraint] = {
val checkConstraints = getCheckConstraints(metadata, spark)
val constraintsFromSchema = Invariants.getFromSchema(metadata.schema, spark)
val charVarcharLengthChecks = if (spark.sessionState.conf.charVarcharAsString) {
Nil
} else {
CharVarcharConstraint.stringConstraints(metadata.schema)
}

(checkConstraints ++ constraintsFromSchema).toSeq
(checkConstraints ++ constraintsFromSchema ++ charVarcharLengthChecks).toSeq
}

/** Get the expression text for a constraint with the given name, if present. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.delta.schema
// scalastyle:off import.ordering.noEmptyLine
import scala.collection.JavaConverters._

import org.apache.spark.sql.delta.constraints.Constraints
import org.apache.spark.sql.delta.constraints.{CharVarcharConstraint, Constraints}

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute

Expand All @@ -41,6 +41,10 @@ object InvariantViolationException {
def apply(
constraint: Constraints.Check,
values: Map[String, Any]): InvariantViolationException = {
if (constraint.name == CharVarcharConstraint.INVARIANT_NAME) {
return new InvariantViolationException("Exceeds char/varchar type length limitation")
}

val valueLines = values.map {
case (column, value) =>
s" - $column : $value"
Expand Down
Loading

0 comments on commit 685820b

Please sign in to comment.