From a8a291f16cdfa86aff2ee0ad1047ea7f83b8f551 Mon Sep 17 00:00:00 2001 From: morazow Date: Thu, 29 Oct 2020 11:22:12 +0100 Subject: [PATCH] Added decimal precision check. --- .../scala/com/exasol/avro/AvroConverter.scala | 35 +++++++++++-------- .../exasol/avro/AvroLogicalTypesTest.scala | 20 +++++++++++ 2 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/main/scala/com/exasol/avro/AvroConverter.scala b/src/main/scala/com/exasol/avro/AvroConverter.scala index 66fc9c8..04ba88d 100644 --- a/src/main/scala/com/exasol/avro/AvroConverter.scala +++ b/src/main/scala/com/exasol/avro/AvroConverter.scala @@ -23,6 +23,7 @@ import org.apache.avro.util.Utf8 */ final class AvroConverter { + private[this] val EXASOL_DECIMAL_PRECISION = 36 private[this] lazy val decimalConverter = new DecimalConversion() private[this] lazy val timestampMillisConverter = new TimestampMillisConversion() private[this] lazy val timestampMicrosConverter = new TimestampMicrosConversion() @@ -77,13 +78,11 @@ final class AvroConverter { } } - private[this] def getIntValue(value: Any, field: Schema): Any = { - val logicalType = field.getLogicalType() - logicalType match { + private[this] def getIntValue(value: Any, field: Schema): Any = + field.getLogicalType() match { case _: LogicalTypes.Date => dateFromSinceEpoch(value.asInstanceOf[Int].longValue()) case _ => value } - } private[this] def dateFromSinceEpoch(days: Long): Date = { // scalastyle:off magic.number @@ -93,33 +92,39 @@ final class AvroConverter { new Date(millis) } - private[this] def getLongValue(value: Any, field: Schema): Any = { - val logicalType = field.getLogicalType() - logicalType match { + private[this] def getLongValue(value: Any, field: Schema): Any = + field.getLogicalType() match { case lt: LogicalTypes.TimestampMillis => Timestamp.from(timestampMillisConverter.fromLong(value.asInstanceOf[Long], field, lt)) case lt: LogicalTypes.TimestampMicros => Timestamp.from(timestampMicrosConverter.fromLong(value.asInstanceOf[Long], field, lt)) case _ => value } - } - private[this] def getFixedValue(value: Any, field: Schema): Any = { - val logicalType = field.getLogicalType() - logicalType match { + private[this] def getFixedValue(value: Any, field: Schema): Any = + field.getLogicalType() match { case lt: LogicalTypes.Decimal => + checkPrecision(lt) decimalConverter.fromFixed(value.asInstanceOf[GenericFixed], field, lt) case _ => getStringValue(value, field) } - } - private[this] def getBytesValue(value: Any, field: Schema): Any = { - val logicalType = field.getLogicalType() - logicalType match { + private[this] def getBytesValue(value: Any, field: Schema): Any = + field.getLogicalType() match { case lt: LogicalTypes.Decimal => + checkPrecision(lt) decimalConverter.fromBytes(value.asInstanceOf[ByteBuffer], field, lt) case _ => getStringValue(value, field) } + + private[this] def checkPrecision(logicalType: LogicalTypes.Decimal): Unit = { + val precision = logicalType.getPrecision() + if (precision > EXASOL_DECIMAL_PRECISION) { + throw new IllegalArgumentException( + s"Decimal precision ${precision.toString()} is larger than " + + s"maximum allowed precision ${EXASOL_DECIMAL_PRECISION.toString()}." + ) + } } private[this] def getStringValue(value: Any, field: Schema): String = diff --git a/src/test/scala/com/exasol/avro/AvroLogicalTypesTest.scala b/src/test/scala/com/exasol/avro/AvroLogicalTypesTest.scala index fd33a64..5ed9e22 100644 --- a/src/test/scala/com/exasol/avro/AvroLogicalTypesTest.scala +++ b/src/test/scala/com/exasol/avro/AvroLogicalTypesTest.scala @@ -131,4 +131,24 @@ class AvroLogicalTypesTest extends AnyFunSuite { } } + test("throws if avro decimal exceeds allowed precision") { + val schema = getLogicalSchema( + s"""|{ + | "type":"bytes", + | "logicalType":"decimal", + | "precision":40, + | "scale":0 + |}""".stripMargin + ) + val record = new GenericData.Record(schema) + val bigDecimal = new BigDecimal("1234567890123456789012345678901234567890") + val bytes = ByteBuffer.wrap(bigDecimal.unscaledValue().toByteArray()) + record.put("value", bytes) + val thrown = intercept[IllegalArgumentException] { + AvroRow(record) + } + val expected = "Decimal precision 40 is larger than maximum allowed precision 36." + assert(thrown.getMessage === expected) + } + }