Skip to content
This repository has been archived by the owner on Dec 20, 2018. It is now read-only.

Add forceSchema option to output to specified schema #222

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
153 changes: 135 additions & 18 deletions src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,33 +20,49 @@ import java.io.{IOException, OutputStream}
import java.nio.ByteBuffer
import java.sql.Timestamp
import java.sql.Date
import java.util
import java.util.HashMap

import org.apache.hadoop.fs.Path

import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import scala.collection.immutable.Map

import org.apache.avro.generic.GenericData.Record
import org.apache.avro.generic.GenericRecord
import org.apache.avro.generic.{GenericData, GenericRecord, GenericRecordBuilder}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.mapred.AvroKey
import org.apache.avro.mapreduce.AvroKeyOutputFormat
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, TaskAttemptID}

import org.apache.log4j.Logger
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types._

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
private[avro] class AvroOutputWriter(
path: String,
context: TaskAttemptContext,
schema: StructType,
recordName: String,
recordNamespace: String) extends OutputWriter {
recordNamespace: String,
forceSchema: String) extends OutputWriter {

private val logger = Logger.getLogger(this.getClass)

private val forceAvroSchema = if (forceSchema.contentEquals("")) {
None
} else {
Option(new Schema.Parser().parse(forceSchema))
}
private lazy val converter = createConverterToAvro(
schema, recordName, recordNamespace, forceAvroSchema
)

private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace)
// copy of the old conversion logic after api change in SPARK-19085
private lazy val internalRowConverter =
CatalystTypeConverters.createToScalaConverter(schema).asInstanceOf[InternalRow => Row]
Expand Down Expand Up @@ -83,28 +99,79 @@ private[avro] class AvroOutputWriter(

override def close(): Unit = recordWriter.close(context)

private def resolveStructTypeToAvroUnion(schema:Schema, dataType:String): Schema = {
val allowedAvroTypes = dataType match {
case "boolean" => List(Schema.Type.BOOLEAN)
case "integer" => List(Schema.Type.INT)
case "long" => List(Schema.Type.LONG)
case "float" => List(Schema.Type.FLOAT)
case "double" => List(Schema.Type.DOUBLE)
case "binary" => List(Schema.Type.BYTES, Schema.Type.FIXED)
case "array" => List(Schema.Type.ARRAY)
case "map" => List(Schema.Type.MAP)
case "string" => List(Schema.Type.STRING, Schema.Type.ENUM)
case "struct" => List(Schema.Type.ARRAY, Schema.Type.RECORD)
case default => {
throw new RuntimeException(
s"Cannot map SparkSQL type '$dataType' against Avro schema '$schema'"
)
}
}
schema.getTypes.find (allowedAvroTypes contains _.getType).get
}

/**
* This function constructs converter function for a given sparkSQL datatype. This is used in
* writing Avro records out to disk
*/
private def createConverterToAvro(
dataType: DataType,
structName: String,
recordNamespace: String): (Any) => Any = {
recordNamespace: String,
forceAvroSchema: Option[Schema]): (Any) => Any = {
dataType match {
case BinaryType => (item: Any) => item match {
case null => null
case bytes: Array[Byte] => ByteBuffer.wrap(bytes)
case bytes: Array[Byte] => if (forceAvroSchema.isDefined) {
// Handle mapping from binary => bytes|fixed w/ forceSchema
forceAvroSchema.get.getType match {
case Schema.Type.BYTES => ByteBuffer.wrap(bytes)
case Schema.Type.FIXED => new GenericData.Fixed(
forceAvroSchema.get, bytes
)
case default => bytes
}
} else {
ByteBuffer.wrap(bytes)
}
}
case ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | StringType | BooleanType => identity
FloatType | DoubleType | BooleanType => identity
case StringType => (item: Any) => if (forceAvroSchema.isDefined) {
// Handle case when forcing schema where this string should map
// to an ENUM
forceAvroSchema.get.getType match {
case Schema.Type.ENUM => new GenericData.EnumSymbol(
forceAvroSchema.get, item.toString
)
case default => item
}
} else {
item
}
case _: DecimalType => (item: Any) => if (item == null) null else item.toString
case TimestampType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Timestamp].getTime
case DateType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Date].getTime
case ArrayType(elementType, _) =>
val elementConverter = createConverterToAvro(elementType, structName, recordNamespace)
val elementConverter = if (forceAvroSchema.isDefined) {
createConverterToAvro(elementType, structName,
recordNamespace, Option(forceAvroSchema.get.getElementType))
} else {
createConverterToAvro(elementType, structName,
recordNamespace, forceAvroSchema)
}
(item: Any) => {
if (item == null) {
null
Expand All @@ -117,14 +184,20 @@ private[avro] class AvroOutputWriter(
targetArray(idx) = elementConverter(sourceArray(idx))
idx += 1
}
targetArray
targetArray.toSeq.asJava
}
}
case MapType(StringType, valueType, _) =>
val valueConverter = createConverterToAvro(valueType, structName, recordNamespace)
val valueConverter = if (forceAvroSchema.isDefined) {
createConverterToAvro(valueType, structName,
recordNamespace, Option(forceAvroSchema.get.getValueType))
} else {
createConverterToAvro(valueType, structName,
recordNamespace, forceAvroSchema)
}
(item: Any) => {
if (item == null) {
null
if (forceAvroSchema.isDefined) new HashMap[String, Any]() else null
} else {
val javaMap = new HashMap[String, Any]()
item.asInstanceOf[Map[String, Any]].foreach { case (key, value) =>
Expand All @@ -135,10 +208,36 @@ private[avro] class AvroOutputWriter(
}
case structType: StructType =>
val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
val schema: Schema = SchemaConverters.convertStructToAvro(
structType, builder, recordNamespace)
val fieldConverters = structType.fields.map(field =>
createConverterToAvro(field.dataType, field.name, recordNamespace))
val schema: Schema = if (!forceAvroSchema.isDefined) {
SchemaConverters.convertStructToAvro(
structType, builder, recordNamespace)
} else {
if (forceAvroSchema.get.getType == Schema.Type.ARRAY) {
forceAvroSchema.get.getElementType
} else {
forceAvroSchema.get
}
}

val fieldConverters = structType.fields.map (
field => {
val fieldConvertSchema = if (forceAvroSchema.isDefined) {
val thisFieldSchema = schema.getField(field.name).schema
Option(
thisFieldSchema.getType match {
case Schema.Type.UNION => {
resolveStructTypeToAvroUnion(thisFieldSchema, field.dataType.typeName)
}
case default => thisFieldSchema
}
)
} else {
forceAvroSchema
}
createConverterToAvro(field.dataType, field.name, recordNamespace, fieldConvertSchema)
}
)

(item: Any) => {
if (item == null) {
null
Expand All @@ -150,9 +249,27 @@ private[avro] class AvroOutputWriter(

while (convertersIterator.hasNext) {
val converter = convertersIterator.next()
record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
val fieldValue = rowIterator.next()
val fieldName = fieldNamesIterator.next()
try {
record.put(fieldName, converter(fieldValue))
} catch {
case ex:NullPointerException => {
// This can happen with forceAvroSchema conversion
if (forceAvroSchema.isDefined) {
logger.debug(s"Trying to write field $fieldName which may be null? $fieldValue")
} else {
// Keep previous behavior when forceAvroSchema is not used
throw ex
}
}
}
}
if(forceAvroSchema.isDefined) {
new GenericRecordBuilder(record).build()
} else {
record
}
record
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,33 @@

package com.databricks.spark.avro

import org.apache.avro.Schema
import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

private[avro] class AvroOutputWriterFactory(
schema: StructType,
recordName: String,
recordNamespace: String) extends OutputWriterFactory {
recordNamespace: String,
forceSchema: String) extends OutputWriterFactory {

override def getFileExtension(context: TaskAttemptContext): String = {
def getFileExtension(context: TaskAttemptContext): String = {
".avro"
}

override def newInstance(
def newInstance(
path: String,
bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
newInstance(path, dataSchema, context)
}

def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new AvroOutputWriter(path, context, schema, recordName, recordNamespace)
new AvroOutputWriter(path, context, schema, recordName, recordNamespace, forceSchema)
}
}
11 changes: 9 additions & 2 deletions src/main/scala/com/databricks/spark/avro/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,17 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
dataSchema: StructType): OutputWriterFactory = {
val recordName = options.getOrElse("recordName", "topLevelRecord")
val recordNamespace = options.getOrElse("recordNamespace", "")
val forceAvroSchema = options.getOrElse("forceSchema", "")
val build = SchemaBuilder.record(recordName).namespace(recordNamespace)
val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace)
val outputAvroSchema = if (forceAvroSchema.contentEquals("")) {
SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace)
} else {
val parser = new Schema.Parser()
parser.parse(forceAvroSchema)
}

AvroJob.setOutputKeySchema(job, outputAvroSchema)

val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level"
val COMPRESS_KEY = "mapred.output.compress"
Expand All @@ -142,7 +149,7 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
log.error(s"unsupported compression codec $unknown")
}

new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace)
new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace, forceAvroSchema)
}

override def buildReader(
Expand Down
Binary file added src/test/resources/messy.avro
Binary file not shown.