Skip to content

Commit

Permalink
[SPARK-8748][SQL] Move castability test out from Cast case class into…
Browse files Browse the repository at this point in the history
… Cast object.

This patch moved resolve function in Cast case class into the companion object,
and renamed it canCast. We can then use this in the analyzer without a Cast expr.
  • Loading branch information
rxin committed Jul 1, 2015
1 parent ccdb052 commit 4d2d989
Showing 1 changed file with 68 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,65 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {

override def checkInputDataTypes(): TypeCheckResult = {
if (resolve(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"cannot cast ${child.dataType} to $dataType")
}
}
object Cast {

override def foldable: Boolean = child.foldable
/**
* Returns true iff we can cast `from` type to `to` type.
*/
def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
case (fromType, toType) if fromType == toType => true

case (NullType, _) => true

override def nullable: Boolean = forceNullable(child.dataType, dataType) || child.nullable
case (_, StringType) => true

case (StringType, BinaryType) => true

case (StringType, BooleanType) => true
case (DateType, BooleanType) => true
case (TimestampType, BooleanType) => true
case (_: NumericType, BooleanType) => true

case (StringType, TimestampType) => true
case (BooleanType, TimestampType) => true
case (DateType, TimestampType) => true
case (_: NumericType, TimestampType) => true

case (_, DateType) => true

private[this] def forceNullable(from: DataType, to: DataType) = (from, to) match {
case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
case (DateType, _: NumericType) => true
case (TimestampType, _: NumericType) => true
case (_: NumericType, _: NumericType) => true

case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
canCast(fromType, toType) &&
resolvableNullability(fn || forceNullable(fromType, toType), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
canCast(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
canCast(fromValue, toValue) &&
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.length == toFields.length &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
canCast(fromField.dataType, toField.dataType) &&
resolvableNullability(
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
toField.nullable)
}

case _ => false
}

private def resolvableNullability(from: Boolean, to: Boolean) = !from || to

private def forceNullable(from: DataType, to: DataType) = (from, to) match {
case (StringType, _: NumericType) => true
case (StringType, TimestampType) => true
case (DoubleType, TimestampType) => true
Expand All @@ -58,61 +100,24 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w
case (_, DecimalType.Fixed(_, _)) => true // TODO: not all upcasts here can really give null
case _ => false
}
}

private[this] def resolvableNullability(from: Boolean, to: Boolean) = !from || to

private[this] def resolve(from: DataType, to: DataType): Boolean = {
(from, to) match {
case (from, to) if from == to => true

case (NullType, _) => true

case (_, StringType) => true

case (StringType, BinaryType) => true

case (StringType, BooleanType) => true
case (DateType, BooleanType) => true
case (TimestampType, BooleanType) => true
case (_: NumericType, BooleanType) => true

case (StringType, TimestampType) => true
case (BooleanType, TimestampType) => true
case (DateType, TimestampType) => true
case (_: NumericType, TimestampType) => true

case (_, DateType) => true

case (StringType, _: NumericType) => true
case (BooleanType, _: NumericType) => true
case (DateType, _: NumericType) => true
case (TimestampType, _: NumericType) => true
case (_: NumericType, _: NumericType) => true

case (ArrayType(from, fn), ArrayType(to, tn)) =>
resolve(from, to) &&
resolvableNullability(fn || forceNullable(from, to), tn)

case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
resolve(fromKey, toKey) &&
(!forceNullable(fromKey, toKey)) &&
resolve(fromValue, toValue) &&
resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

case (StructType(fromFields), StructType(toFields)) =>
fromFields.size == toFields.size &&
fromFields.zip(toFields).forall {
case (fromField, toField) =>
resolve(fromField.dataType, toField.dataType) &&
resolvableNullability(
fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
toField.nullable)
}
/** Cast the child expression to the target data type. */
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with Logging {

case _ => false
override def checkInputDataTypes(): TypeCheckResult = {
if (Cast.canCast(child.dataType, dataType)) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"cannot cast ${child.dataType} to $dataType")
}
}

override def foldable: Boolean = child.foldable

override def nullable: Boolean = Cast.forceNullable(child.dataType, dataType) || child.nullable

override def toString: String = s"CAST($child, $dataType)"

// [[func]] assumes the input is no longer null because eval already does the null check.
Expand Down

0 comments on commit 4d2d989

Please sign in to comment.