Skip to content

Commit

Permalink
Add date and timestamp support to UnsafeRow.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed May 7, 2015
1 parent 4f87e95 commit 4c07b57
Show file tree
Hide file tree
Showing 6 changed files with 145 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.*;

import org.apache.spark.sql.Row;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DateUtils;
import static org.apache.spark.sql.types.DataTypes.*;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.types.UTF8String;
Expand Down Expand Up @@ -107,7 +109,9 @@ public static int calculateBitSetWidthInBytes(int numFields) {
// We support get() on a superset of the types for which we support set():
final Set<DataType> _readableFieldTypes = new HashSet<DataType>(
Arrays.asList(new DataType[]{
StringType
StringType,
DateType,
TimestampType
}));
_readableFieldTypes.addAll(settableFieldTypes);
readableFieldTypes = Collections.unmodifiableSet(_readableFieldTypes);
Expand Down Expand Up @@ -222,6 +226,16 @@ public void setString(int ordinal, String value) {
throw new UnsupportedOperationException();
}

@Override
public void setDate(int ordinal, Date value) {
throw new UnsupportedOperationException();
}

@Override
public void setTimestamp(int ordinal, Timestamp value) {
throw new UnsupportedOperationException();
}

@Override
public int size() {
return numFields;
Expand Down Expand Up @@ -346,7 +360,21 @@ public BigDecimal getDecimal(int i) {

@Override
public Date getDate(int i) {
throw new UnsupportedOperationException();
final int daysSinceEpoch = getInt(i);
return DateUtils.toJavaDate(daysSinceEpoch);
}

@Override
public Timestamp getTimestamp(int i) {
assertIndexIsValid(i);
final long offsetToTimestampSize = getLong(i);
final long time =
PlatformDependent.UNSAFE.getLong(baseObject, baseOffset + offsetToTimestampSize);
final int nanos =
PlatformDependent.UNSAFE.getInt(baseObject, baseOffset + offsetToTimestampSize + 8);
final Timestamp timestamp = new Timestamp(time);
timestamp.setNanos(nanos);
return timestamp;
}

@Override
Expand Down
14 changes: 11 additions & 3 deletions sql/catalyst/src/main/scala/org/apache/spark/sql/Row.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@

package org.apache.spark.sql

import java.sql.{Date, Timestamp}

import scala.util.hashing.MurmurHash3

import org.apache.spark.sql.catalyst.expressions.GenericRow
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.{DateUtils, StructType}

object Row {
/**
Expand Down Expand Up @@ -257,8 +259,14 @@ trait Row extends Serializable {
*
* @throws ClassCastException when data type does not match.
*/
// TODO(davies): This is not the right default implementation, we use Int as Date internally
def getDate(i: Int): java.sql.Date = apply(i).asInstanceOf[java.sql.Date]
def getDate(i: Int): java.sql.Date = DateUtils.toJavaDate(getInt(i))

/**
* Returns the value at position i of date type as java.sql.Timestamp.
*
* @throws ClassCastException when data type does not match.
*/
def getTimestamp(i: Int): java.sql.Timestamp = apply(i).asInstanceOf[java.sql.Timestamp]

/**
* Returns the value at position i of array type as a Scala Seq.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}

import org.apache.spark.sql.types._

/**
Expand Down Expand Up @@ -241,6 +243,11 @@ final class SpecificMutableRow(val values: Array[MutableValue]) extends MutableR

override def setString(ordinal: Int, value: String): Unit = update(ordinal, UTF8String(value))

override def setDate(ordinal: Int, value: java.sql.Date): Unit =
setInt(ordinal, DateUtils.fromJavaDate(value))

override def setTimestamp(ordinal: Int, value: java.sql.Timestamp): Unit = update(ordinal, value)

override def getString(ordinal: Int): String = apply(ordinal).toString

override def setInt(ordinal: Int, value: Int): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ private object UnsafeColumnWriter {
case FloatType => FloatUnsafeColumnWriter
case DoubleType => DoubleUnsafeColumnWriter
case StringType => StringUnsafeColumnWriter
case DateType => DateUnsafeColumnWriter
case TimestampType => TimestampUnsafeColumnWriter
case t =>
throw new UnsupportedOperationException(s"Do not know how to write columns of type $t")
}
Expand All @@ -136,6 +138,8 @@ private object LongUnsafeColumnWriter extends LongUnsafeColumnWriter
private object FloatUnsafeColumnWriter extends FloatUnsafeColumnWriter
private object DoubleUnsafeColumnWriter extends DoubleUnsafeColumnWriter
private object StringUnsafeColumnWriter extends StringUnsafeColumnWriter
private object DateUnsafeColumnWriter extends DateUnsafeColumnWriter
private object TimestampUnsafeColumnWriter extends TimestampUnsafeColumnWriter

private abstract class PrimitiveUnsafeColumnWriter extends UnsafeColumnWriter {
// Primitives don't write to the variable-length region:
Expand Down Expand Up @@ -221,3 +225,36 @@ private class StringUnsafeColumnWriter private() extends UnsafeColumnWriter {
8 + ByteArrayMethods.roundNumberOfBytesToNearestWord(numBytes)
}
}

private class DateUnsafeColumnWriter private() extends UnsafeColumnWriter {
def getSize(source: Row, column: Int): Int = {
0
}

override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
target.setInt(column, DateUtils.fromJavaDate(source.getDate(column)))
0
}
}

private class TimestampUnsafeColumnWriter private() extends UnsafeColumnWriter {
def getSize(source: Row, column: Int): Int = {
// Although Timestamp is fixed length, it needs 12-byte that is more than 8-byte word per field.
// So we need to store it at the variable-length section.
16
}

override def write(source: Row, target: UnsafeRow, column: Int, appendCursor: Int): Int = {
val value = source.get(column).asInstanceOf[java.sql.Timestamp]
val time = value.getTime()
val nanos = value.getNanos()

val baseObject = target.getBaseObject
val baseOffset = target.getBaseOffset

PlatformDependent.UNSAFE.putLong(baseObject, baseOffset + appendCursor, time)
PlatformDependent.UNSAFE.putInt(baseObject, baseOffset + appendCursor + 8, nanos)
target.setLong(column, appendCursor)
16
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType}
import java.sql.{Date, Timestamp}

import org.apache.spark.sql.types.{UTF8String, DataType, StructType, AtomicType, DateUtils}

/**
* An extended interface to [[Row]] that allows the values for each column to be updated. Setting
Expand All @@ -36,7 +38,9 @@ trait MutableRow extends Row {
def setByte(ordinal: Int, value: Byte)
def setFloat(ordinal: Int, value: Float)
def setString(ordinal: Int, value: String)
// TODO(davies): add setDate() and setDecimal()
def setDate(ordinal: Int, value: java.sql.Date)
def setTimestamp(ordinal: Int, value: java.sql.Timestamp)
// TODO(davies): add setDecimal()
}

/**
Expand All @@ -55,6 +59,8 @@ object EmptyRow extends Row {
override def getShort(i: Int): Short = throw new UnsupportedOperationException
override def getByte(i: Int): Byte = throw new UnsupportedOperationException
override def getString(i: Int): String = throw new UnsupportedOperationException
override def getDate(i: Int): java.sql.Date = throw new UnsupportedOperationException
override def getTimestamp(i: Int): java.sql.Timestamp = throw new UnsupportedOperationException
override def getAs[T](i: Int): T = throw new UnsupportedOperationException
override def copy(): Row = this
}
Expand Down Expand Up @@ -121,7 +127,21 @@ class GenericRow(protected[sql] val values: Array[Any]) extends Row {
}
}

// TODO(davies): add getDate and getDecimal
override def getDate(i: Int): java.sql.Date = {
values(i) match {
case null => null
case d: Int => DateUtils.toJavaDate(d)
}
}

override def getTimestamp(i: Int): java.sql.Timestamp = {
values(i) match {
case null => null
case t: java.sql.Timestamp => t
}
}

// TODO(davies): add getDecimal

// Custom hashCode function that matches the efficient code generated version.
override def hashCode: Int = {
Expand Down Expand Up @@ -197,7 +217,16 @@ class GenericMutableRow(v: Array[Any]) extends GenericRow(v) with MutableRow {
override def setFloat(ordinal: Int, value: Float): Unit = { values(ordinal) = value }
override def setInt(ordinal: Int, value: Int): Unit = { values(ordinal) = value }
override def setLong(ordinal: Int, value: Long): Unit = { values(ordinal) = value }
override def setString(ordinal: Int, value: String) { values(ordinal) = UTF8String(value)}
override def setString(ordinal: Int, value: String): Unit = { values(ordinal) = UTF8String(value) }

override def setDate(ordinal: Int, value: java.sql.Date): Unit = {
values(ordinal) = DateUtils.fromJavaDate(value)
}

override def setTimestamp(ordinal: Int, value: java.sql.Timestamp): Unit = {
values(ordinal) = value
}

override def setNullAt(i: Int): Unit = { values(i) = null }

override def setShort(ordinal: Int, value: Short): Unit = { values(ordinal) = value }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.spark.sql.catalyst.expressions

import java.sql.{Date, Timestamp}
import java.util.Arrays

import org.scalatest.{FunSuite, Matchers}
Expand Down Expand Up @@ -72,7 +73,33 @@ class UnsafeRowConverterSuite extends FunSuite with Matchers {
unsafeRow.getString(1) should be ("Hello")
unsafeRow.getString(2) should be ("World")
}

test("basic conversion with primitive, string, date and timestamp types") {
val fieldTypes: Array[DataType] = Array(LongType, StringType, DateType, TimestampType)
val converter = new UnsafeRowConverter(fieldTypes)

val row = new SpecificMutableRow(fieldTypes)
row.setLong(0, 0)
row.setString(1, "Hello")
row.setDate(2, Date.valueOf("1970-01-01"))
row.setTimestamp(3, Timestamp.valueOf("2015-05-08 08:10:25"))

val sizeRequired: Int = converter.getSizeRequirement(row)
sizeRequired should be (8 + (8 * 4) +
ByteArrayMethods.roundNumberOfBytesToNearestWord("Hello".getBytes.length + 8) +
16)
val buffer: Array[Long] = new Array[Long](sizeRequired / 8)
val numBytesWritten = converter.writeRow(row, buffer, PlatformDependent.LONG_ARRAY_OFFSET)
numBytesWritten should be (sizeRequired)

val unsafeRow = new UnsafeRow()
unsafeRow.pointTo(buffer, PlatformDependent.LONG_ARRAY_OFFSET, fieldTypes.length, null)
unsafeRow.getLong(0) should be (0)
unsafeRow.getString(1) should be ("Hello")
unsafeRow.getDate(2) should be (Date.valueOf("1970-01-01"))
unsafeRow.getTimestamp(3) should be (Timestamp.valueOf("2015-05-08 08:10:25"))
}

test("null handling") {
val fieldTypes: Array[DataType] = Array(
NullType,
Expand Down

0 comments on commit 4c07b57

Please sign in to comment.