diff --git a/README.md b/README.md index 90fec4e..e69995b 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ These plugins help with project development. | [SBT PGP][sbt-pgp-link] | PGP plugin for `sbt` | BSD 3-Clause License | | [SBT Git][sbt-git-link] | A plugin for Git integration, used to version the release jars | BSD 2-Clause License | -[travis-badge]: https://img.shields.io/travis/exasol/import-export-udf-common-scala/master.svg?logo=travis +[travis-badge]: https://img.shields.io/travis/com/exasol/import-export-udf-common-scala/master.svg?logo=travis [travis-link]: https://travis-ci.com/exasol/import-export-udf-common-scala [coveralls-badge]: https://coveralls.io/repos/github/exasol/import-export-udf-common-scala/badge.svg?branch=master [coveralls-link]: https://coveralls.io/github/exasol/import-export-udf-common-scala?branch=master diff --git a/doc/changes/changes_0.1.1.md b/doc/changes/changes_0.2.0.md similarity index 69% rename from doc/changes/changes_0.1.1.md rename to doc/changes/changes_0.2.0.md index e280a1c..6f17934 100644 --- a/doc/changes/changes_0.1.1.md +++ b/doc/changes/changes_0.2.0.md @@ -1,9 +1,10 @@ -# Import Export UDF Common Scala 0.1.1, released 2020-10-DD +# Import Export UDF Common Scala 0.2.0, released 2020-10-DD ## Features * #9: Added SLF4J Logging Library as Common Dependency (PR #10) -* #11: Added Support for Complex Avro Types (Array, Map, Record) (PR #12) +* #11: Added Support for Avro Complex Types (Array, Map, Record) (PR #12) +* #13: Added Support for Avro Logical Types (BigDecimal, Date, Timestamp) (PR #14) ## Documentation @@ -19,10 +20,10 @@ ### Test Dependency Updates -* Updated `org.mockito:mockito-core` from `3.5.10` to `3.5.15`. +* Updated `org.mockito:mockito-core` from `3.5.10` to `3.6.0`. ### Plugin Updates * Updated `com.github.cb372:sbt-explicit-dependencies` from `0.2.13` to `0.2.15`. -* Updated `org.wartremover:sbt-wartremover` from `2.4.10` to `2.4.11`. -* Updated `org.wartremover:sbt-wartremover-contib` from `1.3.8` to `1.3.9`. +* Updated `org.wartremover:sbt-wartremover` from `2.4.10` to `2.4.12`. +* Updated `org.wartremover:sbt-wartremover-contib` from `1.3.8` to `1.3.10`. diff --git a/project/Dependencies.scala b/project/Dependencies.scala index 47554cf..d1b740d 100644 --- a/project/Dependencies.scala +++ b/project/Dependencies.scala @@ -15,7 +15,7 @@ object Dependencies { // Test dependencies versions private val ScalaTestVersion = "3.2.2" private val ScalaTestPlusVersion = "1.0.0-M2" - private val MockitoCoreVersion = "3.5.15" + private val MockitoCoreVersion = "3.6.0" val ExasolResolvers: Seq[Resolver] = Seq( "Exasol Releases" at "https://maven.exasol.com/artifactory/exasol-releases" diff --git a/project/plugins.sbt b/project/plugins.sbt index 2cccd1c..d1d0b10 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1,10 +1,10 @@ // Adds a `wartremover` a flexible Scala code linting tool // http://github.com/puffnfresh/wartremover -addSbtPlugin("org.wartremover" % "sbt-wartremover" % "2.4.11") +addSbtPlugin("org.wartremover" % "sbt-wartremover" % "2.4.12") // Adds Contrib Warts // http://github.com/wartremover/wartremover-contrib/ -addSbtPlugin("org.wartremover" % "sbt-wartremover-contrib" % "1.3.9") +addSbtPlugin("org.wartremover" % "sbt-wartremover-contrib" % "1.3.10") // Adds most common doc api mappings // https://github.com/ThoughtWorksInc/sbt-api-mappings diff --git a/src/main/scala/com/exasol/avro/AvroConverter.scala b/src/main/scala/com/exasol/avro/AvroConverter.scala new file mode 100644 index 0000000..04ba88d --- /dev/null +++ b/src/main/scala/com/exasol/avro/AvroConverter.scala @@ -0,0 +1,205 @@ +package com.exasol.common.avro + +import java.nio.ByteBuffer +import java.sql.Date +import java.sql.Timestamp +import java.time._ +import java.util.{Map => JMap} +import java.util.Collection + +import com.exasol.common.json.JsonMapper + +import org.apache.avro.Conversions.DecimalConversion +import org.apache.avro.LogicalTypes +import org.apache.avro.Schema +import org.apache.avro.data.TimeConversions.TimestampMicrosConversion +import org.apache.avro.data.TimeConversions.TimestampMillisConversion +import org.apache.avro.generic.GenericFixed +import org.apache.avro.generic.IndexedRecord +import org.apache.avro.util.Utf8 + +/** + * Avro data type converter helper class. + */ +final class AvroConverter { + + private[this] val EXASOL_DECIMAL_PRECISION = 36 + private[this] lazy val decimalConverter = new DecimalConversion() + private[this] lazy val timestampMillisConverter = new TimestampMillisConversion() + private[this] lazy val timestampMicrosConverter = new TimestampMicrosConversion() + + /** + * Converts Avro schema field value into a Java datatype. + * + * If Avro value is a complex datatype, then it is converted to the + * JSON string. + * + * @param value Avro record field value + * @param schema Avro record field schema + * @return A regular Java data types + */ + def convert(value: Any, schema: Schema): Any = { + val fieldValue = getAvroValue(value, schema) + if (isPrimitiveAvroType(schema.getType())) { + fieldValue + } else { + JsonMapper.toJson(fieldValue) + } + } + + private[this] def isPrimitiveAvroType(avroType: Schema.Type): Boolean = + avroType match { + case Schema.Type.ARRAY => false + case Schema.Type.MAP => false + case Schema.Type.RECORD => false + case _ => true + } + + @SuppressWarnings(Array("org.wartremover.warts.Return", "org.wartremover.warts.ToString")) + private[this] def getAvroValue(value: Any, field: Schema): Any = { + if (value == null) { + return null // scalastyle:ignore return + } + field.getType() match { + case Schema.Type.NULL => value + case Schema.Type.BOOLEAN => value + case Schema.Type.INT => getIntValue(value, field) + case Schema.Type.LONG => getLongValue(value, field) + case Schema.Type.FLOAT => value + case Schema.Type.DOUBLE => value + case Schema.Type.STRING => getStringValue(value, field) + case Schema.Type.FIXED => getFixedValue(value, field) + case Schema.Type.BYTES => getBytesValue(value, field) + case Schema.Type.ENUM => value.toString + case Schema.Type.UNION => getUnionValue(value, field) + case Schema.Type.ARRAY => getArrayValue(value, field) + case Schema.Type.MAP => getMapValue(value, field) + case Schema.Type.RECORD => getRecordValue(value) + } + } + + private[this] def getIntValue(value: Any, field: Schema): Any = + field.getLogicalType() match { + case _: LogicalTypes.Date => dateFromSinceEpoch(value.asInstanceOf[Int].longValue()) + case _ => value + } + + private[this] def dateFromSinceEpoch(days: Long): Date = { + // scalastyle:off magic.number + val date = LocalDateTime.of(1970, 1, 1, 0, 0, 0).plusDays(days) + // scalastyle:on + val millis = date.atZone(ZoneId.systemDefault).toInstant().toEpochMilli() + new Date(millis) + } + + private[this] def getLongValue(value: Any, field: Schema): Any = + field.getLogicalType() match { + case lt: LogicalTypes.TimestampMillis => + Timestamp.from(timestampMillisConverter.fromLong(value.asInstanceOf[Long], field, lt)) + case lt: LogicalTypes.TimestampMicros => + Timestamp.from(timestampMicrosConverter.fromLong(value.asInstanceOf[Long], field, lt)) + case _ => value + } + + private[this] def getFixedValue(value: Any, field: Schema): Any = + field.getLogicalType() match { + case lt: LogicalTypes.Decimal => + checkPrecision(lt) + decimalConverter.fromFixed(value.asInstanceOf[GenericFixed], field, lt) + case _ => getStringValue(value, field) + } + + private[this] def getBytesValue(value: Any, field: Schema): Any = + field.getLogicalType() match { + case lt: LogicalTypes.Decimal => + checkPrecision(lt) + decimalConverter.fromBytes(value.asInstanceOf[ByteBuffer], field, lt) + case _ => getStringValue(value, field) + } + + private[this] def checkPrecision(logicalType: LogicalTypes.Decimal): Unit = { + val precision = logicalType.getPrecision() + if (precision > EXASOL_DECIMAL_PRECISION) { + throw new IllegalArgumentException( + s"Decimal precision ${precision.toString()} is larger than " + + s"maximum allowed precision ${EXASOL_DECIMAL_PRECISION.toString()}." + ) + } + } + + private[this] def getStringValue(value: Any, field: Schema): String = + value match { + case str: String => str + case utf: Utf8 => utf.toString + case byteBuffer: ByteBuffer => new String(byteBuffer.array) + case arrayByte: Array[Byte] => new String(arrayByte) + case fixed: GenericFixed => new String(fixed.bytes()) + case _ => + throw new IllegalArgumentException( + s"Avro ${field.getName} type cannot be converted to string!" + ) + } + + private[this] def getUnionValue(value: Any, field: Schema): Any = { + val types = field.getTypes() + val typesSize = types.size() + typesSize match { + case 1 => getAvroValue(value, types.get(0)) + case 2 => + if (types.get(0).getType() == Schema.Type.NULL) { + getAvroValue(value, types.get(1)) + } else if (types.get(1).getType() == Schema.Type.NULL) { + getAvroValue(value, types.get(0)) + } else { + throw new IllegalArgumentException( + "Avro Union type should contain a primitive and null!" + ) + } + case _ => + throw new IllegalArgumentException("Avro Union type should contain a primitive and null!") + } + } + + private[this] def getArrayValue(value: Any, field: Schema): Array[Any] = value match { + case array: Array[_] => array.map(getAvroValue(_, field.getElementType())) + case list: Collection[_] => + val result = new Array[Any](list.size) + var i = 0 + list.stream().forEach { element => + val _ = result.update(i, getAvroValue(element, field.getElementType())) + i += 1 + } + result + case other => + throw new IllegalArgumentException( + s"Unsupported Avro Array type '${other.getClass.getName()}'." + ) + } + + private[this] def getMapValue(map: Any, field: Schema): JMap[String, Any] = { + val result = new java.util.HashMap[String, Any]() + map.asInstanceOf[JMap[String, _]].forEach { (key, value) => + val _ = result.put(key, getAvroValue(value, field.getValueType())) + } + result + } + + private[this] def getRecordValue(value: Any): JMap[String, Any] = value match { + case record: IndexedRecord => + val size = record.getSchema().getFields().size + val fields = record.getSchema().getFields() + val result = new java.util.HashMap[String, Any]() + var i = 0 + while (i < size) { + val _ = + result.put(fields.get(i).name, getAvroValue(record.get(i), fields.get(i).schema)) + i += 1 + } + result + case other => + throw new IllegalArgumentException( + s"Unsupported Avro Record type '${other.getClass.getName()}'." + ) + } + +} diff --git a/src/main/scala/com/exasol/avro/AvroRow.scala b/src/main/scala/com/exasol/avro/AvroRow.scala index 6bcb8cb..4a60dfd 100644 --- a/src/main/scala/com/exasol/avro/AvroRow.scala +++ b/src/main/scala/com/exasol/avro/AvroRow.scala @@ -1,17 +1,8 @@ package com.exasol.common.avro -import java.nio.ByteBuffer -import java.util.{Map => JMap} -import java.util.Collection - import com.exasol.common.data.Row -import com.exasol.common.json.JsonMapper -import org.apache.avro.Schema -import org.apache.avro.generic.GenericFixed import org.apache.avro.generic.GenericRecord -import org.apache.avro.generic.IndexedRecord -import org.apache.avro.util.Utf8 /** * A factory method that creates [[com.exasol.common.data.Row]] @@ -22,132 +13,18 @@ object AvroRow { /** * Converts an Avro record into an internal [[com.exasol.common.data.Row]]. * - * @param avroRecord a generic Avro record + * @param record a generic Avro record * @return a Row representation of the given Avro record */ - def apply(avroRecord: GenericRecord): Row = { - val fields = avroRecord.getSchema().getFields() + def apply(record: GenericRecord): Row = { + val fields = record.getSchema().getFields() val size = fields.size() val values = Array.ofDim[Any](size) + val converter = new AvroConverter() for { i <- 0 until size } { - values.update(i, getAvroFieldValue(fields.get(i).schema(), avroRecord.get(i))) + values.update(i, converter.convert(record.get(i), fields.get(i).schema())) } Row(values.toSeq) } - private[this] def getAvroFieldValue(schema: Schema, value: Any): Any = { - val fieldValue = getAvroValue(value, schema) - if (isPrimitiveAvroType(schema.getType())) { - fieldValue - } else { - JsonMapper.toJson(fieldValue) - } - } - - private[this] def isPrimitiveAvroType(avroType: Schema.Type): Boolean = - avroType match { - case Schema.Type.ARRAY => false - case Schema.Type.MAP => false - case Schema.Type.RECORD => false - case _ => true - } - - @SuppressWarnings(Array("org.wartremover.warts.Return", "org.wartremover.warts.ToString")) - private[this] def getAvroValue(value: Any, field: Schema): Any = { - if (value == null) { - return null // scalastyle:ignore return - } - field.getType() match { - case Schema.Type.NULL => value - case Schema.Type.BOOLEAN => value - case Schema.Type.INT => value - case Schema.Type.LONG => value - case Schema.Type.FLOAT => value - case Schema.Type.DOUBLE => value - case Schema.Type.STRING => getStringValue(value, field) - case Schema.Type.FIXED => getStringValue(value, field) - case Schema.Type.BYTES => getStringValue(value, field) - case Schema.Type.ENUM => value.toString - case Schema.Type.UNION => getUnionValue(value, field) - case Schema.Type.ARRAY => getArrayValue(value, field) - case Schema.Type.MAP => getMapValue(value, field) - case Schema.Type.RECORD => getRecordValue(value) - } - } - - private[this] def getStringValue(value: Any, field: Schema): String = - value match { - case str: String => str - case utf: Utf8 => utf.toString - case byteBuffer: ByteBuffer => new String(byteBuffer.array) - case arrayByte: Array[Byte] => new String(arrayByte) - case fixed: GenericFixed => new String(fixed.bytes()) - case _ => - throw new IllegalArgumentException( - s"Avro ${field.getName} type cannot be converted to string!" - ) - } - - private[this] def getUnionValue(value: Any, field: Schema): Any = { - val types = field.getTypes() - val typesSize = types.size() - typesSize match { - case 1 => getAvroValue(value, types.get(0)) - case 2 => - if (types.get(0).getType() == Schema.Type.NULL) { - getAvroValue(value, types.get(1)) - } else if (types.get(1).getType() == Schema.Type.NULL) { - getAvroValue(value, types.get(0)) - } else { - throw new IllegalArgumentException( - "Avro Union type should contain a primitive and null!" - ) - } - case _ => - throw new IllegalArgumentException("Avro Union type should contain a primitive and null!") - } - } - - private[this] def getArrayValue(value: Any, field: Schema): Array[Any] = value match { - case array: Array[_] => array.map(getAvroValue(_, field.getElementType())) - case list: Collection[_] => - val result = new Array[Any](list.size) - var i = 0 - list.stream().forEach { element => - val _ = result.update(i, getAvroValue(element, field.getElementType())) - i += 1 - } - result - case other => - throw new IllegalArgumentException( - s"Unsupported Avro Array type '${other.getClass.getName()}'." - ) - } - - private[this] def getMapValue(map: Any, field: Schema): JMap[String, Any] = { - val result = new java.util.HashMap[String, Any]() - map.asInstanceOf[JMap[String, _]].forEach { (key, value) => - val _ = result.put(key, getAvroValue(value, field.getValueType())) - } - result - } - - private[this] def getRecordValue(value: Any): JMap[String, Any] = value match { - case record: IndexedRecord => - val size = record.getSchema().getFields().size - val fields = record.getSchema().getFields() - val result = new java.util.HashMap[String, Any]() - var i = 0 - while (i < size) { - val _ = - result.put(fields.get(i).name, getAvroValue(record.get(i), fields.get(i).schema)) - i += 1 - } - result - case other => - throw new IllegalArgumentException( - s"Unsupported Avro Record type '${other.getClass.getName()}'." - ) - } - } diff --git a/src/test/scala/com/exasol/avro/AvroLogicalTypesTest.scala b/src/test/scala/com/exasol/avro/AvroLogicalTypesTest.scala new file mode 100644 index 0000000..5ed9e22 --- /dev/null +++ b/src/test/scala/com/exasol/avro/AvroLogicalTypesTest.scala @@ -0,0 +1,154 @@ +package com.exasol.common.avro + +import java.math.BigDecimal +import java.nio.ByteBuffer +import java.sql.Date +import java.sql.Timestamp + +import com.exasol.common.data.Row + +import org.apache.avro.Conversions +import org.apache.avro.LogicalTypes +import org.apache.avro.Schema +import org.apache.avro.generic.GenericData +import org.scalatest.funsuite.AnyFunSuite + +class AvroLogicalTypesTest extends AnyFunSuite { + + private[this] def getLogicalSchema(avroType: String): Schema = + new Schema.Parser() + .parse( + s"""|{ + | "type": "record", + | "namespace": "com.exasol.avro.Types", + | "name": "LogicalTypesRecord", + | "fields": [{ + | "name": "value", + | "type": $avroType + | }] + |} + |""".stripMargin + ) + + test("parse avro int with date logical type as Java SQL date type") { + val daysSinceEpoch = Seq(-719164, -70672, -21060, -365, -1, 0, 7252, 17317, 17937) + val expectedDates = Seq( + "0001-01-01", + "1776-07-04", + "1912-05-05", + "1969-01-01", + "1969-12-31", + "1970-01-01", + "1989-11-09", + "2017-05-31", + "2019-02-10" + ) + val schema = getLogicalSchema("""{"type":"int","logicalType":"date"}""") + daysSinceEpoch.zipWithIndex.foreach { + case (days, i) => + val record = new GenericData.Record(schema) + record.put("value", days) + assert(AvroRow(record).getAs[Date](0).toString() === expectedDates(i)) + } + } + + private[this] val milliseconds = Seq(-15854399877000L, 1603927542000L, 0L) + + test("parse avro long with timestamp-millis as Java SQL timestamp type") { + val schema = getLogicalSchema("""{"type":"long","logicalType":"timestamp-millis"}""") + milliseconds.foreach { + case millis => + val record = new GenericData.Record(schema) + record.put("value", millis) + assert(AvroRow(record).getAs[Timestamp](0) === new Timestamp(millis)) + } + } + + test("parse avro long with timestamp-micros as Java SQL timestamp type") { + val schema = getLogicalSchema("""{"type":"long","logicalType":"timestamp-micros"}""") + milliseconds.foreach { + case millis => + val record = new GenericData.Record(schema) + record.put("value", millis * 1000L + 13) + val expected = new Timestamp(millis) + expected.setNanos(13000) + assert(AvroRow(record).getAs[Timestamp](0) === expected) + } + } + + private[this] val precision = 4 + private[this] val scale = 2 + private[this] val decimals = Map( + "3.14" -> "3.14", + "2.01" -> "2.01", + "1.2" -> "1.20", + "0.5" -> "0.50", + "-1" -> "-1.00", + "-2.31" -> "-2.31" + ) + + test("parse avro bytes with decimal as big decimal type") { + val schema = getLogicalSchema( + s"""|{ + | "type":"bytes", + | "logicalType":"decimal", + | "precision":4, + | "scale":2 + |}""".stripMargin + ) + decimals.foreach { + case (given, expected) => + val record = new GenericData.Record(schema) + val bytes = ByteBuffer.wrap( + new BigDecimal(given).setScale(scale).unscaledValue().toByteArray() + ) + record.put("value", bytes) + assert(AvroRow(record) === Row(Seq(new BigDecimal(expected)))) + } + } + + test("parse avro fixed with decimal as big decimal type") { + val schema = getLogicalSchema( + s"""|{ + | "name":"fixed", + | "type":"fixed", + | "size":5, + | "logicalType":"decimal", + | "precision":4, + | "scale":2 + |}""".stripMargin + ) + decimals.foreach { + case (given, expected) => + val record = new GenericData.Record(schema) + val fixed = new Conversions.DecimalConversion().toFixed( + new BigDecimal(given).setScale(scale), + schema.getField("value").schema(), + LogicalTypes.decimal(precision, scale) + ) + record.put("value", fixed) + assert(AvroRow(record) === Row(Seq(new BigDecimal(expected)))) + } + } + + test("throws if avro decimal exceeds allowed precision") { + val schema = getLogicalSchema( + s"""|{ + | "type":"bytes", + | "logicalType":"decimal", + | "precision":40, + | "scale":0 + |}""".stripMargin + ) + val record = new GenericData.Record(schema) + val bigDecimal = new BigDecimal("1234567890123456789012345678901234567890") + val bytes = ByteBuffer.wrap(bigDecimal.unscaledValue().toByteArray()) + record.put("value", bytes) + val thrown = intercept[IllegalArgumentException] { + AvroRow(record) + } + val expected = "Decimal precision 40 is larger than maximum allowed precision 36." + assert(thrown.getMessage === expected) + } + +}