diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index e7bf7cc1f1313..189451d0d9ad7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -67,6 +67,19 @@ object HiveTypeCoercion { }) } + /** + * Similar to [[findTightestCommonType]], if can not find the TightestCommonType, try to use + * [[findTightestCommonTypeToString]] to find the TightestCommonType. + */ + private def findTightestCommonTypeAndPromoteToString(types: Seq[DataType]): Option[DataType] = { + types.foldLeft[Option[DataType]](Some(NullType))((r, c) => r match { + case None => None + case Some(d) => + findTightestCommonTypeOfTwo(d, c).orElse(findTightestCommonTypeToString(d, c)) + }) + } + + /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -599,7 +612,7 @@ trait HiveTypeCoercion { // compatible with every child column. case Coalesce(es) if es.map(_.dataType).distinct.size > 1 => val types = es.map(_.dataType) - findTightestCommonType(types) match { + findTightestCommonTypeAndPromoteToString(types) match { case Some(finalDataType) => Coalesce(es.map(Cast(_, finalDataType))) case None => sys.error(s"Could not determine return type of Coalesce for ${types.mkString(",")}") @@ -634,7 +647,7 @@ trait HiveTypeCoercion { def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { case c: CaseWhenLike if c.childrenResolved && !c.valueTypesEqual => logDebug(s"Input values for null casting ${c.valueTypes.mkString(",")}") - val maybeCommonType = findTightestCommonType(c.valueTypes) + val maybeCommonType = findTightestCommonTypeAndPromoteToString(c.valueTypes) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(when, value) if value.dataType != commonType => @@ -650,7 +663,8 @@ trait HiveTypeCoercion { }.getOrElse(c) case c: CaseKeyWhen if c.childrenResolved && !c.resolved => - val maybeCommonType = findTightestCommonType((c.key +: c.whenList).map(_.dataType)) + val maybeCommonType = + findTightestCommonTypeAndPromoteToString((c.key +: c.whenList).map(_.dataType)) maybeCommonType.map { commonType => val castedBranches = c.branches.grouped(2).map { case Seq(when, then) if when.dataType != commonType => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index a47cc30e92e27..1a6ee8169c38d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -45,6 +45,16 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row("one", 6) :: Row("three", 3) :: Nil) } + test("SPARK-8010: promote numeric to string") { + val df = Seq((1, 1)).toDF("key", "value") + df.registerTempTable("src") + val queryCaseWhen = sql("select case when true then 1.0 else '1' end from src ") + val queryCoalesce = sql("select coalesce(null, 1, '1') from src ") + + checkAnswer(queryCaseWhen, Row("1.0") :: Nil) + checkAnswer(queryCoalesce, Row("1") :: Nil) + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38),