Skip to content

Commit

Permalink
[SPARK-5938] [SPARK-5443] [SQL] Improve JsonRDD performance
Browse files Browse the repository at this point in the history
This patch comprises of a few related pieces of work:

* Schema inference is performed directly on the JSON token stream
* `String => Row` conversion populate Spark SQL structures without intermediate types
* Projection pushdown is implemented via CatalystScan for DataFrame queries
* Support for the legacy parser by setting `spark.sql.json.useJacksonStreamingAPI` to `false`

Performance improvements depend on the schema and queries being executed, but it should be faster across the board. Below are benchmarks using the last.fm Million Song dataset:

```
Command                                            | Baseline | Patched
---------------------------------------------------|----------|--------
import sqlContext.implicits._                      |          |
val df = sqlContext.jsonFile("/tmp/lastfm.json")   |    70.0s |   14.6s
df.count()                                         |    28.8s |    6.2s
df.rdd.count()                                     |    35.3s |   21.5s
df.where($"artist" === "Robert Hood").collect()    |    28.3s |   16.9s
```

To prepare this dataset for benchmarking, follow these steps:

```
# Fetch the datasets from http://labrosa.ee.columbia.edu/millionsong/lastfm
wget http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_test.zip \
     http://labrosa.ee.columbia.edu/millionsong/sites/default/files/lastfm/lastfm_train.zip

# Decompress and combine, pipe through `jq -c` to ensure there is one record per line
unzip -p lastfm_test.zip lastfm_train.zip  | jq -c . > lastfm.json
```

Author: Nathan Howell <nhowell@godaddy.com>

Closes apache#5801 from NathanHowell/json-performance and squashes the following commits:

26fea31 [Nathan Howell] Recreate the baseRDD each for each scan operation
a7ebeb2 [Nathan Howell] Increase coverage of inserts into a JSONRelation
e06a1dd [Nathan Howell] Add comments to the `useJacksonStreamingAPI` config flag
6822712 [Nathan Howell] Split up JsonRDD2 into multiple objects
fa8234f [Nathan Howell] Wrap long lines
b31917b [Nathan Howell] Rename `useJsonRDD2` to `useJacksonStreamingAPI`
15c5d1b [Nathan Howell] JSONRelation's baseRDD need not be lazy
f8add6e [Nathan Howell] Add comments on lack of support for precision and scale DecimalTypes
fa0be47 [Nathan Howell] Remove unused default case in the field parser
80dba17 [Nathan Howell] Add comments regarding null handling and empty strings
842846d [Nathan Howell] Point the empty schema inference test at JsonRDD2
ab6ee87 [Nathan Howell] Add projection pushdown support to JsonRDD/JsonRDD2
f636c14 [Nathan Howell] Enable JsonRDD2 by default, add a flag to switch back to JsonRDD
0bbc445 [Nathan Howell] Improve JSON parsing and type inference performance
7ca70c1 [Nathan Howell] Eliminate arrow pattern, replace with pattern matches
  • Loading branch information
Nathan Howell authored and yhuai committed May 7, 2015
1 parent 9cfa9a5 commit 2d6612c
Show file tree
Hide file tree
Showing 13 changed files with 715 additions and 128 deletions.
Expand Up @@ -26,33 +26,36 @@ object HiveTypeCoercion {
// See https://cwiki.apache.org/confluence/display/Hive/LanguageManual+Types.
// The conversion for integral and floating point types have a linear widening hierarchy:
private val numericPrecedence =
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType, DecimalType.Unlimited)
IndexedSeq(
ByteType,
ShortType,
IntegerType,
LongType,
FloatType,
DoubleType,
DecimalType.Unlimited)

/**
* Find the tightest common type of two types that might be used in a binary expression.
* This handles all numeric types except fixed-precision decimals interacting with each other or
* with primitive types, because in that case the precision and scale of the result depends on
* the operation. Those rules are implemented in [[HiveTypeCoercion.DecimalPrecision]].
*/
def findTightestCommonType(t1: DataType, t2: DataType): Option[DataType] = {
val valueTypes = Seq(t1, t2).filter(t => t != NullType)
if (valueTypes.distinct.size > 1) {
// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
if (numericPrecedence.contains(t1) && numericPrecedence.contains(t2)) {
Some(numericPrecedence.filter(t => t == t1 || t == t2).last)
} else if (t1.isInstanceOf[DecimalType] && t2.isInstanceOf[DecimalType]) {
// Fixed-precision decimals can up-cast into unlimited
if (t1 == DecimalType.Unlimited || t2 == DecimalType.Unlimited) {
Some(DecimalType.Unlimited)
} else {
None
}
} else {
None
}
} else {
Some(if (valueTypes.size == 0) NullType else valueTypes.head)
}
val findTightestCommonType: (DataType, DataType) => Option[DataType] = {
case (t1, t2) if t1 == t2 => Some(t1)
case (NullType, t1) => Some(t1)
case (t1, NullType) => Some(t1)

// Promote numeric types to the highest of the two and all numeric types to unlimited decimal
case (t1, t2) if Seq(t1, t2).forall(numericPrecedence.contains) =>
val index = numericPrecedence.lastIndexWhere(t => t == t1 || t == t2)
Some(numericPrecedence(index))

// Fixed-precision decimals can up-cast into unlimited
case (DecimalType.Unlimited, _: DecimalType) => Some(DecimalType.Unlimited)
case (_: DecimalType, DecimalType.Unlimited) => Some(DecimalType.Unlimited)

case _ => None
}
}

Expand Down
Expand Up @@ -134,6 +134,10 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru
throw new IllegalArgumentException(s"""Field "$name" does not exist."""))
}

private[sql] def getFieldIndex(name: String): Option[Int] = {
nameToIndex.get(name)
}

protected[sql] def toAttributes: Seq[AttributeReference] =
map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)())

Expand Down
4 changes: 2 additions & 2 deletions sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
Expand Up @@ -42,7 +42,7 @@ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{EvaluatePython, ExplainCommand, LogicalRDD}
import org.apache.spark.sql.jdbc.JDBCWriteDetails
import org.apache.spark.sql.json.JsonRDD
import org.apache.spark.sql.json.{JacksonGenerator, JsonRDD}
import org.apache.spark.sql.types._
import org.apache.spark.sql.sources.{ResolvedDataSource, CreateTableUsingAsSelect}
import org.apache.spark.util.Utils
Expand Down Expand Up @@ -1415,7 +1415,7 @@ class DataFrame private[sql](
new Iterator[String] {
override def hasNext: Boolean = iter.hasNext
override def next(): String = {
JsonRDD.rowToJSON(rowSchema, gen)(iter.next())
JacksonGenerator(rowSchema, gen)(iter.next())
gen.flush()

val json = writer.toString
Expand Down
8 changes: 8 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
Expand Up @@ -73,6 +73,8 @@ private[spark] object SQLConf {

val USE_SQL_SERIALIZER2 = "spark.sql.useSerializer2"

val USE_JACKSON_STREAMING_API = "spark.sql.json.useJacksonStreamingAPI"

object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
}
Expand Down Expand Up @@ -166,6 +168,12 @@ private[sql] class SQLConf extends Serializable {

private[spark] def useSqlSerializer2: Boolean = getConf(USE_SQL_SERIALIZER2, "true").toBoolean

/**
* Selects between the new (true) and old (false) JSON handlers, to be removed in Spark 1.5.0
*/
private[spark] def useJacksonStreamingAPI: Boolean =
getConf(USE_JACKSON_STREAMING_API, "true").toBoolean

/**
* Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
* a broadcast value during the physical executions of join operations. Setting this to -1
Expand Down
34 changes: 21 additions & 13 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Expand Up @@ -659,13 +659,17 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], schema: StructType): DataFrame = {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
Option(schema).getOrElse(
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
if (conf.useJacksonStreamingAPI) {
baseRelationToDataFrame(new JSONRelation(() => json, None, 1.0, Some(schema))(this))
} else {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
Option(schema).getOrElse(
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, 1.0, columnNameOfCorruptJsonRecord)))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
}
}

/**
Expand All @@ -689,12 +693,16 @@ class SQLContext(@transient val sparkContext: SparkContext)
*/
@Experimental
def jsonRDD(json: RDD[String], samplingRatio: Double): DataFrame = {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
if (conf.useJacksonStreamingAPI) {
baseRelationToDataFrame(new JSONRelation(() => json, None, samplingRatio, None)(this))
} else {
val columnNameOfCorruptJsonRecord = conf.columnNameOfCorruptRecord
val appliedSchema =
JsonRDD.nullTypeToStringType(
JsonRDD.inferSchema(json, samplingRatio, columnNameOfCorruptJsonRecord))
val rowRDD = JsonRDD.jsonStringToRow(json, appliedSchema, columnNameOfCorruptJsonRecord)
createDataFrame(rowRDD, appliedSchema, needsConversion = false)
}
}

/**
Expand Down
171 changes: 171 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/json/InferSchema.scala
@@ -0,0 +1,171 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.json

import com.fasterxml.jackson.core._

import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion
import org.apache.spark.sql.json.JacksonUtils.nextUntil
import org.apache.spark.sql.types._

private[sql] object InferSchema {
/**
* Infer the type of a collection of json records in three stages:
* 1. Infer the type of each record
* 2. Merge types by choosing the lowest type necessary to cover equal keys
* 3. Replace any remaining null fields with string, the top type
*/
def apply(
json: RDD[String],
samplingRatio: Double = 1.0,
columnNameOfCorruptRecords: String): StructType = {
require(samplingRatio > 0, s"samplingRatio ($samplingRatio) should be greater than 0")
val schemaData = if (samplingRatio > 0.99) {
json
} else {
json.sample(withReplacement = false, samplingRatio, 1)
}

// perform schema inference on each row and merge afterwards
schemaData.mapPartitions { iter =>
val factory = new JsonFactory()
iter.map { row =>
try {
val parser = factory.createParser(row)
parser.nextToken()
inferField(parser)
} catch {
case _: JsonParseException =>
StructType(Seq(StructField(columnNameOfCorruptRecords, StringType)))
}
}
}.treeAggregate[DataType](StructType(Seq()))(compatibleRootType, compatibleRootType) match {
case st: StructType => nullTypeToStringType(st)
}
}

/**
* Infer the type of a json document from the parser's token stream
*/
private def inferField(parser: JsonParser): DataType = {
import com.fasterxml.jackson.core.JsonToken._
parser.getCurrentToken match {
case null | VALUE_NULL => NullType

case FIELD_NAME =>
parser.nextToken()
inferField(parser)

case VALUE_STRING if parser.getTextLength < 1 =>
// Zero length strings and nulls have special handling to deal
// with JSON generators that do not distinguish between the two.
// To accurately infer types for empty strings that are really
// meant to represent nulls we assume that the two are isomorphic
// but will defer treating null fields as strings until all the
// record fields' types have been combined.
NullType

case VALUE_STRING => StringType
case START_OBJECT =>
val builder = Seq.newBuilder[StructField]
while (nextUntil(parser, END_OBJECT)) {
builder += StructField(parser.getCurrentName, inferField(parser), nullable = true)
}

StructType(builder.result().sortBy(_.name))

case START_ARRAY =>
// If this JSON array is empty, we use NullType as a placeholder.
// If this array is not empty in other JSON objects, we can resolve
// the type as we pass through all JSON objects.
var elementType: DataType = NullType
while (nextUntil(parser, END_ARRAY)) {
elementType = compatibleType(elementType, inferField(parser))
}

ArrayType(elementType)

case VALUE_NUMBER_INT | VALUE_NUMBER_FLOAT =>
import JsonParser.NumberType._
parser.getNumberType match {
// For Integer values, use LongType by default.
case INT | LONG => LongType
// Since we do not have a data type backed by BigInteger,
// when we see a Java BigInteger, we use DecimalType.
case BIG_INTEGER | BIG_DECIMAL => DecimalType.Unlimited
case FLOAT | DOUBLE => DoubleType
}

case VALUE_TRUE | VALUE_FALSE => BooleanType
}
}

private def nullTypeToStringType(struct: StructType): StructType = {
val fields = struct.fields.map {
case StructField(fieldName, dataType, nullable, _) =>
val newType = dataType match {
case NullType => StringType
case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull)
case ArrayType(struct: StructType, containsNull) =>
ArrayType(nullTypeToStringType(struct), containsNull)
case struct: StructType =>nullTypeToStringType(struct)
case other: DataType => other
}

StructField(fieldName, newType, nullable)
}

StructType(fields)
}

/**
* Remove top-level ArrayType wrappers and merge the remaining schemas
*/
private def compatibleRootType: (DataType, DataType) => DataType = {
case (ArrayType(ty1, _), ty2) => compatibleRootType(ty1, ty2)
case (ty1, ArrayType(ty2, _)) => compatibleRootType(ty1, ty2)
case (ty1, ty2) => compatibleType(ty1, ty2)
}

/**
* Returns the most general data type for two given data types.
*/
private[json] def compatibleType(t1: DataType, t2: DataType): DataType = {
HiveTypeCoercion.findTightestCommonType(t1, t2).getOrElse {
// t1 or t2 is a StructType, ArrayType, or an unexpected type.
(t1, t2) match {
case (other: DataType, NullType) => other
case (NullType, other: DataType) => other
case (StructType(fields1), StructType(fields2)) =>
val newFields = (fields1 ++ fields2).groupBy(field => field.name).map {
case (name, fieldTypes) =>
val dataType = fieldTypes.view.map(_.dataType).reduce(compatibleType)
StructField(name, dataType, nullable = true)
}
StructType(newFields.toSeq.sortBy(_.name))

case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) =>
ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2)

// strings and every string is a Json object.
case (_, _) => StringType
}
}
}
}

0 comments on commit 2d6612c

Please sign in to comment.