diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index bb546b3086b33..748883afb1ad3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -24,10 +24,12 @@ import javax.annotation.Nullable; import java.math.BigDecimal; import java.sql.Date; +import java.sql.Timestamp; import java.util.*; import org.apache.spark.sql.Row; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateUtils; import static org.apache.spark.sql.types.DataTypes.*; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.UTF8String; @@ -107,7 +109,9 @@ public static int calculateBitSetWidthInBytes(int numFields) { // We support get() on a superset of the types for which we support set(): final Set _readableFieldTypes = new HashSet( Arrays.asList(new DataType[]{ - StringType + StringType, + DateType, + TimestampType })); _readableFieldTypes.addAll(settableFieldTypes); readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes); @@ -222,6 +226,16 @@ public void setString(int ordinal, String value) { throw new UnsupportedOperationException(); } + @Override + public void setDate(int ordinal, Date value) { + throw new UnsupportedOperationException(); + } + + @Override + public void setTimestamp(int ordinal, Timestamp value) { + throw new UnsupportedOperationException(); + } + @Override public int size() { return numFields; @@ -346,7 +360,21 @@ public BigDecimal getDecimal(int i) { @Override public Date getDate(int i) { - throw new UnsupportedOperationException(); + final int daysSinceEpoch = getInt(i); + return DateUtils.toJavaDate(daysSinceEpoch); + } + + @Override + public Timestamp getTimestamp(int i) { + assertIndexIsValid(i); + final long offsetToTimestampSize = getLong(i); + final long time = + PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToTimestampSize); + final int nanos = + PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offsetToTimestampSize + 8); + final Timestamp timestamp = new Timestamp(time); + timestamp.setNanos(nanos); + return timestamp; } @Override diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala index 4190b7ffe1c8f..d3fa9e9514dd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql +import java.sql.{Date, Timestamp} + import scala.util.hashing.MurmurHash3 import org.apache.spark.sql.catalyst.expressions.GenericRow -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DateUtils, StructType} object Row { /** @@ -257,8 +259,14 @@ trait Row extends Serializable { * * @throws ClassCastException when data type does not match. */ - // TODO(davies): This is not the right default implementation, we use Int as Date internally - def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date] + def getDate(i: Int): java.sql.Date = DateUtils.toJavaDate(getInt(i)) + + /** + * Returns the value at position i of date type as java.sql.Timestamp. + * + * @throws ClassCastException when data type does not match. + */ + def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp] /** * Returns the value at position i of array type as a Scala Seq. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index aa4099e4d7bf9..ed8bfd6422dae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} + import org.apache.spark.sql.types._ /** @@ -241,6 +243,11 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value)) + override def setDate(ordinal: Int, value: java.sql.Date): Unit = + setInt(ordinal, DateUtils.fromJavaDate(value)) + + override def setTimestamp(ordinal: Int, value: java.sql.Timestamp): Unit = update(ordinal, value) + override def getString(ordinal: Int): String = apply(ordinal).toString override def setInt(ordinal: Int, value: Int): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala index 5b2c8572784bd..7e8ba1c85d7ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverter.scala @@ -119,6 +119,8 @@ private object UnsafeColumnWriter { case FloatType => FloatUnsafeColumnWriter case DoubleType => DoubleUnsafeColumnWriter case StringType => StringUnsafeColumnWriter + case DateType => DateUnsafeColumnWriter + case TimestampType => TimestampUnsafeColumnWriter case t => throw new UnsupportedOperationException(s"Do not know how to write columns of type $t") } @@ -136,6 +138,8 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter +private object DateUnsafeColumnWriter extends DateUnsafeColumnWriter +private object TimestampUnsafeColumnWriter extends TimestampUnsafeColumnWriter private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter { // Primitives don't write to the variable-length region: @@ -221,3 +225,36 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter { 8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes) } } + +private class DateUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(source: Row, column: Int): Int = { + 0 + } + + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + target.setInt(column, DateUtils.fromJavaDate(source.getDate(column))) + 0 + } +} + +private class TimestampUnsafeColumnWriter private() extends UnsafeColumnWriter { + def getSize(source: Row, column: Int): Int = { + // Although Timestamp is fixed length, it needs 12-byte that is more than 8-byte word per field. + // So we need to store it at the variable-length section. + 16 + } + + override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = { + val value = source.get(column).asInstanceOf[java.sql.Timestamp] + val time = value.getTime() + val nanos = value.getNanos() + + val baseObject = target.getBaseObject + val baseOffset = target.getBaseOffset + + PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, time) + PlatformDependent.UNSAFE.putInt(baseObject, baseOffset + appendCursor + 8, nanos) + target.setLong(column, appendCursor) + 16 + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index 5fd892c42e69c..392b5feaf5927 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType} +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType, DateUtils} /** * An extended interface to [[Row]] that allows the values for each column to be updated. Setting @@ -36,7 +38,9 @@ trait MutableRow extends Row { def setByte(ordinal: Int, value: Byte) def setFloat(ordinal: Int, value: Float) def setString(ordinal: Int, value: String) - // TODO(davies): add setDate() and setDecimal() + def setDate(ordinal: Int, value: java.sql.Date) + def setTimestamp(ordinal: Int, value: java.sql.Timestamp) + // TODO(davies): add setDecimal() } /** @@ -55,6 +59,8 @@ object EmptyRow extends Row { override def getShort(i: Int): Short = throw new UnsupportedOperationException override def getByte(i: Int): Byte = throw new UnsupportedOperationException override def getString(i: Int): String = throw new UnsupportedOperationException + override def getDate(i: Int): java.sql.Date = throw new UnsupportedOperationException + override def getTimestamp(i: Int): java.sql.Timestamp = throw new UnsupportedOperationException override def getAs[T](i: Int): T = throw new UnsupportedOperationException override def copy(): Row = this } @@ -121,7 +127,21 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row { } } - // TODO(davies): add getDate and getDecimal + override def getDate(i: Int): java.sql.Date = { + values(i) match { + case null => null + case d: Int => DateUtils.toJavaDate(d) + } + } + + override def getTimestamp(i: Int): java.sql.Timestamp = { + values(i) match { + case null => null + case t: java.sql.Timestamp => t + } + } + + // TODO(davies): add getDecimal // Custom hashCode function that matches the efficient code generated version. override def hashCode: Int = { @@ -197,7 +217,16 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow { override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value } override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value } override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value } - override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)} + override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = UTF8String(value) } + + override def setDate(ordinal: Int, value: java.sql.Date): Unit = { + values(ordinal) = DateUtils.fromJavaDate(value) + } + + override def setTimestamp(ordinal: Int, value: java.sql.Timestamp): Unit = { + values(ordinal) = value + } + override def setNullAt(i: Int): Unit = { values(i) = null } override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala index 3a60c7fd32675..3f2de5c3a7be8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/UnsafeRowConverterSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Date, Timestamp} import java.util.Arrays import org.scalatest.{FunSuite, Matchers} @@ -72,7 +73,33 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers { unsafeRow.getString(1) should be ("Hello") unsafeRow.getString(2) should be ("World") } + + test("basic conversion with primitive, string, date and timestamp types") { + val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType) + val converter = new UnsafeRowConverter(fieldTypes) + + val row = new SpecificMutableRow(fieldTypes) + row.setLong(0, 0) + row.setString(1, "Hello") + row.setDate(2, Date.valueOf("1970-01-01")) + row.setTimestamp(3, Timestamp.valueOf("2015-05-08 08:10:25")) + + val sizeRequired: Int = converter.getSizeRequirement(row) + sizeRequired should be (8 + (8 * 4) + + ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) + + 16) + val buffer: Array[Long] = new Array[Long](sizeRequired / 8) + val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET) + numBytesWritten should be (sizeRequired) + val unsafeRow = new UnsafeRow() + unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null) + unsafeRow.getLong(0) should be (0) + unsafeRow.getString(1) should be ("Hello") + unsafeRow.getDate(2) should be (Date.valueOf("1970-01-01")) + unsafeRow.getTimestamp(3) should be (Timestamp.valueOf("2015-05-08 08:10:25")) + } + test("null handling") { val fieldTypes: Array[DataType] = Array( NullType,