Skip to content
This repository has been archived by the owner on Dec 20, 2018. It is now read-only.

Commit

Permalink
Enable forceSchema option for writer
Browse files Browse the repository at this point in the history
  • Loading branch information
lindblombr committed Apr 2, 2017
1 parent c19f01a commit 6f88549
Show file tree
Hide file tree
Showing 6 changed files with 464 additions and 25 deletions.
152 changes: 134 additions & 18 deletions src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,47 @@ import java.io.{IOException, OutputStream}
import java.nio.ByteBuffer
import java.sql.Timestamp
import java.sql.Date
import java.util
import java.util.HashMap

import org.apache.hadoop.fs.Path
import scala.collection.immutable.Map

import scala.collection.immutable.Map
import org.apache.avro.generic.GenericData.Record
import org.apache.avro.generic.GenericRecord
import org.apache.avro.generic.{GenericData, GenericRecord, GenericRecordBuilder}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.mapred.AvroKey
import org.apache.avro.mapreduce.AvroKeyOutputFormat
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, TaskAttemptID}

import org.apache.log4j.Logger
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types._

import scala.collection.JavaConversions._
import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer

// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
private[avro] class AvroOutputWriter(
path: String,
context: TaskAttemptContext,
schema: StructType,
recordName: String,
recordNamespace: String) extends OutputWriter {
recordNamespace: String,
forceSchema: String) extends OutputWriter {

private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace)
private val logger = Logger.getLogger(this.getClass)

private val forceAvroSchema = if (forceSchema.contentEquals("")) {
None
} else {
Option(new Schema.Parser().parse(forceSchema))
}
private lazy val converter = createConverterToAvro(
schema, recordName, recordNamespace, forceAvroSchema
)

/**
* Overrides the couple of methods responsible for generating the output streams / files so
Expand Down Expand Up @@ -73,28 +88,79 @@ private[avro] class AvroOutputWriter(

override def close(): Unit = recordWriter.close(context)

private def resolveStructTypeToAvroUnion(schema:Schema, dataType:String): Schema = {
val allowedAvroTypes = dataType match {
case "boolean" => List(Schema.Type.BOOLEAN)
case "integer" => List(Schema.Type.INT)
case "long" => List(Schema.Type.LONG)
case "float" => List(Schema.Type.FLOAT)
case "double" => List(Schema.Type.DOUBLE)
case "binary" => List(Schema.Type.BYTES, Schema.Type.FIXED)
case "array" => List(Schema.Type.ARRAY)
case "map" => List(Schema.Type.MAP)
case "string" => List(Schema.Type.STRING, Schema.Type.ENUM)
case "struct" => List(Schema.Type.ARRAY, Schema.Type.RECORD)
case default => {
throw new RuntimeException(
s"Cannot map SparkSQL type '$dataType' against Avro schema '$schema'"
)
}
}
schema.getTypes.find (allowedAvroTypes contains _.getType).get
}

/**
* This function constructs converter function for a given sparkSQL datatype. This is used in
* writing Avro records out to disk
*/
private def createConverterToAvro(
dataType: DataType,
structName: String,
recordNamespace: String): (Any) => Any = {
recordNamespace: String,
forceAvroSchema: Option[Schema]): (Any) => Any = {
dataType match {
case BinaryType => (item: Any) => item match {
case null => null
case bytes: Array[Byte] => ByteBuffer.wrap(bytes)
case bytes: Array[Byte] => if (forceAvroSchema.isDefined) {
// Handle mapping from binary => bytes|fixed w/ forceSchema
forceAvroSchema.get.getType match {
case Schema.Type.BYTES => ByteBuffer.wrap(bytes)
case Schema.Type.FIXED => new GenericData.Fixed(
forceAvroSchema.get, bytes
)
case default => bytes
}
} else {
ByteBuffer.wrap(bytes)
}
}
case ByteType | ShortType | IntegerType | LongType |
FloatType | DoubleType | StringType | BooleanType => identity
FloatType | DoubleType | BooleanType => identity
case StringType => (item: Any) => if (forceAvroSchema.isDefined) {
// Handle case when forcing schema where this string should map
// to an ENUM
forceAvroSchema.get.getType match {
case Schema.Type.ENUM => new GenericData.EnumSymbol(
forceAvroSchema.get, item.toString
)
case default => item
}
} else {
item
}
case _: DecimalType => (item: Any) => if (item == null) null else item.toString
case TimestampType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Timestamp].getTime
case DateType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Date].getTime
case ArrayType(elementType, _) =>
val elementConverter = createConverterToAvro(elementType, structName, recordNamespace)
val elementConverter = if (forceAvroSchema.isDefined) {
createConverterToAvro(elementType, structName,
recordNamespace, Option(forceAvroSchema.get.getElementType))
} else {
createConverterToAvro(elementType, structName,
recordNamespace, forceAvroSchema)
}
(item: Any) => {
if (item == null) {
null
Expand All @@ -107,14 +173,20 @@ private[avro] class AvroOutputWriter(
targetArray(idx) = elementConverter(sourceArray(idx))
idx += 1
}
targetArray
targetArray.toSeq.asJava
}
}
case MapType(StringType, valueType, _) =>
val valueConverter = createConverterToAvro(valueType, structName, recordNamespace)
val valueConverter = if (forceAvroSchema.isDefined) {
createConverterToAvro(valueType, structName,
recordNamespace, Option(forceAvroSchema.get.getValueType))
} else {
createConverterToAvro(valueType, structName,
recordNamespace, forceAvroSchema)
}
(item: Any) => {
if (item == null) {
null
if (forceAvroSchema.isDefined) new HashMap[String, Any]() else null
} else {
val javaMap = new HashMap[String, Any]()
item.asInstanceOf[Map[String, Any]].foreach { case (key, value) =>
Expand All @@ -125,10 +197,36 @@ private[avro] class AvroOutputWriter(
}
case structType: StructType =>
val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
val schema: Schema = SchemaConverters.convertStructToAvro(
structType, builder, recordNamespace)
val fieldConverters = structType.fields.map(field =>
createConverterToAvro(field.dataType, field.name, recordNamespace))
val schema: Schema = if (!forceAvroSchema.isDefined) {
SchemaConverters.convertStructToAvro(
structType, builder, recordNamespace)
} else {
if (forceAvroSchema.get.getType == Schema.Type.ARRAY) {
forceAvroSchema.get.getElementType
} else {
forceAvroSchema.get
}
}

val fieldConverters = structType.fields.map (
field => {
val fieldConvertSchema = if (forceAvroSchema.isDefined) {
val thisFieldSchema = schema.getField(field.name).schema
Option(
thisFieldSchema.getType match {
case Schema.Type.UNION => {
resolveStructTypeToAvroUnion(thisFieldSchema, field.dataType.typeName)
}
case default => thisFieldSchema
}
)
} else {
forceAvroSchema
}
createConverterToAvro(field.dataType, field.name, recordNamespace, fieldConvertSchema)
}
)

(item: Any) => {
if (item == null) {
null
Expand All @@ -140,9 +238,27 @@ private[avro] class AvroOutputWriter(

while (convertersIterator.hasNext) {
val converter = convertersIterator.next()
record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
val fieldValue = rowIterator.next()
val fieldName = fieldNamesIterator.next()
try {
record.put(fieldName, converter(fieldValue))
} catch {
case ex:NullPointerException => {
// This can happen with forceAvroSchema conversion
if (forceAvroSchema.isDefined) {
logger.debug(s"Trying to write field $fieldName which may be null? $fieldValue")
} else {
// Keep previous behavior when forceAvroSchema is not used
throw ex
}
}
}
}
if(forceAvroSchema.isDefined) {
new GenericRecordBuilder(record).build()
} else {
record
}
record
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,33 @@

package com.databricks.spark.avro

import org.apache.avro.Schema
import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

private[avro] class AvroOutputWriterFactory(
schema: StructType,
recordName: String,
recordNamespace: String) extends OutputWriterFactory {
recordNamespace: String,
forceSchema: String) extends OutputWriterFactory {

override def getFileExtension(context: TaskAttemptContext): String = {
def getFileExtension(context: TaskAttemptContext): String = {
".avro"
}

override def newInstance(
def newInstance(
path: String,
bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
newInstance(path, dataSchema, context)
}

def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new AvroOutputWriter(path, context, schema, recordName, recordNamespace)
new AvroOutputWriter(path, context, schema, recordName, recordNamespace, forceSchema)
}
}
11 changes: 9 additions & 2 deletions src/main/scala/com/databricks/spark/avro/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,17 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
dataSchema: StructType): OutputWriterFactory = {
val recordName = options.getOrElse("recordName", "topLevelRecord")
val recordNamespace = options.getOrElse("recordNamespace", "")
val forceAvroSchema = options.getOrElse("forceSchema", "")
val build = SchemaBuilder.record(recordName).namespace(recordNamespace)
val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace)
val outputAvroSchema = if (forceAvroSchema.contentEquals("")) {
SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace)
} else {
val parser = new Schema.Parser()
parser.parse(forceAvroSchema)
}

AvroJob.setOutputKeySchema(job, outputAvroSchema)

val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level"
val COMPRESS_KEY = "mapred.output.compress"
Expand All @@ -142,7 +149,7 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
log.error(s"unsupported compression codec $unknown")
}

new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace)
new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace, forceAvroSchema)
}

override def buildReader(
Expand Down
Binary file added src/test/resources/messy.avro
Binary file not shown.

0 comments on commit 6f88549

Please sign in to comment.