Skip to content

Commit

Permalink
[SPARK-28698][SQL] Support user-specified output schema in to_avro
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

The mapping of Spark schema to Avro schema is many-to-many. (See https://spark.apache.org/docs/latest/sql-data-sources-avro.html#supported-types-for-spark-sql---avro-conversion)
The default schema mapping might not be exactly what users want. For example, by default, a "string" column is always written as "string" Avro type, but users might want to output the column as "enum" Avro type.
With PR apache#21847, Spark supports user-specified schema in the batch writer.
For the function `to_avro`, we should support user-specified output schema as well.

## How was this patch tested?

Unit test.

Closes apache#25419 from gengliangwang/to_avro.

Authored-by: Gengliang Wang <gengliang.wang@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
gengliangwang authored and cloud-fan committed Aug 13, 2019
1 parent 3249c7a commit 48adc91
Show file tree
Hide file tree
Showing 4 changed files with 76 additions and 12 deletions.
Expand Up @@ -19,19 +19,24 @@ package org.apache.spark.sql.avro

import java.io.ByteArrayOutputStream

import org.apache.avro.Schema
import org.apache.avro.generic.GenericDatumWriter
import org.apache.avro.io.{BinaryEncoder, EncoderFactory}

import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.types.{BinaryType, DataType}

case class CatalystDataToAvro(child: Expression) extends UnaryExpression {
case class CatalystDataToAvro(
child: Expression,
jsonFormatSchema: Option[String]) extends UnaryExpression {

override def dataType: DataType = BinaryType

@transient private lazy val avroType =
SchemaConverters.toAvroType(child.dataType, child.nullable)
jsonFormatSchema
.map(new Schema.Parser().parse)
.getOrElse(SchemaConverters.toAvroType(child.dataType, child.nullable))

@transient private lazy val serializer =
new AvroSerializer(child.dataType, avroType, child.nullable)
Expand Down
Expand Up @@ -72,6 +72,19 @@ object functions {
*/
@Experimental
def to_avro(data: Column): Column = {
new Column(CatalystDataToAvro(data.expr))
new Column(CatalystDataToAvro(data.expr, None))
}

/**
* Converts a column into binary of avro format.
*
* @param data the data column.
* @param jsonFormatSchema user-specified output avro schema in JSON string format.
*
* @since 3.0.0
*/
@Experimental
def to_avro(data: Column, jsonFormatSchema: String): Column = {
new Column(CatalystDataToAvro(data.expr, Some(jsonFormatSchema)))
}
}
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.avro

import org.apache.avro.Schema

import org.apache.spark.SparkFunSuite
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal}
Expand All @@ -38,12 +38,12 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite

private def checkResult(data: Literal, schema: String, expected: Any): Unit = {
checkEvaluation(
AvroDataToCatalyst(CatalystDataToAvro(data), schema, Map.empty),
AvroDataToCatalyst(CatalystDataToAvro(data, None), schema, Map.empty),
prepareExpectedResult(expected))
}

protected def checkUnsupportedRead(data: Literal, schema: String): Unit = {
val binary = CatalystDataToAvro(data)
val binary = CatalystDataToAvro(data, None)
intercept[Exception] {
AvroDataToCatalyst(binary, schema, Map("mode" -> "FAILFAST")).eval()
}
Expand Down Expand Up @@ -209,4 +209,41 @@ class AvroCatalystDataConversionSuite extends SparkFunSuite
checkUnsupportedRead(input, avroSchema)
}
}

test("user-specified output schema") {
val data = Literal("SPADES")
val jsonFormatSchema =
"""
|{ "type": "enum",
| "name": "Suit",
| "symbols" : ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]
|}
""".stripMargin

val message = intercept[SparkException] {
AvroDataToCatalyst(
CatalystDataToAvro(
data,
None),
jsonFormatSchema,
options = Map.empty).eval()
}.getMessage
assert(message.contains("Malformed records are detected in record parsing."))

checkEvaluation(
AvroDataToCatalyst(
CatalystDataToAvro(
data,
Some(jsonFormatSchema)),
jsonFormatSchema,
options = Map.empty),
data.eval())
}

test("invalid user-specified output schema") {
val message = intercept[IncompatibleSchemaException] {
CatalystDataToAvro(Literal("SPADES"), Some("\"long\"")).eval()
}.getMessage
assert(message == "Cannot convert Catalyst type StringType to Avro type \"long\".")
}
}
21 changes: 15 additions & 6 deletions python/pyspark/sql/avro/functions.py
Expand Up @@ -69,26 +69,35 @@ def from_avro(data, jsonFormatSchema, options={}):

@ignore_unicode_prefix
@since(3.0)
def to_avro(data):
def to_avro(data, jsonFormatSchema=""):
"""
Converts a column into binary of avro format.
Note: Avro is built-in but external data source module since Spark 2.4. Please deploy the
application as per the deployment section of "Apache Avro Data Source Guide".
:param data: the data column.
:param jsonFormatSchema: user-specified output avro schema in JSON string format.
>>> from pyspark.sql import Row
>>> from pyspark.sql.avro.functions import to_avro
>>> data = [(1, Row(name='Alice', age=2))]
>>> df = spark.createDataFrame(data, ("key", "value"))
>>> df.select(to_avro(df.value).alias("avro")).collect()
[Row(avro=bytearray(b'\\x00\\x00\\x04\\x00\\nAlice'))]
>>> data = ['SPADES']
>>> df = spark.createDataFrame(data, "string")
>>> df.select(to_avro(df.value).alias("suite")).collect()
[Row(suite=bytearray(b'\\x00\\x0cSPADES'))]
>>> jsonFormatSchema = '''["null", {"type": "enum", "name": "value",
... "symbols": ["SPADES", "HEARTS", "DIAMONDS", "CLUBS"]}]'''
>>> df.select(to_avro(df.value, jsonFormatSchema).alias("suite")).collect()
[Row(suite=bytearray(b'\\x02\\x00'))]
"""

sc = SparkContext._active_spark_context
try:
jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(_to_java_column(data))
if jsonFormatSchema == "":
jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(_to_java_column(data))
else:
jc = sc._jvm.org.apache.spark.sql.avro.functions.to_avro(
_to_java_column(data), jsonFormatSchema)
except TypeError as e:
if str(e) == "'JavaPackage' object is not callable":
_print_missing_jar("Avro", "avro", "avro", sc.version)
Expand Down

0 comments on commit 48adc91

Please sign in to comment.