Skip to content

Commit

Permalink
Support map types in java beans
Browse files Browse the repository at this point in the history
  • Loading branch information
Punya Biswal committed Apr 21, 2015
1 parent 327ebf0 commit 5e00685
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 48 deletions.
106 changes: 106 additions & 0 deletions sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import java.beans.Introspector
import java.lang.{Iterable => JIterable}
import java.util.{Iterator => JIterator, Map => JMap}

import com.google.common.reflect.TypeToken

import org.apache.spark.sql.types._

import scala.language.existentials

/**
* Type-inference utilities for POJOs and Java collections.
*/
private [sql] object JavaTypeInference {

private val iterableType = TypeToken.of(classOf[JIterable[_]])
private val mapType = TypeToken.of(classOf[JMap[_, _]])
private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType

/**
* Infers the corresponding SQL data type of a Java type.
* @param typeToken Java type
* @return (SQL data type, nullable)
*/
private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
typeToken.getRawType match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

case _ if typeToken.isArray =>
val (dataType, nullable) = inferDataType(typeToken.getComponentType)
(ArrayType(dataType, nullable), true)

case _ if mapType.isAssignableFrom(typeToken) =>
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
val (keyDataType, _) = inferDataType(keyType)
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
}
}

private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
val itemType = iteratorType.resolveType(nextReturnType)
itemType
}
}
52 changes: 5 additions & 47 deletions sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag

import com.google.common.reflect.TypeToken

import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
val (dataType, _) = inferDataType(beanClass)
val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass))
dataType.asInstanceOf[StructType].fields.map { f =>
AttributeReference(f.name, f.dataType, f.nullable)()
}
}

/**
* Infers the corresponding SQL data type of a Java class.
* @param clazz Java class
* @return (SQL data type, nullable)
*/
private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
clazz match {
case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
(c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)

case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)

case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)

case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)

case c: Class[_] if c.isArray =>
val (dataType, nullable) = inferDataType(c.getComponentType)
(ArrayType(dataType, nullable), true)

case _ =>
val beanInfo = Introspector.getBeanInfo(clazz)
val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
val fields = properties.map { property =>
val (dataType, nullable) = inferDataType(property.getPropertyType)
new StructField(property.getName, dataType, nullable)
}
(new StructType(fields), true)
}
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@

import java.io.Serializable;
import java.util.Arrays;
import java.util.Map;

import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
import scala.collection.JavaConversions;
import scala.collection.Seq;

import org.junit.After;
Expand All @@ -34,6 +38,7 @@
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
import scala.collection.mutable.Buffer;

import static org.apache.spark.sql.functions.*;

Expand Down Expand Up @@ -106,6 +111,7 @@ public void testShow() {
public static class Bean implements Serializable {
private double a = 0.0;
private Integer[] b = new Integer[]{0, 1};
private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });

public double getA() {
return a;
Expand All @@ -114,6 +120,10 @@ public double getA() {
public Integer[] getB() {
return b;
}

public Map<String, int[]> getC() {
return c;
}
}

@Test
Expand All @@ -127,7 +137,12 @@ public void testCreateDataFrameFromJavaBeans() {
Assert.assertEquals(
new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
schema.apply("b"));
Row first = df.select("a", "b").first();
ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
MapType mapType = new MapType(DataTypes.StringType, valueType, true);
Assert.assertEquals(
new StructField("c", mapType, true, Metadata.empty()),
schema.apply("c"));
Row first = df.select("a", "b", "c").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
// verify that it has the expected length, and contains expected elements.
Expand All @@ -136,5 +151,10 @@ public void testCreateDataFrameFromJavaBeans() {
for (int i = 0; i < result.length(); i++) {
Assert.assertEquals(bean.getB()[i], result.apply(i));
}
Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
Assert.assertArrayEquals(
bean.getC().get("hello"),
Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
}

}

0 comments on commit 5e00685

Please sign in to comment.