Skip to content

Commit

Permalink
[SPARK-35918][AVRO] Unify schema mismatch handling for read/write and…
Browse files Browse the repository at this point in the history
… enhance error messages

### What changes were proposed in this pull request?
This unifies struct schema mismatch-handling logic between `AvroSerializer` and `AvroDeserializer`, pushing it into `AvroUtils` which is used by both. The newly unified exception-handling logic is updated to provide more contextual information in error messages. When a schema mismatch is found, previously we would only report the first missing field that is found, but there may be any others as well, which can make it less clear what exactly is going wrong. Now, we will report on all missing fields.

### Why are the changes needed?
While working on apache#31490, we discussed that there is room for improvement in how schema mismatch errors are reported ([comment1](apache#31490 (comment)), [comment2](apache#31490 (comment))). Additionally, the logic between `AvroSerializer` and `AvroDeserializer` was quite similar for handling these issues, but didn't share common code, causing duplication and making it harder to see exactly what differences existed between the two.

### Does this PR introduce _any_ user-facing change?
Some error messages when matching Catalyst struct schemas against Avro record schemas now include more information.

### How was this patch tested?
New unit tests added.

Closes apache#33308 from xkrogen/xkrogen-SPARK-35918-avroserde-unify-better-error-messages.

Authored-by: Erik Krogen <xkrogen@apache.org>
Signed-off-by: Gengliang Wang <gengliang@apache.org>
  • Loading branch information
xkrogen authored and gengliangwang committed Aug 2, 2021
1 parent 951efb8 commit be06e41
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import java.math.BigDecimal
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.Conversions.DecimalConversion
Expand All @@ -30,7 +29,7 @@ import org.apache.avro.Schema.Type._
import org.apache.avro.generic._
import org.apache.avro.util.Utf8

import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr}
import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField}
import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters}
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData}
Expand Down Expand Up @@ -352,39 +351,26 @@ private[sql] class AvroDeserializer(
avroPath: Seq[String],
catalystPath: Seq[String],
applyFilters: Int => Boolean): (CatalystDataUpdater, GenericRecord) => Boolean = {
val validFieldIndexes = ArrayBuffer.empty[Int]
val fieldWriters = ArrayBuffer.empty[(CatalystDataUpdater, Any) => Unit]

val avroSchemaHelper =
new AvroUtils.AvroSchemaHelper(avroType, avroPath, positionalFieldMatch)
val length = catalystType.length
var i = 0
while (i < length) {
val catalystField = catalystType.fields(i)
avroSchemaHelper.getAvroField(catalystField.name, i) match {
case Some(avroField) =>
validFieldIndexes += avroField.pos()

val baseWriter = newWriter(avroField.schema(), catalystField.dataType,
avroPath :+ avroField.name, catalystPath :+ catalystField.name)
val ordinal = i
val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
if (value == null) {
fieldUpdater.setNullAt(ordinal)
} else {
baseWriter(fieldUpdater, ordinal, value)
}

val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(
avroType, catalystType, avroPath, catalystPath, positionalFieldMatch)

avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true)
// no need to validateNoExtraAvroFields since extra Avro fields are ignored

val (validFieldIndexes, fieldWriters) = avroSchemaHelper.matchedFields.map {
case AvroMatchedField(catalystField, ordinal, avroField) =>
val baseWriter = newWriter(avroField.schema(), catalystField.dataType,
avroPath :+ avroField.name, catalystPath :+ catalystField.name)
val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => {
if (value == null) {
fieldUpdater.setNullAt(ordinal)
} else {
baseWriter(fieldUpdater, ordinal, value)
}
fieldWriters += fieldWriter
case None if !catalystField.nullable =>
val fieldDescription =
toFieldDescription(catalystPath :+ catalystField.name, i, positionalFieldMatch)
throw new IncompatibleSchemaException(
s"Cannot find non-nullable $fieldDescription in Avro schema.")
case _ => // nothing to do
}
i += 1
}
}
(avroField.pos(), fieldWriter)
}.toArray.unzip

(fieldUpdater, record) => {
var i = 0
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import org.apache.avro.generic.GenericData.Record
import org.apache.avro.util.Utf8

import org.apache.spark.internal.Logging
import org.apache.spark.sql.avro.AvroUtils.{toFieldDescription, toFieldStr}
import org.apache.spark.sql.avro.AvroUtils.{toFieldStr, AvroMatchedField}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{SpecializedGetters, SpecificInternalRow}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
Expand Down Expand Up @@ -252,34 +252,19 @@ private[sql] class AvroSerializer(
catalystPath: Seq[String],
avroPath: Seq[String]): InternalRow => Record = {

val avroPathStr = toFieldStr(avroPath)
if (avroStruct.getType != RECORD) {
throw new IncompatibleSchemaException(s"$avroPathStr was not a RECORD")
}
val avroFields = avroStruct.getFields.asScala
if (avroFields.size != catalystStruct.length) {
throw new IncompatibleSchemaException(
s"Avro $avroPathStr schema length (${avroFields.size}) doesn't match " +
s"SQL ${toFieldStr(catalystPath)} schema length (${catalystStruct.length})")
}
val avroSchemaHelper =
new AvroUtils.AvroSchemaHelper(avroStruct, avroPath, positionalFieldMatch)

val (avroIndices: Array[Int], fieldConverters: Array[Converter]) =
catalystStruct.zipWithIndex.map { case (catalystField, catalystPos) =>
val avroField = avroSchemaHelper.getAvroField(catalystField.name, catalystPos) match {
case Some(f) => f
case None =>
val fieldDescription = toFieldDescription(
catalystPath :+ catalystField.name, catalystPos, positionalFieldMatch)
throw new IncompatibleSchemaException(
s"Cannot find $fieldDescription in Avro schema at $avroPathStr")
}
val avroSchemaHelper = new AvroUtils.AvroSchemaHelper(
avroStruct, catalystStruct, avroPath, catalystPath, positionalFieldMatch)

avroSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false)
avroSchemaHelper.validateNoExtraAvroFields()

val (avroIndices, fieldConverters) = avroSchemaHelper.matchedFields.map {
case AvroMatchedField(catalystField, _, avroField) =>
val converter = newConverter(catalystField.dataType,
resolveNullableType(avroField.schema(), catalystField.nullable),
catalystPath :+ catalystField.name, avroPath :+ avroField.name)
(avroField.pos(), converter)
}.toArray.unzip
}.toArray.unzip

val numFields = catalystStruct.length
row: InternalRow =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,18 +205,32 @@ private[sql] object AvroUtils extends Logging {
}
}

/** Wrapper for a pair of matched fields, one Catalyst and one corresponding Avro field. */
case class AvroMatchedField(
catalystField: StructField,
catalystPosition: Int,
avroField: Schema.Field)

/**
* Wraps an Avro Schema object so that field lookups are faster.
* Helper class to perform field lookup/matching on Avro schemas.
*
* This will match `avroSchema` against `catalystSchema`, attempting to find a matching field in
* the Avro schema for each field in the Catalyst schema and vice-versa, respecting settings for
* case sensitivity. The match results can be accessed using the getter methods.
*
* @param avroSchema The schema in which to search for fields. Must be of type RECORD.
* @param catalystSchema The Catalyst schema to use for matching.
* @param avroPath The seq of parent field names leading to `avroSchema`.
* @param catalystPath The seq of parent field names leading to `catalystSchema`.
* @param positionalFieldMatch If true, perform field matching in a positional fashion
* (structural comparison between schemas, ignoring names);
* otherwise, perform field matching using field names.
*/
class AvroSchemaHelper(
avroSchema: Schema,
catalystSchema: StructType,
avroPath: Seq[String],
catalystPath: Seq[String],
positionalFieldMatch: Boolean) {
if (avroSchema.getType != Schema.Type.RECORD) {
throw new IncompatibleSchemaException(
Expand All @@ -228,6 +242,50 @@ private[sql] object AvroUtils extends Logging {
.groupBy(_.name.toLowerCase(Locale.ROOT))
.mapValues(_.toSeq) // toSeq needed for scala 2.13

/** The fields which have matching equivalents in both Avro and Catalyst schemas. */
val matchedFields: Seq[AvroMatchedField] = catalystSchema.zipWithIndex.flatMap {
case (sqlField, sqlPos) =>
getAvroField(sqlField.name, sqlPos).map(AvroMatchedField(sqlField, sqlPos, _))
}

/**
* Validate that there are no Catalyst fields which don't have a matching Avro field, throwing
* [[IncompatibleSchemaException]] if such extra fields are found. If `ignoreNullable` is false,
* consider nullable Catalyst fields to be eligible to be an extra field; otherwise,
* ignore nullable Catalyst fields when checking for extras.
*/
def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit =
catalystSchema.zipWithIndex.foreach { case (sqlField, sqlPos) =>
if (getAvroField(sqlField.name, sqlPos).isEmpty &&
(!ignoreNullable || !sqlField.nullable)) {
if (positionalFieldMatch) {
throw new IncompatibleSchemaException("Cannot find field at position " +
s"$sqlPos of ${toFieldStr(avroPath)} from Avro schema (using positional matching)")
} else {
throw new IncompatibleSchemaException(
s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Avro schema")
}
}
}

/**
* Validate that there are no Avro fields which don't have a matching Catalyst field, throwing
* [[IncompatibleSchemaException]] if such extra fields are found.
*/
def validateNoExtraAvroFields(): Unit = {
(avroFieldArray.toSet -- matchedFields.map(_.avroField)).foreach { extraField =>
if (positionalFieldMatch) {
throw new IncompatibleSchemaException(s"Found field '${extraField.name()}' at position " +
s"${extraField.pos()} of ${toFieldStr(avroPath)} from Avro schema but there is no " +
s"match in the SQL schema at ${toFieldStr(catalystPath)} (using positional matching)")
} else {
throw new IncompatibleSchemaException(
s"Found ${toFieldStr(avroPath :+ extraField.name())} in Avro schema but there is no " +
"match in the SQL schema")
}
}
}

/**
* Extract a single field from the contained avro schema which has the desired field name,
* performing the matching with proper case sensitivity according to SQLConf.resolver.
Expand Down Expand Up @@ -261,21 +319,6 @@ private[sql] object AvroUtils extends Logging {
}
}

/**
* Take a field's hierarchical names (see [[toFieldStr]]) and position, and convert it to a
* human-readable description of the field. Depending on the value of `positionalFieldMatch`,
* either the position or name will be emphasized (for true and false, respectively); both will
* be included in either case.
*/
private[avro] def toFieldDescription(
names: Seq[String],
position: Int,
positionalFieldMatch: Boolean): String = if (positionalFieldMatch) {
s"field at position $position (${toFieldStr(names)})"
} else {
s"${toFieldStr(names)} (at position $position)"
}

/**
* Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable
* string representing the field, like "field 'foo.bar'". If `names` is empty, the string
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.avro

import org.apache.avro.SchemaBuilder

import org.apache.spark.sql.avro.AvroUtils.AvroMatchedField
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{SharedSparkSession, SQLTestUtils}
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
Expand All @@ -28,7 +29,7 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession {
val avroSchema = SchemaBuilder.builder().intType()

val msg = intercept[IncompatibleSchemaException] {
new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false)
new AvroUtils.AvroSchemaHelper(avroSchema, StructType(Seq()), Seq(""), Seq(""), false)
}.getMessage
assert(msg.contains("Attempting to treat int as a RECORD"))
}
Expand All @@ -42,7 +43,8 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession {
)

val avroSchema = SchemaConverters.toAvroType(catalystSchema)
val helper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false)
val helper =
new AvroUtils.AvroSchemaHelper(avroSchema, StructType(Seq()), Seq(""), Seq(""), false)
withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") {
assert(helper.getFieldByName("A").get.name() == "A")
assert(helper.getFieldByName("a").get.name() == "a")
Expand All @@ -69,8 +71,10 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession {
val catalystSchema = new StructType().add("foo", IntegerType).add("bar", StringType)
val avroSchema = SchemaConverters.toAvroType(catalystSchema)

val posHelper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), true)
val nameHelper = new AvroUtils.AvroSchemaHelper(avroSchema, Seq(""), false)
val posHelper =
new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), true)
val nameHelper =
new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), false)

for (name <- Seq("foo", "bar"); fieldPos <- Seq(0, 1)) {
assert(posHelper.getAvroField(name, fieldPos) === Some(avroSchema.getFields.get(fieldPos)))
Expand All @@ -82,4 +86,51 @@ class AvroSchemaHelperSuite extends SQLTestUtils with SharedSparkSession {
assert(posHelper.getAvroField("nonexist", 1).isDefined)
assert(nameHelper.getAvroField("nonexist", 1).isEmpty)
}

test("properly match fields between Avro and Catalyst schemas") {
val catalystSchema = StructType(
Seq("catalyst1", "catalyst2", "shared1", "shared2").map(StructField(_, IntegerType))
)
val avroSchema = SchemaBuilder.record("toplevel").fields()
.requiredInt("shared1")
.requiredInt("shared2")
.requiredInt("avro1")
.requiredInt("avro2")
.endRecord()

val helper = new AvroUtils.AvroSchemaHelper(avroSchema, catalystSchema, Seq(""), Seq(""), false)
assert(helper.matchedFields === Seq(
AvroMatchedField(catalystSchema("shared1"), 2, avroSchema.getField("shared1")),
AvroMatchedField(catalystSchema("shared2"), 3, avroSchema.getField("shared2"))
))
assertThrows[IncompatibleSchemaException] {
helper.validateNoExtraAvroFields()
}
helper.validateNoExtraCatalystFields(ignoreNullable = true)
assertThrows[IncompatibleSchemaException] {
helper.validateNoExtraCatalystFields(ignoreNullable = false)
}
}

test("respect nullability settings for validateNoExtraSqlFields") {
val avroSchema = SchemaBuilder.record("record").fields().requiredInt("bar").endRecord()

val catalystNonnull = new StructType().add("foo", IntegerType, nullable = false)
val helperNonnull =
new AvroUtils.AvroSchemaHelper(avroSchema, catalystNonnull, Seq(""), Seq(""), false)
assertThrows[IncompatibleSchemaException] {
helperNonnull.validateNoExtraCatalystFields(ignoreNullable = true)
}
assertThrows[IncompatibleSchemaException] {
helperNonnull.validateNoExtraCatalystFields(ignoreNullable = false)
}

val catalystNullable = new StructType().add("foo", IntegerType)
val helperNullable =
new AvroUtils.AvroSchemaHelper(avroSchema, catalystNullable, Seq(""), Seq(""), false)
helperNullable.validateNoExtraCatalystFields(ignoreNullable = true)
assertThrows[IncompatibleSchemaException] {
helperNullable.validateNoExtraCatalystFields(ignoreNullable = false)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -86,18 +86,18 @@ class AvroSerdeSuite extends SparkFunSuite {
// deserialize should have no issues when 'bar' is nullable but fail when it is nonnull
Deserializer.create(CATALYST_STRUCT, avro, BY_NAME)
assertFailedConversionMessage(avro, Deserializer, BY_NAME,
"Cannot find non-nullable field 'foo.bar' (at position 0) in Avro schema.",
"Cannot find field 'foo.bar' in Avro schema",
nonnullCatalyst)
assertFailedConversionMessage(avro, Deserializer, BY_POSITION,
"Cannot find non-nullable field at position 1 (field 'foo.baz') in Avro schema.",
"Cannot find field at position 1 of field 'foo' from Avro schema (using positional matching)",
extraNonnullCatalyst)

// serialize fails whether or not 'bar' is nullable
val expectMsg = "Cannot find field 'foo.bar' (at position 0) in Avro schema at field 'foo'"
assertFailedConversionMessage(avro, Serializer, BY_NAME, expectMsg)
assertFailedConversionMessage(avro, Serializer, BY_NAME, expectMsg, nonnullCatalyst)
val byNameMsg = "Cannot find field 'foo.bar' in Avro schema"
assertFailedConversionMessage(avro, Serializer, BY_NAME, byNameMsg)
assertFailedConversionMessage(avro, Serializer, BY_NAME, byNameMsg, nonnullCatalyst)
assertFailedConversionMessage(avro, Serializer, BY_POSITION,
"Avro field 'foo' schema length (1) doesn't match SQL field 'foo' schema length (2)",
"Cannot find field at position 1 of field 'foo' from Avro schema (using positional matching)",
extraNonnullCatalyst)
}

Expand All @@ -122,18 +122,28 @@ class AvroSerdeSuite extends SparkFunSuite {

test("Fail to convert for serialization with field count mismatch") {
// Note that this is allowed for deserialization, but not serialization
withFieldMatchType { fieldMatch =>
val tooManyFields =
createAvroSchemaWithTopLevelFields(_.optionalInt("foo").optionalLong("bar"))
assertFailedConversionMessage(tooManyFields, Serializer, fieldMatch,
"Avro top-level record schema length (2) " +
"doesn't match SQL top-level record schema length (1)")

val tooFewFields = createAvroSchemaWithTopLevelFields(f => f)
assertFailedConversionMessage(tooFewFields, Serializer, fieldMatch,
"Avro top-level record schema length (0) " +
"doesn't match SQL top-level record schema length (1)")
}
val tooManyFields =
createAvroSchemaWithTopLevelFields(_.optionalInt("foo").optionalLong("bar"))
assertFailedConversionMessage(tooManyFields, Serializer, BY_NAME,
"Found field 'bar' in Avro schema but there is no match in the SQL schema")
assertFailedConversionMessage(tooManyFields, Serializer, BY_POSITION,
"Found field 'bar' at position 1 of top-level record from Avro schema but there is no " +
"match in the SQL schema at top-level record (using positional matching)")

val tooManyFieldsNested =
createNestedAvroSchemaWithFields("foo", _.optionalInt("bar").optionalInt("baz"))
assertFailedConversionMessage(tooManyFieldsNested, Serializer, BY_NAME,
"Found field 'foo.baz' in Avro schema but there is no match in the SQL schema")
assertFailedConversionMessage(tooManyFieldsNested, Serializer, BY_POSITION,
s"Found field 'baz' at position 1 of field 'foo' from Avro schema but there is no match " +
s"in the SQL schema at field 'foo' (using positional matching)")

val tooFewFields = createAvroSchemaWithTopLevelFields(f => f)
assertFailedConversionMessage(tooFewFields, Serializer, BY_NAME,
"Cannot find field 'foo' in Avro schema")
assertFailedConversionMessage(tooFewFields, Serializer, BY_POSITION,
"Cannot find field at position 0 of top-level record from Avro schema " +
"(using positional matching)")
}

/**
Expand Down
Loading

0 comments on commit be06e41

Please sign in to comment.