diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java index 6e13e97b4cbcc..b0c818f5a4dfb 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.DataType; @@ -31,26 +32,22 @@ */ public class OrcArrayColumnVector extends OrcColumnVector { private final OrcColumnVector data; - private final long[] offsets; - private final long[] lengths; OrcArrayColumnVector( DataType type, ColumnVector vector, - OrcColumnVector data, - long[] offsets, - long[] lengths) { + OrcColumnVector data) { super(type, vector); this.data = data; - this.offsets = offsets; - this.lengths = lengths; } @Override public ColumnarArray getArray(int rowId) { - return new ColumnarArray(data, (int) offsets[rowId], (int) lengths[rowId]); + int offsets = (int) ((ListColumnVector) baseData).offsets[rowId]; + int lengths = (int) ((ListColumnVector) baseData).lengths[rowId]; + return new ColumnarArray(data, offsets, lengths); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 0becd2572f99c..7fe1b306142e0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -29,7 +29,7 @@ * this column vector is used to adapt Hive ColumnVector with Spark ColumnarVector. */ public abstract class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { - private final ColumnVector baseData; + protected final ColumnVector baseData; private int batchSize; OrcColumnVector(DataType type, ColumnVector vector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java index 3bc7cc8f80142..89f6996e4610f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java @@ -53,15 +53,13 @@ static OrcColumnVector toOrcColumnVector(DataType type, ColumnVector vector) { ListColumnVector listVector = (ListColumnVector) vector; OrcColumnVector dataVector = toOrcColumnVector( ((ArrayType) type).elementType(), listVector.child); - return new OrcArrayColumnVector( - type, vector, dataVector, listVector.offsets, listVector.lengths); + return new OrcArrayColumnVector(type, vector, dataVector); } else if (vector instanceof MapColumnVector) { MapColumnVector mapVector = (MapColumnVector) vector; MapType mapType = (MapType) type; OrcColumnVector keysVector = toOrcColumnVector(mapType.keyType(), mapVector.keys); OrcColumnVector valuesVector = toOrcColumnVector(mapType.valueType(), mapVector.values); - return new OrcMapColumnVector( - type, vector, keysVector, valuesVector, mapVector.offsets, mapVector.lengths); + return new OrcMapColumnVector(type, vector, keysVector, valuesVector); } else { throw new IllegalArgumentException( String.format("OrcColumnVectorUtils.toOrcColumnVector should not take %s as type " + diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java index ace8d157792dc..7eedd8b594128 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; @@ -32,28 +33,24 @@ public class OrcMapColumnVector extends OrcColumnVector { private final OrcColumnVector keys; private final OrcColumnVector values; - private final long[] offsets; - private final long[] lengths; OrcMapColumnVector( DataType type, ColumnVector vector, OrcColumnVector keys, - OrcColumnVector values, - long[] offsets, - long[] lengths) { + OrcColumnVector values) { super(type, vector); this.keys = keys; this.values = values; - this.offsets = offsets; - this.lengths = lengths; } @Override public ColumnarMap getMap(int ordinal) { - return new ColumnarMap(keys, values, (int) offsets[ordinal], (int) lengths[ordinal]); + int offsets = (int) ((MapColumnVector) baseData).offsets[ordinal]; + int lengths = (int) ((MapColumnVector) baseData).lengths[ordinal]; + return new ColumnarMap(keys, values, offsets, lengths); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index ead2c2cf1b70f..6a91a9cfe9340 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -34,7 +34,9 @@ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -713,6 +715,27 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { } } } + + test("SPARK-37728: Reading nested columns with ORC vectorized reader should not " + + "cause ArrayIndexOutOfBoundsException") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(100).map { _ => + val arrayColumn = (0 until 50).map(_ => (0 until 1000).map(k => k.toString)) + arrayColumn + }.toDF("record").repartition(1) + df.write.format("orc").save(path) + + withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + val readDf = spark.read.orc(path) + val vectorizationEnabled = readDf.queryExecution.executedPlan.find { + case scan @ (_: FileSourceScanExec | _: BatchScanExec) => scan.supportsColumnar + case _ => false + }.isDefined + checkAnswer(readDf, df) + } + } + } } class OrcV1QuerySuite extends OrcQuerySuite {