diff --git a/build.sbt b/build.sbt index 77041601e0..0b99a59434 100644 --- a/build.sbt +++ b/build.sbt @@ -47,6 +47,8 @@ val sparkVersion = settingKey[String]("Spark version") spark / sparkVersion := getSparkVersion() connectCommon / sparkVersion := getSparkVersion() connectServer / sparkVersion := getSparkVersion() +kernelDefaults / sparkVersion := getSparkVersion() +goldenTables / sparkVersion := getSparkVersion() // Dependent library versions val defaultSparkVersion = LATEST_RELEASED_SPARK_VERSION @@ -144,6 +146,25 @@ lazy val commonSettings = Seq( unidocSourceFilePatterns := Nil, ) +/** + * Java-/Scala-/Uni-Doc settings aren't working yet against Spark Master. + 1) delta-spark on Spark Master uses JDK 17. delta-iceberg uses JDK 8 or 11. For some reason, + generating delta-spark unidoc compiles delta-iceberg + 2) delta-spark unidoc fails to compile. spark 3.5 is on its classpath. likely due to iceberg + issue above. + */ +def crossSparkUniDocSettings(): Seq[Setting[_]] = getSparkVersion() match { + case LATEST_RELEASED_SPARK_VERSION => Seq( + // Java-/Scala-/Uni-Doc Settings + scalacOptions ++= Seq( + "-P:genjavadoc:strictVisibility=true" // hide package private types and methods in javadoc + ), + unidocSourceFilePatterns := Seq(SourceFilePattern("io/delta/tables/", "io/delta/exceptions/")) + ) + + case SPARK_MASTER_VERSION => Seq() +} + /** * Note: we cannot access sparkVersion.value here, since that can only be used within a task or * setting macro. @@ -158,12 +179,6 @@ def crossSparkSettings(): Seq[Setting[_]] = getSparkVersion() match { Compile / unmanagedSourceDirectories += (Compile / baseDirectory).value / "src" / "main" / "scala-spark-3.5", Test / unmanagedSourceDirectories += (Test / baseDirectory).value / "src" / "test" / "scala-spark-3.5", Antlr4 / antlr4Version := "4.9.3", - - // Java-/Scala-/Uni-Doc Settings - scalacOptions ++= Seq( - "-P:genjavadoc:strictVisibility=true" // hide package private types and methods in javadoc - ), - unidocSourceFilePatterns := Seq(SourceFilePattern("io/delta/tables/", "io/delta/exceptions/")) ) case SPARK_MASTER_VERSION => Seq( @@ -188,13 +203,6 @@ def crossSparkSettings(): Seq[Setting[_]] = getSparkVersion() match { "--add-opens=java.base/sun.security.action=ALL-UNNAMED", "--add-opens=java.base/sun.util.calendar=ALL-UNNAMED" ) - - // Java-/Scala-/Uni-Doc Settings - // This isn't working yet against Spark Master. - // 1) delta-spark on Spark Master uses JDK 17. delta-iceberg uses JDK 8 or 11. For some reason, - // generating delta-spark unidoc compiles delta-iceberg - // 2) delta-spark unidoc fails to compile. spark 3.5 is on its classpath. likely due to iceberg - // issue above. ) } @@ -221,6 +229,7 @@ lazy val connectCommon = (project in file("spark-connect/common")) name := "delta-connect-common", commonSettings, crossSparkSettings(), + crossSparkUniDocSettings(), releaseSettings, Compile / compile := runTaskOnlyOnSparkMaster( task = Compile / compile, @@ -280,6 +289,7 @@ lazy val connectServer = (project in file("spark-connect/server")) emptyValue = () ).value, crossSparkSettings(), + crossSparkUniDocSettings(), libraryDependencies ++= Seq( "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf", @@ -307,6 +317,7 @@ lazy val spark = (project in file("spark")) sparkMimaSettings, releaseSettings, crossSparkSettings(), + crossSparkUniDocSettings(), libraryDependencies ++= Seq( // Adding test classifier seems to break transitive resolution of the core dependencies "org.apache.spark" %% "spark-hive" % sparkVersion.value % "provided", @@ -450,6 +461,7 @@ lazy val kernelApi = (project in file("kernel/kernel-api")) scalaStyleSettings, javaOnlyReleaseSettings, Test / javaOptions ++= Seq("-ea"), + crossSparkSettings(), libraryDependencies ++= Seq( "org.roaringbitmap" % "RoaringBitmap" % "0.9.25", "org.slf4j" % "slf4j-api" % "1.7.36", @@ -504,6 +516,7 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) scalaStyleSettings, javaOnlyReleaseSettings, Test / javaOptions ++= Seq("-ea"), + crossSparkSettings(), libraryDependencies ++= Seq( "org.apache.hadoop" % "hadoop-client-runtime" % hadoopVersion, "com.fasterxml.jackson.core" % "jackson-databind" % "2.13.5", @@ -520,10 +533,10 @@ lazy val kernelDefaults = (project in file("kernel/kernel-defaults")) "org.openjdk.jmh" % "jmh-core" % "1.37" % "test", "org.openjdk.jmh" % "jmh-generator-annprocess" % "1.37" % "test", - "org.apache.spark" %% "spark-hive" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-core" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-catalyst" % defaultSparkVersion % "test" classifier "tests", + "org.apache.spark" %% "spark-hive" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-core" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-catalyst" % sparkVersion.value % "test" classifier "tests", ), javaCheckstyleSettings("dev/kernel-checkstyle.xml"), // Unidoc settings @@ -1218,14 +1231,15 @@ lazy val goldenTables = (project in file("connectors/golden-tables")) name := "golden-tables", commonSettings, skipReleaseSettings, + crossSparkSettings(), libraryDependencies ++= Seq( // Test Dependencies "org.scalatest" %% "scalatest" % scalaTestVersion % "test", "commons-io" % "commons-io" % "2.8.0" % "test", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test", - "org.apache.spark" %% "spark-catalyst" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-core" % defaultSparkVersion % "test" classifier "tests", - "org.apache.spark" %% "spark-sql" % defaultSparkVersion % "test" classifier "tests" + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test", + "org.apache.spark" %% "spark-catalyst" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-core" % sparkVersion.value % "test" classifier "tests", + "org.apache.spark" %% "spark-sql" % sparkVersion.value % "test" classifier "tests" ) ) diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java index 4e6096dad0..8c3bce0138 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/Scan.java @@ -36,6 +36,7 @@ import io.delta.kernel.internal.util.ColumnMapping; import io.delta.kernel.internal.util.PartitionUtils; import io.delta.kernel.internal.util.Tuple2; +import io.delta.kernel.internal.util.VariantUtils; /** * Represents a scan of a Delta table. @@ -196,6 +197,15 @@ public FilteredColumnarBatch next() { nextDataBatch = nextDataBatch.withDeletedColumnAt(rowIndexOrdinal); } + // Transform physical variant columns (struct of binaries) into logical variant + // columns. + if (ScanStateRow.getVariantFeatureEnabled(scanState)) { + nextDataBatch = VariantUtils.withVariantColumns( + engine.getExpressionHandler(), + nextDataBatch + ); + } + // Add partition columns nextDataBatch = PartitionUtils.withPartitionColumns( diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java index 55c639a016..2117794462 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnVector.java @@ -175,6 +175,14 @@ default ArrayValue getArray(int rowId) { throw new UnsupportedOperationException("Invalid value request for data type"); } + /** + * Return the variant value located at {@code rowId}. Returns null if the slot for {@code rowId} + * is null + */ + default VariantValue getVariant(int rowId) { + throw new UnsupportedOperationException("Invalid value request for data type"); + } + /** * Get the child vector associated with the given ordinal. This method is applicable only to the * {@code struct} type columns. diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnarBatch.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnarBatch.java index 9f4df477fc..c4a4279491 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnarBatch.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/ColumnarBatch.java @@ -104,6 +104,22 @@ default ColumnarBatch slice(int start, int end) { throw new UnsupportedOperationException("Not yet implemented!"); } + /** + * Return a copy of this {@link ColumnarBatch} with the column at given {@code ordinal} + * replaced with {@code newVector} and the schema field at given {@code ordinal} replaced + * with {@code newColumnSchema}. + * + * @param ordinal Ordinal of the column vector to replace. + * @param newColumnSchema The schema field of the new column. + * @param newVector New column vector that will replace the column vector at the given + * {@code ordinal}. + * @return {@link ColumnarBatch} with a new column vector at the given ordinal. + */ + default ColumnarBatch withReplacedColumnVector(int ordinal, StructField newColumnSchema, + ColumnVector newVector) { + throw new UnsupportedOperationException("Not yet implemented!"); + } + /** * @return iterator of {@link Row}s in this batch */ diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java index adcacbc0f4..560f3113d7 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/Row.java @@ -117,4 +117,10 @@ public interface Row { * Throws error if the column at given ordinal is not of map type, */ MapValue getMap(int ordinal); + + /** + * Return variant value of the column located at the given ordinal. + * Throws error if the column at given ordinal is not of variant type. + */ + VariantValue getVariant(int ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java b/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java new file mode 100644 index 0000000000..abf57d5450 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/data/VariantValue.java @@ -0,0 +1,25 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.data; + +/** + * Abstraction to represent a single Variant value in a {@link ColumnVector}. + */ +public interface VariantValue { + byte[] getValue(); + + byte[] getMetadata(); +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java index 5a90c07e45..4c5b94b59f 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/TableFeatures.java @@ -51,6 +51,7 @@ public static void validateReadSupportedTable( break; case "deletionVectors": // fall through case "timestampNtz": // fall through + case "variantType-preview": // fall through case "vacuumProtocolCheck": // fall through case "v2Checkpoint": break; diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java index ae4793fa47..74e76f979b 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ChildVectorBasedRow.java @@ -21,6 +21,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.StructType; /** @@ -111,5 +112,10 @@ public MapValue getMap(int ordinal) { return getChild(ordinal).getMap(rowId); } + @Override + public VariantValue getVariant(int ordinal) { + return getChild(ordinal).getVariant(rowId); + } + protected abstract ColumnVector getChild(int ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java index 01a12bb84d..c4d6aeaf8c 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/GenericRow.java @@ -23,6 +23,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.*; /** @@ -134,6 +135,12 @@ public MapValue getMap(int ordinal) { return (MapValue) getValue(ordinal); } + @Override + public VariantValue getVariant(int ordinal) { + throwIfUnsafeAccess(ordinal, VariantType.class, "variant"); + return (VariantValue) getValue(ordinal); + } + private Object getValue(int ordinal) { return ordinalToValue.get(ordinal); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java index 0d3ab71ee6..d772bba415 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/data/ScanStateRow.java @@ -41,6 +41,7 @@ public class ScanStateRow extends GenericRow { .add("partitionColumns", new ArrayType(StringType.STRING, false)) .add("minReaderVersion", IntegerType.INTEGER) .add("minWriterVersion", IntegerType.INTEGER) + .add("variantFeatureEnabled", BooleanType.BOOLEAN) .add("tablePath", StringType.STRING); private static final Map COL_NAME_TO_ORDINAL = @@ -64,6 +65,10 @@ public static ScanStateRow of( valueMap.put(COL_NAME_TO_ORDINAL.get("partitionColumns"), metadata.getPartitionColumns()); valueMap.put(COL_NAME_TO_ORDINAL.get("minReaderVersion"), protocol.getMinReaderVersion()); valueMap.put(COL_NAME_TO_ORDINAL.get("minWriterVersion"), protocol.getMinWriterVersion()); + valueMap.put( + COL_NAME_TO_ORDINAL.get("variantFeatureEnabled"), + protocol.getReaderFeatures().contains("variantType-preview") + ); valueMap.put(COL_NAME_TO_ORDINAL.get("tablePath"), tablePath); return new ScanStateRow(valueMap); } @@ -147,4 +152,15 @@ public static String getColumnMappingMode(Row scanState) { public static String getTableRoot(Row scanState) { return scanState.getString(COL_NAME_TO_ORDINAL.get("tablePath")); } + + /** + * Get whether the "variantType" table feature is enabled from scan state {@link Row} returned + * by {@link Scan#getScanState(Engine)} + * + * @param scanState Scan state {@link Row} + * @return Boolean indicating whether "variantType" is enabled. + */ + public static Boolean getVariantFeatureEnabled(Row scanState) { + return scanState.getBoolean(COL_NAME_TO_ORDINAL.get("variantFeatureEnabled")); + } } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java new file mode 100644 index 0000000000..a46df42880 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VariantUtils.java @@ -0,0 +1,67 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.internal.util; + +import java.util.Arrays; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.ColumnarBatch; +import io.delta.kernel.engine.ExpressionHandler; +import io.delta.kernel.expressions.*; +import io.delta.kernel.types.*; + +public class VariantUtils { + public static ColumnarBatch withVariantColumns( + ExpressionHandler expressionHandler, + ColumnarBatch dataBatch) { + for (int i = 0; i < dataBatch.getSchema().length(); i++) { + StructField field = dataBatch.getSchema().at(i); + if (!(field.getDataType() instanceof StructType) && + !(field.getDataType() instanceof ArrayType) && + !(field.getDataType() instanceof MapType) && + (field.getDataType() != VariantType.VARIANT || + dataBatch.getColumnVector(i).getDataType() == VariantType.VARIANT)) { + continue; + } + + ExpressionEvaluator evaluator = expressionHandler.getEvaluator( + // Field here is variant type if its actually a variant. + // TODO: probably better to pass in the schema as an expression argument + // so the schema is enforced at the expression level. Need to pass in a literal + // schema + new StructType().add(field), + new ScalarExpression( + "variant_coalesce", + Arrays.asList(new Column(field.getName())) + ), + VariantType.VARIANT + ); + + ColumnVector variantCol = evaluator.eval(dataBatch); + dataBatch = dataBatch.withReplacedColumnVector(i, field, variantCol); + } + return dataBatch; + } + + private static ColumnVector[] getColumnBatchVectors(ColumnarBatch batch) { + ColumnVector[] res = new ColumnVector[batch.getSchema().length()]; + for (int i = 0; i < batch.getSchema().length(); i++) { + res[i] = batch.getColumnVector(i); + } + return res; + } +} diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java index 206321c0d3..96ebc171af 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/internal/util/VectorUtils.java @@ -205,6 +205,8 @@ private static Object getValueAsObject( return toJavaList(columnVector.getArray(rowId)); } else if (dataType instanceof MapType) { return toJavaMap(columnVector.getMap(rowId)); + } else if (dataType instanceof VariantType) { + return columnVector.getVariant(rowId); } else { throw new UnsupportedOperationException("unsupported data type"); } diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java index 33affe3d8a..6141d84a13 100644 --- a/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/BasePrimitiveType.java @@ -64,6 +64,7 @@ public static List getAllPrimitiveTypes() { put("timestamp_ntz", TimestampNTZType.TIMESTAMP_NTZ); put("binary", BinaryType.BINARY); put("string", StringType.STRING); + put("variant", VariantType.VARIANT); } }); diff --git a/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java b/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java new file mode 100644 index 0000000000..71a84cdb71 --- /dev/null +++ b/kernel/kernel-api/src/main/java/io/delta/kernel/types/VariantType.java @@ -0,0 +1,31 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.types; + +import io.delta.kernel.annotation.Evolving; + +/** + * A logical variant type. + * @since 4.0.0 + */ +@Evolving +public class VariantType extends BasePrimitiveType { + public static final VariantType VARIANT = new VariantType(); + + private VariantType() { + super("variant"); + } +} diff --git a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala index db2814ab77..83e31ba672 100644 --- a/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala +++ b/kernel/kernel-api/src/test/scala/io/delta/kernel/internal/SnapshotManagerSuite.scala @@ -759,7 +759,7 @@ class SnapshotManagerSuite extends AnyFunSuite with MockFileSystemClientUtils { // corrupt incomplete multi-part checkpoint val corruptedCheckpointStatuses = FileNames.checkpointFileWithParts(logPath, 10, 5).asScala .map(p => FileStatus.of(p.toString, 10, 10)) - .take(4) + .take(4).toSeq val deltas = deltaFileStatuses(10L to 13L) testExpectedError[RuntimeException]( corruptedCheckpointStatuses ++ deltas, @@ -808,7 +808,7 @@ class SnapshotManagerSuite extends AnyFunSuite with MockFileSystemClientUtils { // _last_checkpoint refers to incomplete multi-part checkpoint val corruptedCheckpointStatuses = FileNames.checkpointFileWithParts(logPath, 20, 5).asScala .map(p => FileStatus.of(p.toString, 10, 10)) - .take(4) + .take(4).toSeq testExpectedError[RuntimeException]( files = corruptedCheckpointStatuses ++ deltaFileStatuses(10L to 20L) ++ singularCheckpointFileStatuses(Seq(10L)), diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultColumnarBatch.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultColumnarBatch.java index 69aac7aa43..d0869b5ad1 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultColumnarBatch.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultColumnarBatch.java @@ -109,6 +109,19 @@ public ColumnarBatch withNewSchema(StructType newSchema) { size, newSchema, columnVectors.toArray(new ColumnVector[0])); } + @Override + public ColumnarBatch withReplacedColumnVector(int ordinal, StructField newColumnSchema, + ColumnVector newVector) { + ArrayList newStructFields = new ArrayList<>(schema.fields()); + newStructFields.set(ordinal, newColumnSchema); + StructType newSchema = new StructType(newStructFields); + + ArrayList newColumnVectors = new ArrayList<>(columnVectors); + newColumnVectors.set(ordinal, newVector); + return new DefaultColumnarBatch( + size, newSchema, newColumnVectors.toArray(new ColumnVector[0])); + } + @Override public int getSize() { return size; diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java index bbada452c0..8f60b3848c 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java @@ -33,6 +33,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.*; import io.delta.kernel.internal.util.InternalUtils; @@ -128,6 +129,11 @@ public MapValue getMap(int ordinal) { return (MapValue) parsedValues[ordinal]; } + @Override + public VariantValue getVariant(int ordinal) { + throw new UnsupportedOperationException("not yet implemented"); + } + private static void throwIfTypeMismatch(String expType, boolean hasExpType, JsonNode jsonNode) { if (!hasExpType) { throw new RuntimeException( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java new file mode 100644 index 0000000000..c4c283e960 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/value/DefaultVariantValue.java @@ -0,0 +1,63 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.value; + +import java.util.Arrays; + +import io.delta.kernel.data.VariantValue; + +/** + * Default implementation of a Delta kernel VariantValue. + */ +public class DefaultVariantValue implements VariantValue { + private final byte[] value; + private final byte[] metadata; + + public DefaultVariantValue(byte[] value, byte[] metadata) { + this.value = value; + this.metadata = metadata; + } + + @Override + public byte[] getValue() { + return value; + } + + @Override + public byte[] getMetadata() { + return metadata; + } + + @Override + public String toString() { + return "VariantValue{value=" + Arrays.toString(value) + + ", metadata=" + Arrays.toString(metadata) + '}'; + } + + /** + * Compare two variants in bytes. The variant equality is more complex than it, and we haven't + * supported it in the user surface yet. This method is only intended for tests. + */ + @Override + public boolean equals(Object other) { + if (other instanceof DefaultVariantValue) { + return Arrays.equals(value, ((DefaultVariantValue) other).getValue()) && + Arrays.equals(metadata, ((DefaultVariantValue) other).getMetadata()); + } else { + return false; + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java index 8196b93a17..7d8a997681 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/AbstractColumnVector.java @@ -22,6 +22,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -78,6 +79,10 @@ public boolean isNullAt(int rowId) { return nullability.get()[rowId]; } + public Optional getNullability() { + return nullability; + } + @Override public boolean getBoolean(int rowId) { throw unsupportedDataAccessException("boolean"); @@ -138,6 +143,11 @@ public ArrayValue getArray(int rowId) { throw unsupportedDataAccessException("array"); } + @Override + public VariantValue getVariant(int rowId) { + throw unsupportedDataAccessException("variant"); + } + // TODO no need to override these here; update default implementations in `ColumnVector` // to have a more informative exception message protected UnsupportedOperationException unsupportedDataAccessException(String accessType) { diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java index bdadea6110..e7f5c5a25c 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultArrayVector.java @@ -88,4 +88,12 @@ public ColumnVector getElements() { } }; } + + public ColumnVector getElementVector() { + return elementVector; + } + + public int[] getOffsets() { + return offsets; + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java index aad449d1eb..925a11746d 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java @@ -172,6 +172,13 @@ public ColumnVector getChild(int ordinal) { (rowId) -> (Row) rowIdToValueAccessor.apply(rowId)); } + @Override + public VariantValue getVariant(int rowId) { + assertValidRowId(rowId); + throwIfUnsafeAccess(VariantType.class, "variant"); + return (VariantValue) rowIdToValueAccessor.apply(rowId); + } + private void throwIfUnsafeAccess( Class expDataType, String accessType) { if (!expDataType.isAssignableFrom(dataType.getClass())) { String msg = String.format( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java index ee0a0d4feb..403bcd5658 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultMapVector.java @@ -98,4 +98,16 @@ public ColumnVector getValues() { } }; } + + public ColumnVector getKeyVector() { + return keyVector; + } + + public ColumnVector getValueVector() { + return valueVector; + } + + public int[] getOffsets() { + return offsets; + } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java index a0aac0f12b..1656d572f8 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java @@ -23,6 +23,7 @@ import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; @@ -155,6 +156,12 @@ public ArrayValue getArray(int rowId) { return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal); } + @Override + public VariantValue getVariant(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getVariant(columnOrdinal); + } + @Override public ColumnVector getChild(int childOrdinal) { StructType structType = (StructType) dataType; diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java new file mode 100644 index 0000000000..a12c426fda --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultVariantVector.java @@ -0,0 +1,93 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.vector; + +import java.util.Optional; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.VariantValue; +import io.delta.kernel.types.DataType; + +import static io.delta.kernel.internal.util.Preconditions.checkArgument; +import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue; + +/** + * {@link io.delta.kernel.data.ColumnVector} implementation for variant type data. + */ +public class DefaultVariantVector + extends AbstractColumnVector { + private final ColumnVector valueVector; + private final ColumnVector metadataVector; + + + /** + * Create an instance of {@link io.delta.kernel.data.ColumnVector} for array type. + * + * @param size number of elements in the vector. + * @param type {@code variant} datatype definition. + * @param nullability Optional array of nullability value for each element in the vector. + * All values in the vector are considered non-null when parameter is + * empty. + * @param value The child binary column vector representing each variant's values. + * @param metadata The child binary column vector representing each variant's metadata. + */ + public DefaultVariantVector( + int size, + DataType type, + Optional nullability, + ColumnVector value, + ColumnVector metadata) { + super(size, type, nullability); + this.valueVector = requireNonNull(value, "value is null"); + this.metadataVector = requireNonNull(metadata, "metadata is null"); + } + + /** + * Get the value at given {@code rowId}. The return value is undefined and can be + * anything, if the slot for {@code rowId} is null. + * + * @param rowId + * @return + */ + @Override + public VariantValue getVariant(int rowId) { + checkValidRowId(rowId); + if (isNullAt(rowId)) { + return null; + } + + return new DefaultVariantValue( + valueVector.getBinary(rowId), metadataVector.getBinary(rowId)); + } + + /** + * Get the child column vector at the given {@code ordinal}. Variants should only have two + * child vectors, one for value and one for metadata. + * + * @param ordinal + * @return + */ + @Override + public ColumnVector getChild(int ordinal) { + checkArgument(ordinal >= 0 && ordinal < 2, "Invalid ordinal " + ordinal); + if (ordinal == 0) { + return valueVector; + } else { + return metadataVector; + } + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java index 49c1fe00a2..f256c6300c 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java @@ -20,6 +20,7 @@ import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.DataType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -137,6 +138,12 @@ public ArrayValue getArray(int rowId) { return underlyingVector.getArray(offset + rowId); } + @Override + public VariantValue getVariant(int rowId) { + checkValidRowId(rowId); + return underlyingVector.getVariant(offset + rowId); + } + @Override public ColumnVector getChild(int ordinal) { return new DefaultViewVector(underlyingVector.getChild(ordinal), offset, offset + size); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java index b9c7f0f57d..a814f5f4c9 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/DefaultExpressionEvaluator.java @@ -15,6 +15,7 @@ */ package io.delta.kernel.defaults.internal.expressions; +import java.util.Arrays; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; @@ -32,8 +33,7 @@ import static io.delta.kernel.internal.util.ExpressionUtils.getRight; import static io.delta.kernel.internal.util.ExpressionUtils.getUnaryChild; import static io.delta.kernel.internal.util.Preconditions.checkArgument; -import io.delta.kernel.defaults.internal.data.vector.DefaultBooleanVector; -import io.delta.kernel.defaults.internal.data.vector.DefaultConstantVector; +import io.delta.kernel.defaults.internal.data.vector.*; import static io.delta.kernel.defaults.internal.DefaultEngineErrors.unsupportedExpressionException; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.*; import static io.delta.kernel.defaults.internal.expressions.DefaultExpressionUtils.booleanWrapperVector; @@ -47,6 +47,7 @@ */ public class DefaultExpressionEvaluator implements ExpressionEvaluator { private final Expression expression; + private final StructType inputSchema; /** * Create a {@link DefaultExpressionEvaluator} instance bound to the given expression and @@ -67,12 +68,14 @@ public DefaultExpressionEvaluator( "Expression %s does not match expected output type %s", expression, outputType); throw unsupportedExpressionException(expression, reason); } + // TODO(richardc-db): Hack to avoid needing to pass the schema into the expression. + this.inputSchema = inputSchema; this.expression = transformResult.expression; } @Override public ColumnVector eval(ColumnarBatch input) { - return new ExpressionEvalVisitor(input).visit(expression); + return new ExpressionEvalVisitor(input, inputSchema).visit(expression); } @Override @@ -293,6 +296,20 @@ ExpressionTransformResult visitLike(final Predicate like) { return new ExpressionTransformResult(transformedExpression, BooleanType.BOOLEAN); } + ExpressionTransformResult visitVariantCoalesce(ScalarExpression variantCoalesce) { + checkArgument( + variantCoalesce.getChildren().size() == 1, + "Expected one input to 'variant_coalesce but received %s", + variantCoalesce.getChildren().size() + ); + Expression transformedVariantInput = visit(childAt(variantCoalesce, 0)).expression; + return new ExpressionTransformResult( + new ScalarExpression( + "VARIANT_COALESCE", + Arrays.asList(transformedVariantInput)), + VariantType.VARIANT); + } + private Predicate validateIsPredicate( Expression baseExpression, ExpressionTransformResult result) { @@ -333,9 +350,11 @@ private Expression transformBinaryComparator(Predicate predicate) { */ private static class ExpressionEvalVisitor extends ExpressionVisitor { private final ColumnarBatch input; + private final StructType inputSchema; - ExpressionEvalVisitor(ColumnarBatch input) { + ExpressionEvalVisitor(ColumnarBatch input, StructType inputSchema) { this.input = input; + this.inputSchema = inputSchema; } /* @@ -575,6 +594,108 @@ ColumnVector visitLike(final Predicate like) { .collect(toList())); } + ColumnVector visitVariantCoalesce(ScalarExpression variantCoalesce) { + return variantCoalesceImpl( + visit(childAt(variantCoalesce, 0)), + inputSchema.at(0).getDataType() + ); + } + + private ColumnVector variantCoalesceImpl(ColumnVector inputVec, DataType dt) { + if (dt instanceof StructType) { + StructType structType = (StructType) dt; + DefaultStructVector structVec = (DefaultStructVector) inputVec; + ColumnVector[] structColVecs = new ColumnVector[structType.length()]; + for (int i = 0; i < structType.length(); i++) { + if (structType.at(i).getDataType() instanceof ArrayType || + structType.at(i).getDataType() instanceof StructType || + structType.at(i).getDataType() instanceof MapType || + structType.at(i).getDataType() instanceof VariantType) { + structColVecs[i] = variantCoalesceImpl( + structVec.getChild(i), + structType.at(i).getDataType() + ); + } else { + structColVecs[i] = structVec.getChild(i); + } + } + return new DefaultStructVector( + structVec.getSize(), + structType, + structVec.getNullability(), + structColVecs + ); + } + + if (dt instanceof ArrayType) { + ArrayType arrType = (ArrayType) dt; + DefaultArrayVector arrVec = (DefaultArrayVector) inputVec; + + if (arrType.getElementType() instanceof ArrayType || + arrType.getElementType() instanceof StructType || + arrType.getElementType() instanceof MapType || + arrType.getElementType() instanceof VariantType) { + ColumnVector elementVec = variantCoalesceImpl( + arrVec.getElementVector(), + arrType.getElementType() + ); + + return new DefaultArrayVector( + arrVec.getSize(), + arrType, + arrVec.getNullability(), + arrVec.getOffsets(), + elementVec + ); + } + return arrVec; + } + + if (dt instanceof MapType) { + MapType mapType = (MapType) dt; + DefaultMapVector mapVec = (DefaultMapVector) inputVec; + + ColumnVector valueVec = mapVec.getValueVector(); + if (mapType.getValueType() instanceof ArrayType || + mapType.getValueType() instanceof StructType || + mapType.getValueType() instanceof MapType || + mapType.getValueType() instanceof VariantType) { + valueVec = variantCoalesceImpl( + mapVec.getValueVector(), + mapType.getValueType() + ); + } + return new DefaultMapVector( + mapVec.getSize(), + mapType, + mapVec.getNullability(), + mapVec.getOffsets(), + mapVec.getKeyVector(), + valueVec + ); + } + + DefaultStructVector structBackingVariant = (DefaultStructVector) inputVec; + checkArgument( + structBackingVariant.getChild(0).getDataType() instanceof BinaryType, + "Expected struct field 0 backing variant to be binary type. Actual: %s", + structBackingVariant.getChild(0).getDataType() + ); + checkArgument( + structBackingVariant.getChild(1).getDataType() instanceof BinaryType, + "Expected struct field 1 backing variant to be binary type. Actual: %s", + structBackingVariant.getChild(1).getDataType() + ); + + return new DefaultVariantVector( + structBackingVariant.getSize(), + VariantType.VARIANT, + structBackingVariant.getNullability(), + structBackingVariant.getChild(0), + structBackingVariant.getChild(1) + ); + } + /** * Utility method to evaluate inputs to the binary input expression. Also validates the * evaluated expression result {@link ColumnVector}s are of the same size. diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java index 01888d6f77..726cdadf83 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/expressions/ExpressionVisitor.java @@ -61,6 +61,8 @@ abstract class ExpressionVisitor { abstract R visitLike(Predicate predicate); + abstract R visitVariantCoalesce(ScalarExpression variantCoalesce); + final R visit(Expression expression) { if (expression instanceof PartitionValueExpression) { return visitPartitionValue((PartitionValueExpression) expression); @@ -109,6 +111,8 @@ private R visitScalarExpression(ScalarExpression expression) { return visitCoalesce(expression); case "LIKE": return visitLike(new Predicate(name, children)); + case "VARIANT_COALESCE": + return visitVariantCoalesce(expression); default: throw new UnsupportedOperationException( String.format("Scalar expression `%s` is not supported.", name)); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java index dce7c5244c..644797940d 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetColumnReaders.java @@ -82,6 +82,15 @@ public static Converter createConverter( return createTimestampConverter(initialBatchSize, typeFromFile); } else if (typeFromClient instanceof TimestampNTZType) { return createTimestampNtzConverter(initialBatchSize, typeFromFile); + } else if (typeFromClient instanceof VariantType) { + // TODO(r.chen): Is converting the typeFromFile to the readSchea ok? + // We lose the field metadata from the client. + return new RowColumnReader( + initialBatchSize, + // The physical schema representing variants can be different per file so we must + // infer the read schema from the type from file. + (StructType) ParquetSchemaUtils.toKernelType(typeFromFile), + (GroupType) typeFromFile); } throw new UnsupportedOperationException(typeFromClient + " is not supported"); diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java index e525eed818..38307dde7b 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/parquet/ParquetSchemaUtils.java @@ -20,7 +20,7 @@ import static java.lang.String.format; import org.apache.parquet.schema.*; -import org.apache.parquet.schema.LogicalTypeAnnotation.DecimalLogicalTypeAnnotation; +import org.apache.parquet.schema.LogicalTypeAnnotation.*; import org.apache.parquet.schema.Type.Repetition; import static org.apache.parquet.schema.LogicalTypeAnnotation.TimeUnit.MICROS; import static org.apache.parquet.schema.LogicalTypeAnnotation.decimalType; @@ -157,6 +157,93 @@ public static MessageType toParquetSchema(StructType structType) { return new MessageType("Default Kernel Schema", types); } + /** + * Convert the given Parquet data type to a Kernel data type. + * + * TODO(r.chen): Test this function. + * + * @param type Parquet type object + * @return {@link DataType} representing the Parquet type in Kernel. + */ + public static DataType toKernelType(Type type) { + if (type.isPrimitive()) { + PrimitiveType pt = type.asPrimitiveType(); + + if (pt.getOriginalType() == OriginalType.DECIMAL) { + DecimalLogicalTypeAnnotation dlta = + (DecimalLogicalTypeAnnotation) pt.getLogicalTypeAnnotation(); + return new DecimalType(dlta.getPrecision(), dlta.getScale()); + } else if (pt.getPrimitiveTypeName() == BOOLEAN) { + return BooleanType.BOOLEAN; + } else if (pt.getPrimitiveTypeName() == INT32) { + if (pt.getOriginalType() == OriginalType.INT_8) { + return ByteType.BYTE; + } else if (pt.getOriginalType() == OriginalType.INT_16) { + return ShortType.SHORT; + } else if (pt.getLogicalTypeAnnotation() == LogicalTypeAnnotation.dateType()) { + return DateType.DATE; + } + return IntegerType.INTEGER; + } else if (pt.getPrimitiveTypeName() == INT64) { + if (pt.getOriginalType() == OriginalType.TIMESTAMP_MICROS) { + TimestampLogicalTypeAnnotation tlta = + (TimestampLogicalTypeAnnotation) pt.getLogicalTypeAnnotation(); + return tlta.isAdjustedToUTC() ? + TimestampType.TIMESTAMP : TimestampNTZType.TIMESTAMP_NTZ; + } + return LongType.LONG; + } else if (pt.getPrimitiveTypeName() == FLOAT) { + return FloatType.FLOAT; + } else if (pt.getPrimitiveTypeName() == DOUBLE) { + return DoubleType.DOUBLE; + } else if (pt.getPrimitiveTypeName() == BINARY) { + if (pt.getLogicalTypeAnnotation() == LogicalTypeAnnotation.stringType()) { + return StringType.STRING; + } else { + return BinaryType.BINARY; + } + } else { + throw new UnsupportedOperationException( + "Converting the given Parquet data type to Kernel is not supported: " + type); + } + } else { + if (type.getLogicalTypeAnnotation() == LogicalTypeAnnotation.listType()) { + GroupType gt = (GroupType) type; + Type childType = gt.getType(0); + return new ArrayType( + toKernelType(childType), childType.getRepetition() == OPTIONAL); + } else if (type.getLogicalTypeAnnotation() == LogicalTypeAnnotation.mapType()) { + GroupType gt = (GroupType) type; + Type keyType = gt.getType(0); + Type valueType = gt.getType(1); + return new MapType( + toKernelType(keyType), + toKernelType(valueType), + valueType.getRepetition() == OPTIONAL + ); + } else { + List kernelFields = new ArrayList<>(); + GroupType gt = (GroupType) type; + for (Type parquetType : gt.getFields()) { + FieldMetadata.Builder metadataBuilder = FieldMetadata.builder(); + if (type.getId() != null) { + metadataBuilder.putLong( + ColumnMapping.PARQUET_FIELD_ID_KEY, + (long) (type.getId().intValue()) + ); + } + kernelFields.add(new StructField( + parquetType.getName(), + toKernelType(parquetType), + parquetType.getRepetition() == OPTIONAL, + metadataBuilder.build() + )); + } + return new StructType(kernelFields); + } + } + } + private static List pruneFields( GroupType type, StructType deltaDataType, boolean hasFieldIds) { // prune fields including nested pruning like in pruneSchema diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java index 1474bd2db6..5b76dc8868 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/integration/DataBuilderUtils.java @@ -26,6 +26,7 @@ import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; +import io.delta.kernel.data.VariantValue; import io.delta.kernel.types.StructType; import static io.delta.kernel.internal.util.Preconditions.checkArgument; @@ -165,5 +166,10 @@ public MapValue getMap(int ordinal) { throw new UnsupportedOperationException( "map type unsupported for TestColumnBatchBuilder; use scala test utilities"); } + + @Override + public VariantValue getVariant(int ordinal) { + return (VariantValue) values.get(ordinal); + } } } diff --git a/kernel/kernel-defaults/src/test/scala-spark-3.5/DeltaExcludedBySparkVersionTestMixinShims.scala b/kernel/kernel-defaults/src/test/scala-spark-3.5/DeltaExcludedBySparkVersionTestMixinShims.scala new file mode 100644 index 0000000000..91526dc919 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala-spark-3.5/DeltaExcludedBySparkVersionTestMixinShims.scala @@ -0,0 +1,45 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.defaults + +import org.scalatest.funsuite.AnyFunSuite + +import io.delta.kernel.defaults.utils.TestUtils + +trait DeltaExcludedBySparkVersionTestMixinShims extends AnyFunSuite with TestUtils { + /** + * Tests that are meant for Delta compiled against Spark Latest Release only. Executed since this + * is the Spark Latest Release shim. + */ + protected def testSparkLatestOnly( + testName: String, testTags: org.scalatest.Tag*) + (testFun: => Any) + (implicit pos: org.scalactic.source.Position): Unit = { + test(testName, testTags: _*)(testFun)(pos) + } + + /** + * Tests that are meant for Delta compiled against Spark Master Release only. Ignored since this + * is the Spark Latest Release shim. + */ + protected def testSparkMasterOnly( + testName: String, testTags: org.scalatest.Tag*) + (testFun: => Any) + (implicit pos: org.scalactic.source.Position): Unit = { + ignore(testName, testTags: _*)(testFun)(pos) + } +} diff --git a/kernel/kernel-defaults/src/test/scala-spark-3.5/shims/VariantShims.scala b/kernel/kernel-defaults/src/test/scala-spark-3.5/shims/VariantShims.scala new file mode 100644 index 0000000000..88c53cda79 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala-spark-3.5/shims/VariantShims.scala @@ -0,0 +1,53 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.defaults + +import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.DataType + +object VariantShims { + + /** + * Spark's variant type is implemented for Spark 4.0 and is not implemented in Spark 3.5. Thus, + * any Spark 3.5 DataType cannot be a variant type. + */ + def isVariantType(dt: DataType): Boolean = false + + /** + * Converts Spark's variant value to Kernel Default's variant value for testing. + * This method should not be called when depending on Spark 3.5 because Spark 3.5 cannot create + * variants. + */ + def convertToKernelVariant(v: Any): DefaultVariantValue = + throw new UnsupportedOperationException("Not supported") + + /** + * Retrieves a Spark variant from a Spark row and converts it to Kernel Default's variant value + * for testing. + * + * Should not be called when testing using Spark 3.5. + */ + def getVariantAndConvertToKernel(r: Row, ordinal: Int): DefaultVariantValue = + throw new UnsupportedOperationException("Not supported") + + /** + * Returns Spark's variant type singleton. This should not be called when testing with Spark 3.5. + */ + def getSparkVariantType(): DataType = throw new UnsupportedOperationException("Not supported") +} diff --git a/kernel/kernel-defaults/src/test/scala-spark-master/DeltaExcludedBySparkVersionTestMixinShims.scala b/kernel/kernel-defaults/src/test/scala-spark-master/DeltaExcludedBySparkVersionTestMixinShims.scala new file mode 100644 index 0000000000..1a1c97c862 --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala-spark-master/DeltaExcludedBySparkVersionTestMixinShims.scala @@ -0,0 +1,47 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.defaults + +import org.scalatest.funsuite.AnyFunSuite + +import io.delta.kernel.defaults.utils.TestUtils + +trait DeltaExcludedBySparkVersionTestMixinShims extends AnyFunSuite with TestUtils { + + /** + * Tests that are meant for Delta compiled against Spark Latest Release only. Ignored since this + * is the Spark Master shim. + */ + protected def testSparkLatestOnly( + testName: String, testTags: org.scalatest.Tag*) + (testFun: => Any) + (implicit pos: org.scalactic.source.Position): Unit = { + ignore(testName + " (Spark Latest Release Only)", testTags: _*)(testFun)(pos) + } + + /** + * Tests that are meant for Delta compiled against Spark Master (4.0+). Executed since this is the + * Spark Master shim. + */ + protected def testSparkMasterOnly( + testName: String, testTags: org.scalatest.Tag*) + (testFun: => Any) + (implicit pos: org.scalactic.source.Position): Unit = { + test(testName, testTags: _*)(testFun)(pos) + } + +} diff --git a/kernel/kernel-defaults/src/test/scala-spark-master/shims/VariantShims.scala b/kernel/kernel-defaults/src/test/scala-spark-master/shims/VariantShims.scala new file mode 100644 index 0000000000..20d0b2e56d --- /dev/null +++ b/kernel/kernel-defaults/src/test/scala-spark-master/shims/VariantShims.scala @@ -0,0 +1,46 @@ +/* + * Copyright (2024) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.delta.kernel.defaults + +import io.delta.kernel.defaults.internal.data.value.DefaultVariantValue + +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{DataType, VariantType} +import org.apache.spark.unsafe.types.VariantVal + +object VariantShims { + + /** Spark's variant type is only implemented in Spark 4.0 and above. */ + def isVariantType(dt: DataType): Boolean = dt.isInstanceOf[VariantType] + + /** Converts Spark's variant value to Kernel Default's variant value for testing. */ + def convertToKernelVariant(v: Any): DefaultVariantValue = { + val sparkVariant = v.asInstanceOf[VariantVal] + new DefaultVariantValue(sparkVariant.getValue(), sparkVariant.getMetadata()) + } + + /** + * Retrieves a Spark variant from a Spark row and converts it to Kernel Default's variant value + * for testing. + */ + def getVariantAndConvertToKernel(r: Row, ordinal: Int): DefaultVariantValue = { + val sparkVariant = r.getAs[VariantVal](ordinal) + new DefaultVariantValue(sparkVariant.getValue(), sparkVariant.getMetadata()) + } + + def getSparkVariantType(): DataType = VariantType +} diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala index bc32b5d011..c0c1eb08fd 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableReadsSuite.scala @@ -316,12 +316,12 @@ class DeltaTableReadsSuite extends AnyFunSuite with TestUtils { Seq(TestRow(2), TestRow(2), TestRow(2)), TestRow("2", "2", TestRow(2, 2L)), "2" - ) :: Nil) + ) :: Nil).toSeq checkTable( path = path, expectedAnswer = expectedAnswer, - readCols = readCols + readCols = readCols.toSeq ) } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala index ad0dd8bc7d..c9153524c3 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/DeltaTableWritesSuite.scala @@ -42,7 +42,6 @@ import io.delta.kernel.utils.{CloseableIterable, CloseableIterator} import java.util.Optional import scala.collection.JavaConverters._ -import scala.collection.immutable.Seq class DeltaTableWritesSuite extends DeltaTableWriteSuiteBase with ParquetSuiteBase { val OBJ_MAPPER = new ObjectMapper() @@ -508,7 +507,7 @@ class DeltaTableWritesSuite extends DeltaTableWriteSuiteBase with ParquetSuiteBa val parquetAllTypes = goldenTablePath("parquet-all-types") val schema = removeUnsupportedTypes(tableSchema(parquetAllTypes)) - val data = readTableUsingKernel(engine, parquetAllTypes, schema).to[Seq] + val data = readTableUsingKernel(engine, parquetAllTypes, schema) val dataWithPartInfo = Seq(Map.empty[String, Literal] -> data) appendData(engine, tblPath, isNewTable = true, schema, Seq.empty, dataWithPartInfo) @@ -551,7 +550,7 @@ class DeltaTableWritesSuite extends DeltaTableWriteSuiteBase with ParquetSuiteBa "timestampType" ) val casePreservingPartCols = - casePreservingPartitionColNames(schema, partCols.asJava).asScala.to[Seq] + casePreservingPartitionColNames(schema, partCols.asJava).asScala.toSeq // get the partition values from the data batch at the given rowId def getPartitionValues(batch: ColumnarBatch, rowId: Int): Map[String, Literal] = { @@ -584,7 +583,7 @@ class DeltaTableWritesSuite extends DeltaTableWriteSuiteBase with ParquetSuiteBa }.toMap } - val data = readTableUsingKernel(engine, parquetAllTypes, schema).to[Seq] + val data = readTableUsingKernel(engine, parquetAllTypes, schema) // From the above table read data, convert each row as a new batch with partition info // Take the values of the partitionCols from the data and create a new batch with the diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayEngineMetricsSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayEngineMetricsSuite.scala index 0f54b810bf..670e56dd71 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayEngineMetricsSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/LogReplayEngineMetricsSuite.scala @@ -394,7 +394,7 @@ trait FileReadMetrics { self: Object => } } - def getVersionsRead: Seq[Long] = versionsRead + def getVersionsRead: Seq[Long] = versionsRead.toSeq def getLastCheckpointMetadataReadCalls: Int = lastCheckpointMetadataReadCalls diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala index 79119d05ea..65dd86c113 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/ScanSuite.scala @@ -21,18 +21,21 @@ import java.time.{Instant, OffsetDateTime} import java.time.temporal.ChronoUnit import java.util.Optional +import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ import io.delta.golden.GoldenTableUtils.goldenTablePath import org.apache.hadoop.conf.Configuration -import org.apache.spark.sql.{Row => SparkRow} +import org.apache.spark.sql.{DataFrame, Row => SparkRow} import org.apache.spark.sql.catalyst.plans.SQLHelper import org.apache.spark.sql.delta.{DeltaConfigs, DeltaLog} import org.apache.spark.sql.types.{IntegerType => SparkIntegerType, StructField => SparkStructField, StructType => SparkStructType} import org.scalatest.funsuite.AnyFunSuite -import io.delta.kernel.engine.{JsonHandler, ParquetHandler, Engine} +import io.delta.kernel.Scan import io.delta.kernel.data.{ColumnarBatch, ColumnVector, FilteredColumnarBatch, Row} +import io.delta.kernel.defaults.utils.TestRow +import io.delta.kernel.engine.{Engine, JsonHandler, ParquetHandler} import io.delta.kernel.expressions.{AlwaysFalse, AlwaysTrue, And, Column, Or, Predicate, ScalarExpression} import io.delta.kernel.expressions.Literal._ import io.delta.kernel.types.StructType @@ -40,12 +43,18 @@ import io.delta.kernel.types.StringType.STRING import io.delta.kernel.types.IntegerType.INTEGER import io.delta.kernel.utils.{CloseableIterator, FileStatus} import io.delta.kernel.{Scan, Snapshot, Table} -import io.delta.kernel.internal.util.InternalUtils -import io.delta.kernel.internal.{InternalScanFileUtils, ScanImpl} import io.delta.kernel.defaults.engine.{DefaultJsonHandler, DefaultParquetHandler, DefaultEngine} +import io.delta.kernel.internal.{InternalScanFileUtils, ScanImpl} +import io.delta.kernel.internal.data.ScanStateRow +import io.delta.kernel.internal.util.InternalUtils +import io.delta.kernel.internal.util.Utils.singletonCloseableIterator +import io.delta.kernel.internal.util.Utils.toCloseableIterator import io.delta.kernel.defaults.utils.{ExpressionTestUtils, TestUtils} -class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with SQLHelper { +class ScanSuite extends AnyFunSuite + with ExpressionTestUtils + with SQLHelper + with DeltaExcludedBySparkVersionTestMixinShims { import io.delta.kernel.defaults.ScanSuite._ @@ -1588,6 +1597,98 @@ class ScanSuite extends AnyFunSuite with TestUtils with ExpressionTestUtils with ) } } + + private def testReadWithVariant(testName: String)(df: => DataFrame): Unit = { + testSparkMasterOnly(testName) { + withTable("test_table") { + df.write + .format("delta") + .mode("overwrite") + .saveAsTable("test_table") + val path = spark.sql("describe table extended `test_table`") + .where("col_name = 'Location'") + .collect()(0) + .getString(1) + .replace("file:", "") + + val kernelSchema = tableSchema(path) + + val snapshot = latestSnapshot(path) + val scan = snapshot.getScanBuilder(defaultEngine).build() + val scanState = scan.getScanState(defaultEngine) + val physicalReadSchema = + ScanStateRow.getPhysicalDataReadSchema(defaultEngine, scanState) + val scanFilesIter = scan.getScanFiles(defaultEngine) + + val readRows = ArrayBuffer[Row]() + while (scanFilesIter.hasNext()) { + val scanFilesBatch = scanFilesIter.next() + val scanFileRows = scanFilesBatch.getRows() + while (scanFileRows.hasNext()) { + val scanFileRow = scanFileRows.next() + val fileStatus = InternalScanFileUtils.getAddFileStatus(scanFileRow) + + val physicalDataIter = defaultEngine.getParquetHandler.readParquetFiles( + singletonCloseableIterator(fileStatus), + physicalReadSchema, + Optional.empty()) + + val transformedRowsIter = Scan.transformPhysicalData( + defaultEngine, + scanState, + scanFileRow, + physicalDataIter + ) + + val transformedRows = transformedRowsIter.asScala.toSeq.map(_.getRows).flatMap(_.toSeq) + readRows.appendAll(transformedRows) + } + } + + checkAnswer(readRows.toSeq, df.collect().map(TestRow(_))) + } + } + } + + testReadWithVariant("basic variant") { + spark.range(0, 1, 1, 1).selectExpr( + "parse_json(cast(id as string)) as basic_v", + "named_struct('v', parse_json(cast(id as string))) as struct_v", + "named_struct('v', array(parse_json(cast(id as string)))) as struct_array_v", + "named_struct('v', map('key', parse_json(cast(id as string)))) as struct_map_v", + "named_struct('top', named_struct('v', parse_json(cast(id as string)))) as struct_struct_v", + """array( + parse_json(cast(id as string)), + parse_json(cast(id as string)), + parse_json(cast(id as string)) + ) as array_v""", + """array( + named_struct('v', parse_json(cast(id as string))), + named_struct('v', parse_json(cast(id as string))), + named_struct('v', parse_json(cast(id as string))) + ) as array_struct_v""", + """array( + map('v', parse_json(cast(id as string))), + map('k1', parse_json(cast(id as string)), 'k2', parse_json(cast(id as string))), + map('v', parse_json(cast(id as string))) + ) as array_map_v""", + "map('test', parse_json(cast(id as string))) as map_value_v", + "map('test', named_struct('v', parse_json(cast(id as string)))) as map_struct_v" + ) + } + + testReadWithVariant("basic null variant") { + spark.range(0, 10, 1, 1).selectExpr( + "cast(null as variant) basic_v", + "named_struct('v', cast(null as variant)) as struct_v", + """array( + parse_json(cast(id as string)), + parse_json(cast(id as string)), + null + ) as array_v""", + "map('test', cast(null as variant)) as map_value_v" + ) + } } object ScanSuite { diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala index 661d286a3c..578b92b52a 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestRow.scala @@ -16,9 +16,11 @@ package io.delta.kernel.defaults.utils import scala.collection.JavaConverters._ +import scala.collection.mutable.{Seq => MutableSeq} import org.apache.spark.sql.{types => sparktypes} import org.apache.spark.sql.{Row => SparkRow} import io.delta.kernel.data.{ArrayValue, ColumnVector, MapValue, Row} +import io.delta.kernel.defaults.VariantShims import io.delta.kernel.types._ import java.sql.Timestamp @@ -44,7 +46,7 @@ import java.time.{Instant, LocalDate, LocalDateTime, ZoneOffset} * - ArrayType --> Seq[Any] * - MapType --> Map[Any, Any] * - StructType --> TestRow - * + * - VariantType --> VariantVal * For complex types array and map, the inner elements types should align with this mapping. */ class TestRow(val values: Array[Any]) { @@ -108,9 +110,10 @@ object TestRow { case _: ArrayType => arrayValueToScalaSeq(row.getArray(i)) case _: MapType => mapValueToScalaMap(row.getMap(i)) case _: StructType => TestRow(row.getStruct(i)) + case _: VariantType => row.getVariant(i) case _ => throw new UnsupportedOperationException("unrecognized data type") } - }) + }.toSeq) } def apply(row: SparkRow): TestRow = { @@ -133,13 +136,14 @@ object TestRow { case _: sparktypes.BinaryType => obj.asInstanceOf[Array[Byte]] case _: sparktypes.DecimalType => obj.asInstanceOf[java.math.BigDecimal] case arrayType: sparktypes.ArrayType => - obj.asInstanceOf[Seq[Any]] + obj.asInstanceOf[MutableSeq[Any]] .map(decodeCellValue(arrayType.elementType, _)) case mapType: sparktypes.MapType => obj.asInstanceOf[Map[Any, Any]].map { case (k, v) => decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v) } case _: sparktypes.StructType => TestRow(obj.asInstanceOf[SparkRow]) + case t if VariantShims.isVariantType(t) => VariantShims.convertToKernelVariant(obj) case _ => throw new UnsupportedOperationException("unrecognized data type") } } @@ -173,6 +177,7 @@ object TestRow { decodeCellValue(mapType.keyType, k) -> decodeCellValue(mapType.valueType, v) } case _: sparktypes.StructType => TestRow(row.getStruct(i)) + case t if VariantShims.isVariantType(t) => VariantShims.getVariantAndConvertToKernel(row, i) case _ => throw new UnsupportedOperationException("unrecognized data type") } }) @@ -204,6 +209,7 @@ object TestRow { TestRow.fromSeq(Seq.range(0, dataType.length()).map { ordinal => getAsTestObject(vector.getChild(ordinal), rowId) }) + case _: VariantType => vector.getVariant(rowId) case _ => throw new UnsupportedOperationException("unrecognized data type") } } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index a97aee85a7..35a31e1618 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import io.delta.golden.GoldenTableUtils import io.delta.kernel.{Scan, Snapshot, Table} import io.delta.kernel.data.{ColumnVector, ColumnarBatch, FilteredColumnarBatch, MapValue, Row} +import io.delta.kernel.defaults.VariantShims import io.delta.kernel.defaults.engine.DefaultEngine import io.delta.kernel.defaults.internal.data.vector.DefaultGenericVector import io.delta.kernel.engine.Engine @@ -71,7 +72,7 @@ trait TestUtils extends Assertions with SQLHelper { while (iter.hasNext) { result.append(iter.next()) } - result + result.toSeq } finally { iter.close() } @@ -157,7 +158,7 @@ trait TestUtils extends Assertions with SQLHelper { // for all primitive types Seq(new Column((basePath :+ field.getName).asJava.toArray(new Array[String](0)))); case _ => Seq.empty - } + }.toSeq } def collectScanFileRows(scan: Scan, engine: Engine = defaultEngine): Seq[Row] = { @@ -235,7 +236,7 @@ trait TestUtils extends Assertions with SQLHelper { } } } - result + result.toSeq } def readTableUsingKernel( @@ -684,7 +685,8 @@ trait TestUtils extends Assertions with SQLHelper { toSparkType(field.getDataType), field.isNullable ) - }) + }.toSeq) + case VariantType.VARIANT => VariantShims.getSparkVariantType() } }