diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index 4aaabff15b6ee..a2ff938f16939 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -178,7 +178,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) // TODO: consider using Array[_] for ArrayType child to avoid // boxing of primitives val baseValue = value.asInstanceOf[Seq[_]] - val index = ordinal.asInstanceOf[Int] + val index = ordinal.asInstanceOf[Number].intValue() if (index >= baseValue.size || index < 0) { null } else { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala similarity index 98% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 72fdcebb4cbc8..e0bf07ed182f3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.sql.catalyst import org.apache.spark.sql.types._ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 2b0f4618b21e0..b80911e7257fc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -26,6 +26,26 @@ import org.apache.spark.unsafe.types.UTF8String class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { + /** + * Runs through the testFunc for all integral data types. + * + * @param testFunc a test function that accepts a conversion function to convert an integer + * into another data type. + */ + private def testIntegralDataTypes(testFunc: (Int => Any) => Unit): Unit = { + testFunc(_.toByte) + testFunc(_.toShort) + testFunc(identity) + testFunc(_.toLong) + } + + test("GetArrayItem") { + testIntegralDataTypes { convert => + val array = Literal.create(Seq("a", "b"), ArrayType(StringType)) + checkEvaluation(GetArrayItem(array, Literal(convert(1))), "b") + } + } + test("CreateStruct") { val row = InternalRow(1, 2, 3) val c1 = 'a.int.at(0).as("a")