diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala index f65a107924ec5..9832207ee940c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ExtractValue.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import scala.collection.Map -import org.apache.spark.sql.{catalyst, AnalysisException} +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.types._ @@ -41,12 +41,14 @@ object ExtractValue { resolver: Resolver): ExtractValue = { (child.dataType, extraction) match { - case (StructType(fields), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetStructField(child, fields(ordinal), ordinal) - case (ArrayType(StructType(fields), containsNull), Literal(fieldName, StringType)) => - val ordinal = findField(fields, fieldName.toString, resolver) - GetArrayStructFields(child, fields(ordinal), ordinal, containsNull) + case (StructType(fields), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetStructField(child, fields(ordinal).copy(name = fieldName), ordinal) + case (ArrayType(StructType(fields), containsNull), NonNullLiteral(v, StringType)) => + val fieldName = v.toString + val ordinal = findField(fields, fieldName, resolver) + GetArrayStructFields(child, fields(ordinal).copy(name = fieldName), ordinal, containsNull) case (_: ArrayType, _) if extraction.dataType.isInstanceOf[IntegralType] => GetArrayItem(child, extraction) case (_: MapType, _) =>