Skip to content

Commit

Permalink
Added decimal precision check.
Browse files Browse the repository at this point in the history
  • Loading branch information
morazow committed Oct 29, 2020
1 parent 993636e commit a8a291f
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 15 deletions.
35 changes: 20 additions & 15 deletions src/main/scala/com/exasol/avro/AvroConverter.scala
Expand Up @@ -23,6 +23,7 @@ import org.apache.avro.util.Utf8
*/
final class AvroConverter {

private[this] val EXASOL_DECIMAL_PRECISION = 36
private[this] lazy val decimalConverter = new DecimalConversion()
private[this] lazy val timestampMillisConverter = new TimestampMillisConversion()
private[this] lazy val timestampMicrosConverter = new TimestampMicrosConversion()
Expand Down Expand Up @@ -77,13 +78,11 @@ final class AvroConverter {
}
}

private[this] def getIntValue(value: Any, field: Schema): Any = {
val logicalType = field.getLogicalType()
logicalType match {
private[this] def getIntValue(value: Any, field: Schema): Any =
field.getLogicalType() match {
case _: LogicalTypes.Date => dateFromSinceEpoch(value.asInstanceOf[Int].longValue())
case _ => value
}
}

private[this] def dateFromSinceEpoch(days: Long): Date = {
// scalastyle:off magic.number
Expand All @@ -93,33 +92,39 @@ final class AvroConverter {
new Date(millis)
}

private[this] def getLongValue(value: Any, field: Schema): Any = {
val logicalType = field.getLogicalType()
logicalType match {
private[this] def getLongValue(value: Any, field: Schema): Any =
field.getLogicalType() match {
case lt: LogicalTypes.TimestampMillis =>
Timestamp.from(timestampMillisConverter.fromLong(value.asInstanceOf[Long], field, lt))
case lt: LogicalTypes.TimestampMicros =>
Timestamp.from(timestampMicrosConverter.fromLong(value.asInstanceOf[Long], field, lt))
case _ => value
}
}

private[this] def getFixedValue(value: Any, field: Schema): Any = {
val logicalType = field.getLogicalType()
logicalType match {
private[this] def getFixedValue(value: Any, field: Schema): Any =
field.getLogicalType() match {
case lt: LogicalTypes.Decimal =>
checkPrecision(lt)
decimalConverter.fromFixed(value.asInstanceOf[GenericFixed], field, lt)
case _ => getStringValue(value, field)
}
}

private[this] def getBytesValue(value: Any, field: Schema): Any = {
val logicalType = field.getLogicalType()
logicalType match {
private[this] def getBytesValue(value: Any, field: Schema): Any =
field.getLogicalType() match {
case lt: LogicalTypes.Decimal =>
checkPrecision(lt)
decimalConverter.fromBytes(value.asInstanceOf[ByteBuffer], field, lt)
case _ => getStringValue(value, field)
}

private[this] def checkPrecision(logicalType: LogicalTypes.Decimal): Unit = {
val precision = logicalType.getPrecision()
if (precision > EXASOL_DECIMAL_PRECISION) {
throw new IllegalArgumentException(
s"Decimal precision ${precision.toString()} is larger than " +
s"maximum allowed precision ${EXASOL_DECIMAL_PRECISION.toString()}."
)
}
}

private[this] def getStringValue(value: Any, field: Schema): String =
Expand Down
20 changes: 20 additions & 0 deletions src/test/scala/com/exasol/avro/AvroLogicalTypesTest.scala
Expand Up @@ -131,4 +131,24 @@ class AvroLogicalTypesTest extends AnyFunSuite {
}
}

test("throws if avro decimal exceeds allowed precision") {
val schema = getLogicalSchema(
s"""|{
| "type":"bytes",
| "logicalType":"decimal",
| "precision":40,
| "scale":0
|}""".stripMargin
)
val record = new GenericData.Record(schema)
val bigDecimal = new BigDecimal("1234567890123456789012345678901234567890")
val bytes = ByteBuffer.wrap(bigDecimal.unscaledValue().toByteArray())
record.put("value", bytes)
val thrown = intercept[IllegalArgumentException] {
AvroRow(record)
}
val expected = "Decimal precision 40 is larger than maximum allowed precision 36."
assert(thrown.getMessage === expected)
}

}

0 comments on commit a8a291f

Please sign in to comment.