diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index f0bd3cbd985da..93bfc25bca855 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -313,12 +313,15 @@ class StructField(DataType): """ - def __init__(self, name, dataType, nullable): + def __init__(self, name, dataType, nullable, metadata=None): """Creates a StructField :param name: the name of this field. :param dataType: the data type of this field. :param nullable: indicates whether values of this field can be null. + :param metadata: metadata of this field, which is a map from string + to simple type that can be serialized to JSON + automatically >>> (StructField("f1", StringType, True) ... == StructField("f1", StringType, True)) @@ -330,6 +333,7 @@ def __init__(self, name, dataType, nullable): self.name = name self.dataType = dataType self.nullable = nullable + self.metadata = metadata or {} def __repr__(self): return "StructField(%s,%s,%s)" % (self.name, self.dataType, @@ -338,13 +342,15 @@ def __repr__(self): def jsonValue(self): return {"name": self.name, "type": self.dataType.jsonValue(), - "nullable": self.nullable} + "nullable": self.nullable, + "metadata": self.metadata} @classmethod def fromJson(cls, json): return StructField(json["name"], _parse_datatype_json_value(json["type"]), - json["nullable"]) + json["nullable"], + json["metadata"]) class StructType(DataType): @@ -423,7 +429,8 @@ def _parse_datatype_json_string(json_string): ... StructField("simpleArray", simple_arraytype, True), ... StructField("simpleMap", simple_maptype, True), ... StructField("simpleStruct", simple_structtype, True), - ... StructField("boolean", BooleanType(), False)]) + ... StructField("boolean", BooleanType(), False), + ... StructField("withMeta", DoubleType(), False, {"name": "age"})]) >>> check_datatype(complex_structtype) True >>> # Complex ArrayType. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d76c743d3f652..75923d9e8d729 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -46,7 +46,7 @@ object ScalaReflection { /** Returns a Sequence of attributes for the given case class type. */ def attributesFor[T: TypeTag]: Seq[Attribute] = schemaFor[T] match { case Schema(s: StructType, _) => - s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + s.fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) } /** Returns a catalyst DataType and its nullability for the given Scala Type using reflection. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 1eb260efa6387..39b120e8de485 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.types.{DataType, FractionalType, IntegralType, NumericType, NativeType} +import org.apache.spark.sql.catalyst.util.Metadata abstract class Expression extends TreeNode[Expression] { self: Product => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index 9c865254e0be9..ab0701fd9a80b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -43,7 +43,7 @@ abstract class Generator extends Expression { override type EvaluatedType = TraversableOnce[Row] override lazy val dataType = - ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable)))) + ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata)))) override def nullable = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index fe13a661f6f7a..3310566087b3d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -21,6 +21,7 @@ import org.apache.spark.sql.catalyst.trees import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.types._ +import org.apache.spark.sql.catalyst.util.Metadata object NamedExpression { private val curId = new java.util.concurrent.atomic.AtomicLong() @@ -43,6 +44,9 @@ abstract class NamedExpression extends Expression { def toAttribute: Attribute + /** Returns the metadata when an expression is a reference to another expression with metadata. */ + def metadata: Metadata = Metadata.empty + protected def typeSuffix = if (resolved) { dataType match { @@ -88,10 +92,16 @@ case class Alias(child: Expression, name: String) override def dataType = child.dataType override def nullable = child.nullable + override def metadata: Metadata = { + child match { + case named: NamedExpression => named.metadata + case _ => Metadata.empty + } + } override def toAttribute = { if (resolved) { - AttributeReference(name, child.dataType, child.nullable)(exprId, qualifiers) + AttributeReference(name, child.dataType, child.nullable, metadata)(exprId, qualifiers) } else { UnresolvedAttribute(name) } @@ -108,15 +118,20 @@ case class Alias(child: Expression, name: String) * @param name The name of this attribute, should only be used during analysis or for debugging. * @param dataType The [[DataType]] of this attribute. * @param nullable True if null is a valid value for this attribute. + * @param metadata The metadata of this attribute. * @param exprId A globally unique id used to check if different AttributeReferences refer to the * same attribute. * @param qualifiers a list of strings that can be used to referred to this attribute in a fully * qualified way. Consider the examples tableName.name, subQueryAlias.name. * tableName and subQueryAlias are possible qualifiers. */ -case class AttributeReference(name: String, dataType: DataType, nullable: Boolean = true) - (val exprId: ExprId = NamedExpression.newExprId, val qualifiers: Seq[String] = Nil) - extends Attribute with trees.LeafNode[Expression] { +case class AttributeReference( + name: String, + dataType: DataType, + nullable: Boolean = true, + override val metadata: Metadata = Metadata.empty)( + val exprId: ExprId = NamedExpression.newExprId, + val qualifiers: Seq[String] = Nil) extends Attribute with trees.LeafNode[Expression] { override def equals(other: Any) = other match { case ar: AttributeReference => exprId == ar.exprId && dataType == ar.dataType @@ -128,10 +143,12 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea var h = 17 h = h * 37 + exprId.hashCode() h = h * 37 + dataType.hashCode() + h = h * 37 + metadata.hashCode() h } - override def newInstance() = AttributeReference(name, dataType, nullable)(qualifiers = qualifiers) + override def newInstance() = + AttributeReference(name, dataType, nullable, metadata)(qualifiers = qualifiers) /** * Returns a copy of this [[AttributeReference]] with changed nullability. @@ -140,7 +157,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea if (nullable == newNullability) { this } else { - AttributeReference(name, dataType, newNullability)(exprId, qualifiers) + AttributeReference(name, dataType, newNullability, metadata)(exprId, qualifiers) } } @@ -159,7 +176,7 @@ case class AttributeReference(name: String, dataType: DataType, nullable: Boolea if (newQualifiers.toSet == qualifiers.toSet) { this } else { - AttributeReference(name, dataType, nullable)(exprId, newQualifiers) + AttributeReference(name, dataType, nullable, metadata)(exprId, newQualifiers) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala index 4e6e1166bfffb..6069f9b0a68dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/dataTypes.scala @@ -24,16 +24,16 @@ import scala.reflect.ClassTag import scala.reflect.runtime.universe.{TypeTag, runtimeMirror, typeTag} import scala.util.parsing.combinator.RegexParsers -import org.json4s.JsonAST.JValue import org.json4s._ +import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.sql.catalyst.ScalaReflectionLock import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference, Expression} +import org.apache.spark.sql.catalyst.util.Metadata import org.apache.spark.util.Utils - object DataType { def fromJson(json: String): DataType = parseDataType(parse(json)) @@ -70,10 +70,11 @@ object DataType { private def parseStructField(json: JValue): StructField = json match { case JSortedObject( + ("metadata", metadata: JObject), ("name", JString(name)), ("nullable", JBool(nullable)), ("type", dataType: JValue)) => - StructField(name, parseDataType(dataType), nullable) + StructField(name, parseDataType(dataType), nullable, Metadata.fromJObject(metadata)) } @deprecated("Use DataType.fromJson instead", "1.2.0") @@ -388,24 +389,34 @@ case class ArrayType(elementType: DataType, containsNull: Boolean) extends DataT * @param name The name of this field. * @param dataType The data type of this field. * @param nullable Indicates if values of this field can be `null` values. + * @param metadata The metadata of this field. The metadata should be preserved during + * transformation if the content of the column is not modified, e.g, in selection. */ -case class StructField(name: String, dataType: DataType, nullable: Boolean) { +case class StructField( + name: String, + dataType: DataType, + nullable: Boolean, + metadata: Metadata = Metadata.empty) { private[sql] def buildFormattedString(prefix: String, builder: StringBuilder): Unit = { builder.append(s"$prefix-- $name: ${dataType.typeName} (nullable = $nullable)\n") DataType.buildFormattedString(dataType, s"$prefix |", builder) } + // override the default toString to be compatible with legacy parquet files. + override def toString: String = s"StructField($name,$dataType,$nullable)" + private[sql] def jsonValue: JValue = { ("name" -> name) ~ ("type" -> dataType.jsonValue) ~ - ("nullable" -> nullable) + ("nullable" -> nullable) ~ + ("metadata" -> metadata.jsonValue) } } object StructType { protected[sql] def fromAttributes(attributes: Seq[Attribute]): StructType = - StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable))) + StructType(attributes.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))) } case class StructType(fields: Seq[StructField]) extends DataType { @@ -439,7 +450,7 @@ case class StructType(fields: Seq[StructField]) extends DataType { } protected[sql] def toAttributes = - fields.map(f => AttributeReference(f.name, f.dataType, f.nullable)()) + fields.map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) def treeString: String = { val builder = new StringBuilder diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala new file mode 100644 index 0000000000000..2f2082fa3c863 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/Metadata.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.json4s._ +import org.json4s.jackson.JsonMethods._ + +/** + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use either [[MetadataBuilder]] or + * [[Metadata$#fromJson]] to create Metadata instances. + * + * @param map an immutable map that stores the data + */ +sealed class Metadata private[util] (private[util] val map: Map[String, Any]) extends Serializable { + + /** Gets a Long. */ + def getLong(key: String): Long = get(key) + + /** Gets a Double. */ + def getDouble(key: String): Double = get(key) + + /** Gets a Boolean. */ + def getBoolean(key: String): Boolean = get(key) + + /** Gets a String. */ + def getString(key: String): String = get(key) + + /** Gets a Metadata. */ + def getMetadata(key: String): Metadata = get(key) + + /** Gets a Long array. */ + def getLongArray(key: String): Array[Long] = get(key) + + /** Gets a Double array. */ + def getDoubleArray(key: String): Array[Double] = get(key) + + /** Gets a Boolean array. */ + def getBooleanArray(key: String): Array[Boolean] = get(key) + + /** Gets a String array. */ + def getStringArray(key: String): Array[String] = get(key) + + /** Gets a Metadata array. */ + def getMetadataArray(key: String): Array[Metadata] = get(key) + + /** Converts to its JSON representation. */ + def json: String = compact(render(jsonValue)) + + override def toString: String = json + + override def equals(obj: Any): Boolean = { + obj match { + case that: Metadata => + if (map.keySet == that.map.keySet) { + map.keys.forall { k => + (map(k), that.map(k)) match { + case (v0: Array[_], v1: Array[_]) => + v0.view == v1.view + case (v0, v1) => + v0 == v1 + } + } + } else { + false + } + case other => + false + } + } + + override def hashCode: Int = Metadata.hash(this) + + private def get[T](key: String): T = { + map(key).asInstanceOf[T] + } + + private[sql] def jsonValue: JValue = Metadata.toJsonValue(this) +} + +object Metadata { + + /** Returns an empty Metadata. */ + def empty: Metadata = new Metadata(Map.empty) + + /** Creates a Metadata instance from JSON. */ + def fromJson(json: String): Metadata = { + fromJObject(parse(json).asInstanceOf[JObject]) + } + + /** Creates a Metadata instance from JSON AST. */ + private[sql] def fromJObject(jObj: JObject): Metadata = { + val builder = new MetadataBuilder + jObj.obj.foreach { + case (key, JInt(value)) => + builder.putLong(key, value.toLong) + case (key, JDouble(value)) => + builder.putDouble(key, value) + case (key, JBool(value)) => + builder.putBoolean(key, value) + case (key, JString(value)) => + builder.putString(key, value) + case (key, o: JObject) => + builder.putMetadata(key, fromJObject(o)) + case (key, JArray(value)) => + if (value.isEmpty) { + // If it is an empty array, we cannot infer its element type. We put an empty Array[Long]. + builder.putLongArray(key, Array.empty) + } else { + value.head match { + case _: JInt => + builder.putLongArray(key, value.asInstanceOf[List[JInt]].map(_.num.toLong).toArray) + case _: JDouble => + builder.putDoubleArray(key, value.asInstanceOf[List[JDouble]].map(_.num).toArray) + case _: JBool => + builder.putBooleanArray(key, value.asInstanceOf[List[JBool]].map(_.value).toArray) + case _: JString => + builder.putStringArray(key, value.asInstanceOf[List[JString]].map(_.s).toArray) + case _: JObject => + builder.putMetadataArray( + key, value.asInstanceOf[List[JObject]].map(fromJObject).toArray) + case other => + throw new RuntimeException(s"Do not support array of type ${other.getClass}.") + } + } + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + builder.build() + } + + /** Converts to JSON AST. */ + private def toJsonValue(obj: Any): JValue = { + obj match { + case map: Map[_, _] => + val fields = map.toList.map { case (k: String, v) => (k, toJsonValue(v)) } + JObject(fields) + case arr: Array[_] => + val values = arr.toList.map(toJsonValue) + JArray(values) + case x: Long => + JInt(x) + case x: Double => + JDouble(x) + case x: Boolean => + JBool(x) + case x: String => + JString(x) + case x: Metadata => + toJsonValue(x.map) + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + } + + /** Computes the hash code for the types we support. */ + private def hash(obj: Any): Int = { + obj match { + case map: Map[_, _] => + map.mapValues(hash).## + case arr: Array[_] => + // Seq.empty[T] has the same hashCode regardless of T. + arr.toSeq.map(hash).## + case x: Long => + x.## + case x: Double => + x.## + case x: Boolean => + x.## + case x: String => + x.## + case x: Metadata => + hash(x.map) + case other => + throw new RuntimeException(s"Do not support type ${other.getClass}.") + } + } +} + +/** + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ +class MetadataBuilder { + + private val map: mutable.Map[String, Any] = mutable.Map.empty + + /** Returns the immutable version of this map. Used for java interop. */ + protected def getMap = map.toMap + + /** Include the content of an existing [[Metadata]] instance. */ + def withMetadata(metadata: Metadata): this.type = { + map ++= metadata.map + this + } + + /** Puts a Long. */ + def putLong(key: String, value: Long): this.type = put(key, value) + + /** Puts a Double. */ + def putDouble(key: String, value: Double): this.type = put(key, value) + + /** Puts a Boolean. */ + def putBoolean(key: String, value: Boolean): this.type = put(key, value) + + /** Puts a String. */ + def putString(key: String, value: String): this.type = put(key, value) + + /** Puts a [[Metadata]]. */ + def putMetadata(key: String, value: Metadata): this.type = put(key, value) + + /** Puts a Long array. */ + def putLongArray(key: String, value: Array[Long]): this.type = put(key, value) + + /** Puts a Double array. */ + def putDoubleArray(key: String, value: Array[Double]): this.type = put(key, value) + + /** Puts a Boolean array. */ + def putBooleanArray(key: String, value: Array[Boolean]): this.type = put(key, value) + + /** Puts a String array. */ + def putStringArray(key: String, value: Array[String]): this.type = put(key, value) + + /** Puts a [[Metadata]] array. */ + def putMetadataArray(key: String, value: Array[Metadata]): this.type = put(key, value) + + /** Builds the [[Metadata]] instance. */ + def build(): Metadata = { + new Metadata(map.toMap) + } + + private def put(key: String, value: Any): this.type = { + map.put(key, value) + this + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala new file mode 100644 index 0000000000000..0063d31666c85 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/MetadataSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.json4s.jackson.JsonMethods.parse +import org.scalatest.FunSuite + +class MetadataSuite extends FunSuite { + + val baseMetadata = new MetadataBuilder() + .putString("purpose", "ml") + .putBoolean("isBase", true) + .build() + + val summary = new MetadataBuilder() + .putLong("numFeatures", 10L) + .build() + + val age = new MetadataBuilder() + .putString("name", "age") + .putLong("index", 1L) + .putBoolean("categorical", false) + .putDouble("average", 45.0) + .build() + + val gender = new MetadataBuilder() + .putString("name", "gender") + .putLong("index", 5) + .putBoolean("categorical", true) + .putStringArray("categories", Array("male", "female")) + .build() + + val metadata = new MetadataBuilder() + .withMetadata(baseMetadata) + .putBoolean("isBase", false) // overwrite an existing key + .putMetadata("summary", summary) + .putLongArray("long[]", Array(0L, 1L)) + .putDoubleArray("double[]", Array(3.0, 4.0)) + .putBooleanArray("boolean[]", Array(true, false)) + .putMetadataArray("features", Array(age, gender)) + .build() + + test("metadata builder and getters") { + assert(age.getLong("index") === 1L) + assert(age.getDouble("average") === 45.0) + assert(age.getBoolean("categorical") === false) + assert(age.getString("name") === "age") + assert(metadata.getString("purpose") === "ml") + assert(metadata.getBoolean("isBase") === false) + assert(metadata.getMetadata("summary") === summary) + assert(metadata.getLongArray("long[]").toSeq === Seq(0L, 1L)) + assert(metadata.getDoubleArray("double[]").toSeq === Seq(3.0, 4.0)) + assert(metadata.getBooleanArray("boolean[]").toSeq === Seq(true, false)) + assert(gender.getStringArray("categories").toSeq === Seq("male", "female")) + assert(metadata.getMetadataArray("features").toSeq === Seq(age, gender)) + } + + test("metadata json conversion") { + val json = metadata.json + withClue("toJson must produce a valid JSON string") { + parse(json) + } + val parsed = Metadata.fromJson(json) + assert(parsed === metadata) + assert(parsed.## === metadata.##) + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java index 37e88d72b9172..0c85cdc0aa640 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/DataType.java @@ -17,9 +17,7 @@ package org.apache.spark.sql.api.java; -import java.util.HashSet; -import java.util.List; -import java.util.Set; +import java.util.*; /** * The base type of all Spark SQL data types. @@ -151,15 +149,31 @@ public static MapType createMapType( * Creates a StructField by specifying the name ({@code name}), data type ({@code dataType}) and * whether values of this field can be null values ({@code nullable}). */ - public static StructField createStructField(String name, DataType dataType, boolean nullable) { + public static StructField createStructField( + String name, + DataType dataType, + boolean nullable, + Metadata metadata) { if (name == null) { throw new IllegalArgumentException("name should not be null."); } if (dataType == null) { throw new IllegalArgumentException("dataType should not be null."); } + if (metadata == null) { + throw new IllegalArgumentException("metadata should not be null."); + } + + return new StructField(name, dataType, nullable, metadata); + } - return new StructField(name, dataType, nullable); + /** + * Creates a StructField with empty metadata. + * + * @see #createStructField(String, DataType, boolean, Metadata) + */ + public static StructField createStructField(String name, DataType dataType, boolean nullable) { + return createStructField(name, dataType, nullable, (new MetadataBuilder()).build()); } /** @@ -191,5 +205,4 @@ public static StructType createStructType(StructField[] fields) { return new StructType(fields); } - } diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java new file mode 100644 index 0000000000000..0f819fb01a76a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/Metadata.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use [[MetadataBuilder]]. + */ +class Metadata extends org.apache.spark.sql.catalyst.util.Metadata { + Metadata(scala.collection.immutable.Map map) { + super(map); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java new file mode 100644 index 0000000000000..6e6b12f0722c5 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/MetadataBuilder.java @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.api.java; + +/** + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ +public class MetadataBuilder extends org.apache.spark.sql.catalyst.util.MetadataBuilder { + @Override + public Metadata build() { + return new Metadata(getMap()); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java index b48e2a2c5f953..7c60d492bcdf0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java +++ b/sql/core/src/main/java/org/apache/spark/sql/api/java/StructField.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.api.java; +import java.util.Map; + /** * A StructField object represents a field in a StructType object. * A StructField object comprises three fields, {@code String name}, {@code DataType dataType}, @@ -24,20 +26,27 @@ * The field of {@code dataType} specifies the data type of a StructField. * The field of {@code nullable} specifies if values of a StructField can contain {@code null} * values. + * The field of {@code metadata} provides extra information of the StructField. * * To create a {@link StructField}, - * {@link DataType#createStructField(String, DataType, boolean)} + * {@link DataType#createStructField(String, DataType, boolean, Metadata)} * should be used. */ public class StructField { private String name; private DataType dataType; private boolean nullable; + private Metadata metadata; - protected StructField(String name, DataType dataType, boolean nullable) { + protected StructField( + String name, + DataType dataType, + boolean nullable, + Metadata metadata) { this.name = name; this.dataType = dataType; this.nullable = nullable; + this.metadata = metadata; } public String getName() { @@ -52,6 +61,10 @@ public boolean isNullable() { return nullable; } + public Metadata getMetadata() { + return metadata; + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -62,6 +75,7 @@ public boolean equals(Object o) { if (nullable != that.nullable) return false; if (!dataType.equals(that.dataType)) return false; if (!name.equals(that.name)) return false; + if (!metadata.equals(that.metadata)) return false; return true; } @@ -71,6 +85,7 @@ public int hashCode() { int result = name.hashCode(); result = 31 * result + dataType.hashCode(); result = 31 * result + (nullable ? 1 : 0); + result = 31 * result + metadata.hashCode(); return result; } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index a41a500c9a5d0..4953f8399a96b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{Optimizer, DefaultOptimizer} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.catalyst.types.DataType import org.apache.spark.sql.execution.{SparkStrategies, _} import org.apache.spark.sql.json._ import org.apache.spark.sql.parquet.ParquetRelation diff --git a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala index 047dc85df6c1d..eabe312f92371 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/json/JsonRDD.scala @@ -117,10 +117,7 @@ private[sql] object JsonRDD extends Logging { } }.flatMap(field => field).toSeq - StructType( - (topLevelFields ++ structFields).sortBy { - case StructField(name, _, _) => name - }) + StructType((topLevelFields ++ structFields).sortBy(_.name)) } makeStruct(resolved.keySet.toSeq, Nil) @@ -128,7 +125,7 @@ private[sql] object JsonRDD extends Logging { private[sql] def nullTypeToStringType(struct: StructType): StructType = { val fields = struct.fields.map { - case StructField(fieldName, dataType, nullable) => { + case StructField(fieldName, dataType, nullable, _) => { val newType = dataType match { case NullType => StringType case ArrayType(NullType, containsNull) => ArrayType(StringType, containsNull) @@ -163,9 +160,7 @@ private[sql] object JsonRDD extends Logging { StructField(name, dataType, true) } } - StructType(newFields.toSeq.sortBy { - case StructField(name, _, _) => name - }) + StructType(newFields.toSeq.sortBy(_.name)) } case (ArrayType(elementType1, containsNull1), ArrayType(elementType2, containsNull2)) => ArrayType(compatibleType(elementType1, elementType2), containsNull1 || containsNull2) @@ -413,7 +408,7 @@ private[sql] object JsonRDD extends Logging { // TODO: Reuse the row instead of creating a new one for every record. val row = new GenericMutableRow(schema.fields.length) schema.fields.zipWithIndex.foreach { - case (StructField(name, dataType, _), i) => + case (StructField(name, dataType, _, _), i) => row.update(i, json.get(name).flatMap(v => Option(v)).map( enforceCorrectType(_, dataType)).orNull) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/package.scala index e98d151286818..f0e57e2a7447b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/package.scala @@ -125,6 +125,9 @@ package object sql { @DeveloperApi type DataType = catalyst.types.DataType + @DeveloperApi + val DataType = catalyst.types.DataType + /** * :: DeveloperApi :: * @@ -414,4 +417,24 @@ package object sql { */ @DeveloperApi val StructField = catalyst.types.StructField + + /** + * :: DeveloperApi :: + * + * Metadata is a wrapper over Map[String, Any] that limits the value type to simple ones: Boolean, + * Long, Double, String, Metadata, Array[Boolean], Array[Long], Array[Double], Array[String], and + * Array[Metadata]. JSON is used for serialization. + * + * The default constructor is private. User should use either [[MetadataBuilder]] or + * [[Metadata$#fromJson]] to create Metadata instances. + * + * @param map an immutable map that stores the data + */ + @DeveloperApi + type Metadata = catalyst.util.Metadata + + /** + * Builder for [[Metadata]]. If there is a key collision, the latter will overwrite the former. + */ + type MetadataBuilder = catalyst.util.MetadataBuilder } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala index 609f7db562a31..142598c904b37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/types/util/DataTypeConversions.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.types.util import org.apache.spark.sql._ -import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField} +import org.apache.spark.sql.api.java.{DataType => JDataType, StructField => JStructField, MetadataBuilder => JMetaDataBuilder} import scala.collection.JavaConverters._ @@ -31,7 +31,8 @@ protected[sql] object DataTypeConversions { JDataType.createStructField( scalaStructField.name, asJavaDataType(scalaStructField.dataType), - scalaStructField.nullable) + scalaStructField.nullable, + (new JMetaDataBuilder).withMetadata(scalaStructField.metadata).build()) } /** @@ -68,7 +69,8 @@ protected[sql] object DataTypeConversions { StructField( javaStructField.getName, asScalaDataType(javaStructField.getDataType), - javaStructField.isNullable) + javaStructField.isNullable, + javaStructField.getMetadata) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala index 100ecb45e9e88..6c9db639c0f6c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataTypeSuite.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql import org.scalatest.FunSuite -import org.apache.spark.sql.catalyst.types.DataType - class DataTypeSuite extends FunSuite { test("construct an ArrayType") { @@ -79,8 +77,12 @@ class DataTypeSuite extends FunSuite { checkDataTypeJsonRepr(ArrayType(StringType, false)) checkDataTypeJsonRepr(MapType(IntegerType, StringType, true)) checkDataTypeJsonRepr(MapType(IntegerType, ArrayType(DoubleType), false)) + val metadata = new MetadataBuilder() + .putString("name", "age") + .build() checkDataTypeJsonRepr( StructType(Seq( StructField("a", IntegerType, nullable = true), - StructField("b", ArrayType(DoubleType), nullable = false)))) + StructField("b", ArrayType(DoubleType), nullable = false), + StructField("c", DoubleType, nullable = false, metadata)))) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4acd92d33d180..6befe1b755cc6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -17,17 +17,16 @@ package org.apache.spark.sql +import java.util.TimeZone + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.sql.TestData._ import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.joins.BroadcastHashJoin -import org.apache.spark.sql.test._ -import org.scalatest.BeforeAndAfterAll -import java.util.TimeZone -/* Implicits */ -import TestSQLContext._ -import TestData._ +import org.apache.spark.sql.test.TestSQLContext._ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { // Make sure the tables are loaded. @@ -697,6 +696,30 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll { ("true", "false") :: Nil) } + test("metadata is propagated correctly") { + val person = sql("SELECT * FROM person") + val schema = person.schema + val docKey = "doc" + val docValue = "first name" + val metadata = new MetadataBuilder() + .putString(docKey, docValue) + .build() + val schemaWithMeta = new StructType(Seq( + schema("id"), schema("name").copy(metadata = metadata), schema("age"))) + val personWithMeta = applySchema(person, schemaWithMeta) + def validateMetadata(rdd: SchemaRDD): Unit = { + assert(rdd.schema("name").metadata.getString(docKey) == docValue) + } + personWithMeta.registerTempTable("personWithMeta") + validateMetadata(personWithMeta.select('name)) + validateMetadata(personWithMeta.select("name".attr)) + validateMetadata(personWithMeta.select('id, 'name)) + validateMetadata(sql("SELECT * FROM personWithMeta")) + validateMetadata(sql("SELECT id, name FROM personWithMeta")) + validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId")) + validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId")) + } + test("SPARK-3371 Renaming a function expression with group by gives error") { registerFunction("len", (s: String) => s.length) checkAnswer( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index c4dd3e860f5fd..836dd17fcc3a2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -166,4 +166,15 @@ object TestData { // An RDD with 4 elements and 8 partitions val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8) withEmptyParts.registerTempTable("withEmptyParts") + + case class Person(id: Int, name: String, age: Int) + case class Salary(personId: Int, salary: Double) + val person = TestSQLContext.sparkContext.parallelize( + Person(0, "mike", 30) :: + Person(1, "jim", 20) :: Nil) + person.registerTempTable("person") + val salary = TestSQLContext.sparkContext.parallelize( + Salary(0, 2000.0) :: + Salary(1, 1000.0) :: Nil) + salary.registerTempTable("salary") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala index 8415af41be3af..e0e0ff9cb3d3d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/api/java/ScalaSideDataTypeConversionSuite.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.api.java -import org.apache.spark.sql.types.util.DataTypeConversions import org.scalatest.FunSuite -import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField} -import org.apache.spark.sql.{StructType => SStructType} -import DataTypeConversions._ +import org.apache.spark.sql.{DataType => SDataType, StructField => SStructField, StructType => SStructType} +import org.apache.spark.sql.types.util.DataTypeConversions._ class ScalaSideDataTypeConversionSuite extends FunSuite { @@ -67,11 +65,15 @@ class ScalaSideDataTypeConversionSuite extends FunSuite { checkDataType(simpleScalaStructType) // Complex StructType. + val metadata = new MetadataBuilder() + .putString("name", "age") + .build() val complexScalaStructType = SStructType( SStructField("simpleArray", simpleScalaArrayType, true) :: SStructField("simpleMap", simpleScalaMapType, true) :: SStructField("simpleStruct", simpleScalaStructType, true) :: - SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: Nil) + SStructField("boolean", org.apache.spark.sql.BooleanType, false) :: + SStructField("withMeta", org.apache.spark.sql.DoubleType, false, metadata) :: Nil) checkDataType(complexScalaStructType) // Complex ArrayType.