Skip to content
This repository has been archived by the owner on Dec 20, 2018. It is now read-only.

Commit

Permalink
Add forceSchema option to output to specified schema
Browse files Browse the repository at this point in the history
  • Loading branch information
lindblombr committed Mar 25, 2017
1 parent c19f01a commit 7e6b342
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 21 deletions.
156 changes: 142 additions & 14 deletions src/main/scala/com/databricks/spark/avro/AvroOutputWriter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,32 +20,46 @@ import java.io.{IOException, OutputStream}
import java.nio.ByteBuffer
import java.sql.Timestamp
import java.sql.Date
import java.util
import java.util.HashMap

import org.apache.hadoop.fs.Path
import scala.collection.immutable.Map

import scala.collection.immutable.Map
import org.apache.avro.generic.GenericData.Record
import org.apache.avro.generic.GenericRecord
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.mapred.AvroKey
import org.apache.avro.mapreduce.AvroKeyOutputFormat
import org.apache.hadoop.io.NullWritable
import org.apache.hadoop.mapreduce.{RecordWriter, TaskAttemptContext, TaskAttemptID}

import org.apache.log4j.Logger
import org.apache.spark.SPARK_VERSION
import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.datasources.OutputWriter
import org.apache.spark.sql.types._

import scala.collection.JavaConversions._

// NOTE: This class is instantiated and used on executor side only, no need to be serializable.
private[avro] class AvroOutputWriter(
path: String,
context: TaskAttemptContext,
schema: StructType,
recordName: String,
recordNamespace: String) extends OutputWriter {
recordNamespace: String,
forceSchema: String) extends OutputWriter {

private val logger = Logger.getLogger(this.getClass)

private lazy val converter = createConverterToAvro(schema, recordName, recordNamespace)
private val forceAvroSchema = if (forceSchema.contentEquals("")) {
None
} else {
Option(new Schema.Parser().parse(forceSchema))
}
private lazy val converter = createConverterToAvro(
schema, recordName, recordNamespace, forceAvroSchema
)

/**
* Overrides the couple of methods responsible for generating the output streams / files so
Expand All @@ -55,7 +69,14 @@ private[avro] class AvroOutputWriter(
new AvroKeyOutputFormat[GenericRecord]() {

override def getDefaultWorkFile(context: TaskAttemptContext, extension: String): Path = {
new Path(path)
if (SPARK_VERSION.startsWith("2.0")) {
val uniqueWriteJobId = context.getConfiguration.get("spark.sql.sources.writeJobUUID")
val taskAttemptId: TaskAttemptID = context.getTaskAttemptID
val split = taskAttemptId.getTaskID.getId
new Path(path, f"part-r-$split%05d-$uniqueWriteJobId$extension")
} else {
new Path(path)
}
}

@throws(classOf[IOException])
Expand All @@ -73,14 +94,19 @@ private[avro] class AvroOutputWriter(

override def close(): Unit = recordWriter.close(context)

private def resolveForceSchemaUnion(schema:Schema, allowedTypes: List[Schema.Type]): Schema = {
schema.getTypes.find(allowedTypes contains _.getType).get
}

/**
* This function constructs converter function for a given sparkSQL datatype. This is used in
* writing Avro records out to disk
*/
private def createConverterToAvro(
dataType: DataType,
structName: String,
recordNamespace: String): (Any) => Any = {
recordNamespace: String,
forceAvroSchema: Option[Schema]): (Any) => Any = {
dataType match {
case BinaryType => (item: Any) => item match {
case null => null
Expand All @@ -94,7 +120,13 @@ private[avro] class AvroOutputWriter(
case DateType => (item: Any) =>
if (item == null) null else item.asInstanceOf[Date].getTime
case ArrayType(elementType, _) =>
val elementConverter = createConverterToAvro(elementType, structName, recordNamespace)
val elementConverter = if (forceAvroSchema.isDefined) {
createConverterToAvro(elementType, structName,
recordNamespace, Option(forceAvroSchema.get.getElementType))
} else {
createConverterToAvro(elementType, structName,
recordNamespace, forceAvroSchema)
}
(item: Any) => {
if (item == null) {
null
Expand All @@ -111,10 +143,16 @@ private[avro] class AvroOutputWriter(
}
}
case MapType(StringType, valueType, _) =>
val valueConverter = createConverterToAvro(valueType, structName, recordNamespace)
val valueConverter = if (forceAvroSchema.isDefined) {
createConverterToAvro(valueType, structName,
recordNamespace, Option(forceAvroSchema.get.getValueType))
} else {
createConverterToAvro(valueType, structName,
recordNamespace, forceAvroSchema)
}
(item: Any) => {
if (item == null) {
null
if (forceAvroSchema.isDefined) new HashMap[String, Any]() else null
} else {
val javaMap = new HashMap[String, Any]()
item.asInstanceOf[Map[String, Any]].foreach { case (key, value) =>
Expand All @@ -125,10 +163,86 @@ private[avro] class AvroOutputWriter(
}
case structType: StructType =>
val builder = SchemaBuilder.record(structName).namespace(recordNamespace)
val schema: Schema = SchemaConverters.convertStructToAvro(
structType, builder, recordNamespace)
val fieldConverters = structType.fields.map(field =>
createConverterToAvro(field.dataType, field.name, recordNamespace))
val schema: Schema = if (!forceAvroSchema.isDefined) {
SchemaConverters.convertStructToAvro(
structType, builder, recordNamespace)
} else {
if (forceAvroSchema.get.getType == Schema.Type.ARRAY) {
forceAvroSchema.get.getElementType
} else {
forceAvroSchema.get
}
}

val fieldConverters = structType.fields.map (
field => {
val fieldConvertSchema = if (forceAvroSchema.isDefined) {
val thisFieldSchema = schema.getField(field.name).schema
Option(
thisFieldSchema.getType match {
case Schema.Type.UNION => {
field.dataType match {
case MapType(StringType, vt, _) => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.MAP)
)
case ArrayType(et, _) => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.ARRAY)
)
case innerStructType: StructType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.RECORD, Schema.Type.ARRAY)
)
case ByteType | ShortType | IntegerType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.INT, Schema.Type.FIXED)
)
case LongType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.FIXED, Schema.Type.LONG)
)
case FloatType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.FLOAT)
)
case DoubleType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.DOUBLE)
)
case StringType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.STRING)
)
case BooleanType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.BOOLEAN)
)
case _: DecimalType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.STRING)
)
case TimestampType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.LONG)
)
case DateType => resolveForceSchemaUnion(
thisFieldSchema,
List(Schema.Type.LONG)
)
case default => thisFieldSchema
}
}
case default => thisFieldSchema
}
)
} else {
forceAvroSchema
}
createConverterToAvro(field.dataType, field.name, recordNamespace, fieldConvertSchema)
}
)

(item: Any) => {
if (item == null) {
null
Expand All @@ -140,7 +254,21 @@ private[avro] class AvroOutputWriter(

while (convertersIterator.hasNext) {
val converter = convertersIterator.next()
record.put(fieldNamesIterator.next(), converter(rowIterator.next()))
val fieldValue = rowIterator.next()
val fieldName = fieldNamesIterator.next()
try {
record.put(fieldName, converter(fieldValue))
} catch {
case ex:NullPointerException => {
// This can happen with forceAvroSchema conversion
if (forceAvroSchema.isDefined) {
logger.info(s"Trying to write field $fieldName which may be null? $fieldValue")
} else {
// Keep previous behavior when forceAvroSchema is not used
throw ex
}
}
}
}
record
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,33 @@

package com.databricks.spark.avro

import org.apache.avro.Schema
import org.apache.hadoop.mapreduce.TaskAttemptContext

import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFactory}
import org.apache.spark.sql.types.StructType

private[avro] class AvroOutputWriterFactory(
schema: StructType,
recordName: String,
recordNamespace: String) extends OutputWriterFactory {
recordNamespace: String,
forceSchema: String) extends OutputWriterFactory {

override def getFileExtension(context: TaskAttemptContext): String = {
def getFileExtension(context: TaskAttemptContext): String = {
".avro"
}

override def newInstance(
def newInstance(
path: String,
bucketId: Option[Int],
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
newInstance(path, dataSchema, context)
}

def newInstance(
path: String,
dataSchema: StructType,
context: TaskAttemptContext): OutputWriter = {
new AvroOutputWriter(path, context, schema, recordName, recordNamespace)
new AvroOutputWriter(path, context, schema, recordName, recordNamespace, forceSchema)
}
}
11 changes: 9 additions & 2 deletions src/main/scala/com/databricks/spark/avro/DefaultSource.scala
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,17 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
dataSchema: StructType): OutputWriterFactory = {
val recordName = options.getOrElse("recordName", "topLevelRecord")
val recordNamespace = options.getOrElse("recordNamespace", "")
val forceAvroSchema = options.getOrElse("forceSchema", "")
val build = SchemaBuilder.record(recordName).namespace(recordNamespace)
val outputAvroSchema = SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace)
val outputAvroSchema = if (forceAvroSchema.contentEquals("")) {
SchemaConverters.convertStructToAvro(dataSchema, build, recordNamespace)
} else {
val parser = new Schema.Parser()
parser.parse(forceAvroSchema)
}

AvroJob.setOutputKeySchema(job, outputAvroSchema)

val AVRO_COMPRESSION_CODEC = "spark.sql.avro.compression.codec"
val AVRO_DEFLATE_LEVEL = "spark.sql.avro.deflate.level"
val COMPRESS_KEY = "mapred.output.compress"
Expand All @@ -142,7 +149,7 @@ private[avro] class DefaultSource extends FileFormat with DataSourceRegister {
log.error(s"unsupported compression codec $unknown")
}

new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace)
new AvroOutputWriterFactory(dataSchema, recordName, recordNamespace, forceAvroSchema)
}

override def buildReader(
Expand Down

0 comments on commit 7e6b342

Please sign in to comment.