diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 4f879b881a6fe..dc11f338a006e 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -40,7 +40,22 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.connector.write.V1WriteBuilder"), // [SPARK-33955] Add latest offsets to source progress - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceProgress.this") + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.SourceProgress.this"), + + // [SPARK-34862][SQL] Support nested column in ORC vectorized reader + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getBoolean"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getByte"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getShort"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getInt"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getLong"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getFloat"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getDouble"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getDecimal"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getUTF8String"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getBinary"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getArray"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getMap"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.vectorized.ColumnVector.getChild") ) // Exclude rules for 3.1.x diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5c5cb52f30951..d91bb59f68c44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -839,6 +839,13 @@ object SQLConf { .intConf .createWithDefault(4096) + val ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED = + buildConf("spark.sql.orc.enableNestedColumnVectorizedReader") + .doc("Enables vectorized orc decoding for nested column.") + .version("3.2.0") + .booleanConf + .createWithDefault(false) + val ORC_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.orc.filterPushdown") .doc("When true, enable filter pushdown for ORC files.") .version("1.4.0") @@ -3339,6 +3346,9 @@ class SQLConf extends Serializable with Logging { def orcVectorizedReaderBatchSize: Int = getConf(ORC_VECTORIZED_READER_BATCH_SIZE) + def orcVectorizedReaderNestedColumnEnabled: Boolean = + getConf(ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED) + def parquetCompressionCodec: String = getConf(PARQUET_COMPRESSION) def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java new file mode 100644 index 0000000000000..6e13e97b4cbcc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java @@ -0,0 +1,115 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; + +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector implementation for Spark's {@link ArrayType}. + */ +public class OrcArrayColumnVector extends OrcColumnVector { + private final OrcColumnVector data; + private final long[] offsets; + private final long[] lengths; + + OrcArrayColumnVector( + DataType type, + ColumnVector vector, + OrcColumnVector data, + long[] offsets, + long[] lengths) { + + super(type, vector); + + this.data = data; + this.offsets = offsets; + this.lengths = lengths; + } + + @Override + public ColumnarArray getArray(int rowId) { + return new ColumnarArray(data, (int) offsets[rowId], (int) lengths[rowId]); + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java new file mode 100644 index 0000000000000..c2d8334d928c0 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcAtomicColumnVector.java @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import java.math.BigDecimal; + +import org.apache.hadoop.hive.ql.exec.vector.*; + +import org.apache.spark.sql.catalyst.util.DateTimeUtils; +import org.apache.spark.sql.catalyst.util.RebaseDateTime; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DateType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.TimestampType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector implementation for Spark's AtomicType. + */ +public class OrcAtomicColumnVector extends OrcColumnVector { + private final boolean isTimestamp; + private final boolean isDate; + + // Column vector for each type. Only 1 is populated for any type. + private LongColumnVector longData; + private DoubleColumnVector doubleData; + private BytesColumnVector bytesData; + private DecimalColumnVector decimalData; + private TimestampColumnVector timestampData; + + OrcAtomicColumnVector(DataType type, ColumnVector vector) { + super(type, vector); + + if (type instanceof TimestampType) { + isTimestamp = true; + } else { + isTimestamp = false; + } + + if (type instanceof DateType) { + isDate = true; + } else { + isDate = false; + } + + if (vector instanceof LongColumnVector) { + longData = (LongColumnVector) vector; + } else if (vector instanceof DoubleColumnVector) { + doubleData = (DoubleColumnVector) vector; + } else if (vector instanceof BytesColumnVector) { + bytesData = (BytesColumnVector) vector; + } else if (vector instanceof DecimalColumnVector) { + decimalData = (DecimalColumnVector) vector; + } else if (vector instanceof TimestampColumnVector) { + timestampData = (TimestampColumnVector) vector; + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean getBoolean(int rowId) { + return longData.vector[getRowIndex(rowId)] == 1; + } + + @Override + public byte getByte(int rowId) { + return (byte) longData.vector[getRowIndex(rowId)]; + } + + @Override + public short getShort(int rowId) { + return (short) longData.vector[getRowIndex(rowId)]; + } + + @Override + public int getInt(int rowId) { + int value = (int) longData.vector[getRowIndex(rowId)]; + if (isDate) { + return RebaseDateTime.rebaseJulianToGregorianDays(value); + } else { + return value; + } + } + + @Override + public long getLong(int rowId) { + int index = getRowIndex(rowId); + if (isTimestamp) { + return DateTimeUtils.fromJavaTimestamp(timestampData.asScratchTimestamp(index)); + } else { + return longData.vector[index]; + } + } + + @Override + public float getFloat(int rowId) { + return (float) doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public double getDouble(int rowId) { + return doubleData.vector[getRowIndex(rowId)]; + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) return null; + BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); + return Decimal.apply(data, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + BytesColumnVector col = bytesData; + return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) return null; + int index = getRowIndex(rowId); + byte[] binary = new byte[bytesData.length[index]]; + System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); + return binary; + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 6e55fedfc4deb..0becd2572f99c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -17,75 +17,29 @@ package org.apache.spark.sql.execution.datasources.orc; -import java.math.BigDecimal; +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; -import org.apache.hadoop.hive.ql.exec.vector.*; - -import org.apache.spark.sql.catalyst.util.DateTimeUtils; -import org.apache.spark.sql.catalyst.util.RebaseDateTime; import org.apache.spark.sql.types.DataType; -import org.apache.spark.sql.types.DateType; -import org.apache.spark.sql.types.Decimal; -import org.apache.spark.sql.types.TimestampType; -import org.apache.spark.sql.vectorized.ColumnarArray; -import org.apache.spark.sql.vectorized.ColumnarMap; -import org.apache.spark.unsafe.types.UTF8String; +import org.apache.spark.sql.vectorized.ColumnarBatch; /** - * A column vector class wrapping Hive's ColumnVector. Because Spark ColumnarBatch only accepts - * Spark's vectorized.ColumnVector, this column vector is used to adapt Hive ColumnVector with - * Spark ColumnarVector. + * A column vector interface wrapping Hive's {@link ColumnVector}. + * + * Because Spark {@link ColumnarBatch} only accepts Spark's vectorized.ColumnVector, + * this column vector is used to adapt Hive ColumnVector with Spark ColumnarVector. */ -public class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { - private ColumnVector baseData; - private LongColumnVector longData; - private DoubleColumnVector doubleData; - private BytesColumnVector bytesData; - private DecimalColumnVector decimalData; - private TimestampColumnVector timestampData; - private final boolean isTimestamp; - private final boolean isDate; - +public abstract class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { + private final ColumnVector baseData; private int batchSize; OrcColumnVector(DataType type, ColumnVector vector) { super(type); - if (type instanceof TimestampType) { - isTimestamp = true; - } else { - isTimestamp = false; - } - - if (type instanceof DateType) { - isDate = true; - } else { - isDate = false; - } - baseData = vector; - if (vector instanceof LongColumnVector) { - longData = (LongColumnVector) vector; - } else if (vector instanceof DoubleColumnVector) { - doubleData = (DoubleColumnVector) vector; - } else if (vector instanceof BytesColumnVector) { - bytesData = (BytesColumnVector) vector; - } else if (vector instanceof DecimalColumnVector) { - decimalData = (DecimalColumnVector) vector; - } else if (vector instanceof TimestampColumnVector) { - timestampData = (TimestampColumnVector) vector; - } else { - throw new UnsupportedOperationException(); - } - } - - public void setBatchSize(int batchSize) { - this.batchSize = batchSize; } @Override public void close() { - } @Override @@ -112,97 +66,18 @@ public int numNulls() { } } - /* A helper method to get the row index in a column. */ - private int getRowIndex(int rowId) { - return baseData.isRepeating ? 0 : rowId; - } - @Override public boolean isNullAt(int rowId) { return baseData.isNull[getRowIndex(rowId)]; } - @Override - public boolean getBoolean(int rowId) { - return longData.vector[getRowIndex(rowId)] == 1; - } - - @Override - public byte getByte(int rowId) { - return (byte) longData.vector[getRowIndex(rowId)]; - } - - @Override - public short getShort(int rowId) { - return (short) longData.vector[getRowIndex(rowId)]; - } - - @Override - public int getInt(int rowId) { - int value = (int) longData.vector[getRowIndex(rowId)]; - if (isDate) { - return RebaseDateTime.rebaseJulianToGregorianDays(value); - } else { - return value; - } - } - - @Override - public long getLong(int rowId) { - int index = getRowIndex(rowId); - if (isTimestamp) { - return DateTimeUtils.fromJavaTimestamp(timestampData.asScratchTimestamp(index)); - } else { - return longData.vector[index]; - } - } - - @Override - public float getFloat(int rowId) { - return (float) doubleData.vector[getRowIndex(rowId)]; - } - - @Override - public double getDouble(int rowId) { - return doubleData.vector[getRowIndex(rowId)]; - } - - @Override - public Decimal getDecimal(int rowId, int precision, int scale) { - if (isNullAt(rowId)) return null; - BigDecimal data = decimalData.vector[getRowIndex(rowId)].getHiveDecimal().bigDecimalValue(); - return Decimal.apply(data, precision, scale); - } - @Override - public UTF8String getUTF8String(int rowId) { - if (isNullAt(rowId)) return null; - int index = getRowIndex(rowId); - BytesColumnVector col = bytesData; - return UTF8String.fromBytes(col.vector[index], col.start[index], col.length[index]); - } - - @Override - public byte[] getBinary(int rowId) { - if (isNullAt(rowId)) return null; - int index = getRowIndex(rowId); - byte[] binary = new byte[bytesData.length[index]]; - System.arraycopy(bytesData.vector[index], bytesData.start[index], binary, 0, binary.length); - return binary; - } - - @Override - public ColumnarArray getArray(int rowId) { - throw new UnsupportedOperationException(); - } - - @Override - public ColumnarMap getMap(int rowId) { - throw new UnsupportedOperationException(); + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; } - @Override - public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { - throw new UnsupportedOperationException(); + /* A helper method to get the row index in a column. */ + protected int getRowIndex(int rowId) { + return baseData.isRepeating ? 0 : rowId; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java new file mode 100644 index 0000000000000..3bc7cc8f80142 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.hadoop.hive.ql.exec.vector.*; + +import org.apache.spark.sql.types.*; + +/** + * Utility class for {@link OrcColumnVector}. + */ +class OrcColumnVectorUtils { + + /** + * Convert a Hive's {@link ColumnVector} to a Spark's {@link OrcColumnVector}. + * + * @param type The data type of column vector + * @param vector Hive's column vector + * @return Spark's column vector + */ + static OrcColumnVector toOrcColumnVector(DataType type, ColumnVector vector) { + if (vector instanceof LongColumnVector || + vector instanceof DoubleColumnVector || + vector instanceof BytesColumnVector || + vector instanceof DecimalColumnVector || + vector instanceof TimestampColumnVector) { + return new OrcAtomicColumnVector(type, vector); + } else if (vector instanceof StructColumnVector) { + StructColumnVector structVector = (StructColumnVector) vector; + OrcColumnVector[] fields = new OrcColumnVector[structVector.fields.length]; + int ordinal = 0; + for (StructField f : ((StructType) type).fields()) { + fields[ordinal] = toOrcColumnVector(f.dataType(), structVector.fields[ordinal]); + ordinal++; + } + return new OrcStructColumnVector(type, vector, fields); + } else if (vector instanceof ListColumnVector) { + ListColumnVector listVector = (ListColumnVector) vector; + OrcColumnVector dataVector = toOrcColumnVector( + ((ArrayType) type).elementType(), listVector.child); + return new OrcArrayColumnVector( + type, vector, dataVector, listVector.offsets, listVector.lengths); + } else if (vector instanceof MapColumnVector) { + MapColumnVector mapVector = (MapColumnVector) vector; + MapType mapType = (MapType) type; + OrcColumnVector keysVector = toOrcColumnVector(mapType.keyType(), mapVector.keys); + OrcColumnVector valuesVector = toOrcColumnVector(mapType.valueType(), mapVector.values); + return new OrcMapColumnVector( + type, vector, keysVector, valuesVector, mapVector.offsets, mapVector.lengths); + } else { + throw new IllegalArgumentException( + String.format("OrcColumnVectorUtils.toOrcColumnVector should not take %s as type " + + "and %s as vector", type, vector)); + } + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 6a4b116cdef0b..40ed0b2454c12 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -180,7 +180,8 @@ public void initBatch( missingCol.setIsConstant(); orcVectorWrappers[i] = missingCol; } else { - orcVectorWrappers[i] = new OrcColumnVector(dt, wrap.batch().cols[colId]); + orcVectorWrappers[i] = OrcColumnVectorUtils.toOrcColumnVector( + dt, wrap.batch().cols[colId]); } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java new file mode 100644 index 0000000000000..ace8d157792dc --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector implementation for Spark's {@link MapType}. + */ +public class OrcMapColumnVector extends OrcColumnVector { + private final OrcColumnVector keys; + private final OrcColumnVector values; + private final long[] offsets; + private final long[] lengths; + + OrcMapColumnVector( + DataType type, + ColumnVector vector, + OrcColumnVector keys, + OrcColumnVector values, + long[] offsets, + long[] lengths) { + + super(type, vector); + + this.keys = keys; + this.values = values; + this.offsets = offsets; + this.lengths = lengths; + } + + @Override + public ColumnarMap getMap(int ordinal) { + return new ColumnarMap(keys, values, (int) offsets[ordinal], (int) lengths[ordinal]); + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcStructColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcStructColumnVector.java new file mode 100644 index 0000000000000..48e540d22095e --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcStructColumnVector.java @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc; + +import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; + +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A column vector implementation for Spark's {@link StructType}. + */ +public class OrcStructColumnVector extends OrcColumnVector { + private final OrcColumnVector[] fields; + + OrcStructColumnVector(DataType type, ColumnVector vector, OrcColumnVector[] fields) { + super(type, vector); + + this.fields = fields; + } + + @Override + public org.apache.spark.sql.vectorized.ColumnVector getChild(int ordinal) { + return fields[ordinal]; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException(); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 8f4d1e5098029..3a5097441ab37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -131,11 +131,27 @@ class OrcFileFormat } } + private def supportBatchForNestedColumn( + sparkSession: SparkSession, + schema: StructType): Boolean = { + val hasNestedColumn = schema.map(_.dataType).exists { + case _: ArrayType | _: MapType | _: StructType => true + case _ => false + } + if (hasNestedColumn) { + sparkSession.sessionState.conf.orcVectorizedReaderNestedColumnEnabled + } else { + true + } + } + override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && - schema.forall(_.dataType.isInstanceOf[AtomicType]) + schema.forall(s => supportDataType(s.dataType) && + !s.dataType.isInstanceOf[UserDefinedType[_]]) && + supportBatchForNestedColumn(sparkSession, schema) } override def isSplitable( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index c763f4c9428c8..ad69bee4fb643 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -33,6 +33,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY} +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, SchemaMergeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -542,6 +543,7 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll with CommonFileDa } class OrcSourceSuite extends OrcSuite with SharedSparkSession { + import testImplicits._ protected override def beforeAll(): Unit = { super.beforeAll() @@ -602,4 +604,31 @@ class OrcSourceSuite extends OrcSuite with SharedSparkSession { checkAnswer(spark.read.orc(path), Seq(Row(0), Row(1), Row(2))) } } + + test("SPARK-34862: Support ORC vectorized reader for nested column") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(10).map { x => + val stringColumn = s"$x" * 10 + val structColumn = (x, s"$x" * 100) + val arrayColumn = (0 until 5).map(i => (x + i, s"$x" * 5)) + val mapColumn = Map( + s"$x" -> (x * 0.1, (x, s"$x" * 100)), + (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)), + (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300))) + (x, stringColumn, structColumn, arrayColumn, mapColumn) + }.toDF("int_col", "string_col", "struct_col", "array_col", "map_col") + df.write.format("orc").save(path) + + withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + val readDf = spark.read.orc(path) + val vectorizationEnabled = readDf.queryExecution.executedPlan.find { + case scan: FileSourceScanExec => scan.supportsColumnar + case _ => false + }.isDefined + assert(vectorizationEnabled) + checkAnswer(readDf, df) + } + } + } }