diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java index f7da5564809..4fff4031228 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java @@ -19,6 +19,7 @@ import java.util.Map; import java.util.stream.Collectors; import org.apache.hadoop.fs.Path; +import org.apache.iceberg.exceptions.ValidationException; import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; import org.apache.spark.SparkConf; import org.apache.spark.sql.Column; @@ -48,6 +49,9 @@ public abstract class SparkIcebergCatalogIT extends SparkCommonIT { private static final String ICEBERG_DELETE_MODE = "write.delete.mode"; private static final String ICEBERG_UPDATE_MODE = "write.update.mode"; private static final String ICEBERG_MERGE_MODE = "write.merge.mode"; + private static final String ICEBERG_WRITE_DISTRIBUTION_MODE = "write.distribution-mode"; + private static final String ICEBERG_SORT_ORDER = "sort-order"; + private static final String ICEBERG_IDENTIFIER_FIELDS = "identifier-fields"; @Override protected String getCatalogName() { @@ -251,6 +255,15 @@ void testIcebergTableRowLevelOperations() { testIcebergMergeIntoUpdateOperation(); } + @Test + void testIcebergSQLExtensions() { + testIcebergPartitionFieldOperations(); + testIcebergBranchOperations(); + testIcebergTagOperations(); + testIcebergIdentifierOperations(); + testIcebergDistributionAndOrderingOperations(); + } + private void testMetadataColumns() { String tableName = "test_metadata_columns"; dropTableIfExists(tableName); @@ -503,6 +516,260 @@ private void testIcebergMergeIntoUpdateOperation() { }); } + private void testIcebergPartitionFieldOperations() { + List partitionFields = + Arrays.asList("name", "truncate(1, name), bucket(16, id), days(ts)"); + String partitionExpression = "name=a/name_trunc=a/id_bucket=4/ts_day=2024-01-01"; + String tableName = "test_iceberg_partition_field_operations"; + dropTableIfExists(tableName); + sql(getCreateIcebergSimpleTableString(tableName)); + + // add partition fields + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withMetadataColumns(getIcebergMetadataColumns()); + checker.check(tableInfo); + + partitionFields.forEach( + partitionField -> + sql(String.format("ALTER TABLE %s ADD PARTITION FIELD %s", tableName, partitionField))); + + tableInfo = getTableInfo(tableName); + checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withMetadataColumns(getIcebergMetadataColumns()) + .withIdentifyPartition(Collections.singletonList("name")) + .withTruncatePartition(1, "name") + .withBucketPartition(16, Collections.singletonList("id")) + .withDayPartition("ts"); + checker.check(tableInfo); + + sql( + String.format( + "INSERT INTO %s VALUES(2,'a',cast('2024-01-01 12:00:00' as timestamp));", tableName)); + List queryResult = getTableData(tableName); + Assertions.assertEquals(1, queryResult.size()); + Assertions.assertEquals("2,a,2024-01-01 12:00:00", queryResult.get(0)); + Path partitionPath = new Path(getTableLocation(tableInfo), partitionExpression); + checkDirExists(partitionPath); + + // replace partition fields + sql(String.format("ALTER TABLE %s REPLACE PARTITION FIELD ts_day WITH months(ts)", tableName)); + tableInfo = getTableInfo(tableName); + checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withMetadataColumns(getIcebergMetadataColumns()) + .withIdentifyPartition(Collections.singletonList("name")) + .withTruncatePartition(1, "name") + .withBucketPartition(16, Collections.singletonList("id")) + .withMonthPartition("ts"); + checker.check(tableInfo); + + // drop partition fields + sql(String.format("ALTER TABLE %s DROP PARTITION FIELD months(ts)", tableName)); + tableInfo = getTableInfo(tableName); + checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(getIcebergSimpleTableColumn()) + .withMetadataColumns(getIcebergMetadataColumns()) + .withIdentifyPartition(Collections.singletonList("name")) + .withTruncatePartition(1, "name") + .withBucketPartition(16, Collections.singletonList("id")); + checker.check(tableInfo); + } + + private void testIcebergBranchOperations() { + String tableName = "test_iceberg_branch_operations"; + String fullTableName = + String.format("%s.%s.%s", getCatalogName(), getDefaultDatabase(), tableName); + String branch1 = "branch1"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + // create branch and query data using branch + sql(String.format("INSERT INTO %s VALUES(1, '1', 1);", tableName)); + List tableData = getTableData(tableName); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + List snapshots = + getSparkSession() + .sql(String.format("SELECT snapshot_id FROM %s.snapshots", fullTableName)) + .collectAsList(); + Assertions.assertEquals(1, snapshots.size()); + long snapshotId = snapshots.get(0).getLong(0); + + sql(String.format("ALTER TABLE %s CREATE BRANCH IF NOT EXISTS `%s`", tableName, branch1)); + sql(String.format("INSERT INTO %s VALUES(2, '2', 2);", tableName)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + + tableData = + getQueryData(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, branch1)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + + sql(String.format("ALTER TABLE %s CREATE OR REPLACE BRANCH `%s`", tableName, branch1)); + tableData = + getQueryData( + String.format("SELECT * FROM %s VERSION AS OF '%s' ORDER BY id", tableName, branch1)); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + + // replace branch + sql( + String.format( + "ALTER TABLE %s REPLACE BRANCH `%s` AS OF VERSION %d RETAIN 1 DAYS", + tableName, branch1, snapshotId)); + tableData = + getQueryData(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, branch1)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + + // drop branch + sql(String.format("ALTER TABLE %s DROP BRANCH `%s`", tableName, branch1)); + Assertions.assertThrows( + ValidationException.class, + () -> sql(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, branch1))); + } + + private void testIcebergTagOperations() { + String tableName = "test_iceberg_tag_operations"; + String fullTableName = + String.format("%s.%s.%s", getCatalogName(), getDefaultDatabase(), tableName); + String tag1 = "tag1"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + // create tag and query data using tag + sql(String.format("INSERT INTO %s VALUES(1, '1', 1);", tableName)); + List tableData = getTableData(tableName); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + List snapshots = + getSparkSession() + .sql(String.format("SELECT snapshot_id FROM %s.snapshots", fullTableName)) + .collectAsList(); + Assertions.assertEquals(1, snapshots.size()); + long snapshotId = snapshots.get(0).getLong(0); + + sql(String.format("ALTER TABLE %s CREATE TAG IF NOT EXISTS `%s`", tableName, tag1)); + sql(String.format("INSERT INTO %s VALUES(2, '2', 2);", tableName)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + + tableData = getQueryData(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, tag1)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + + sql(String.format("ALTER TABLE %s CREATE OR REPLACE TAG `%s`", tableName, tag1)); + tableData = + getQueryData( + String.format("SELECT * FROM %s VERSION AS OF '%s' ORDER BY id", tableName, tag1)); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + + // replace tag + sql( + String.format( + "ALTER TABLE %s REPLACE TAG `%s` AS OF VERSION %d RETAIN 1 DAYS", + tableName, tag1, snapshotId)); + tableData = getQueryData(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, tag1)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + + // drop tag + sql(String.format("ALTER TABLE %s DROP TAG `%s`", tableName, tag1)); + Assertions.assertThrows( + ValidationException.class, + () -> sql(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, tag1))); + } + + private void testIcebergIdentifierOperations() { + String tableName = "test_iceberg_identifier_operations"; + List columnInfos = + Arrays.asList( + SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment", false), + SparkTableInfo.SparkColumnInfo.of("name", DataTypes.StringType, "", false), + SparkTableInfo.SparkColumnInfo.of("ts", DataTypes.TimestampType, null, true)); + dropTableIfExists(tableName); + sql( + String.format( + "CREATE TABLE %s (id INT COMMENT 'id comment' NOT NULL, name STRING COMMENT '' NOT NULL, age INT)", + tableName)); + SparkTableInfo tableInfo = getTableInfo(tableName); + SparkTableInfoChecker checker = + SparkTableInfoChecker.create() + .withName(tableName) + .withColumns(columnInfos) + .withMetadataColumns(getIcebergMetadataColumns()); + checker.check(tableInfo); + Map tableProperties = tableInfo.getTableProperties(); + Assertions.assertNull(tableProperties.get(ICEBERG_IDENTIFIER_FIELDS)); + + // add identifier + sql(String.format("ALTER TABLE %s SET IDENTIFIER FIELDS id, name", tableName)); + tableInfo = getTableInfo(tableName); + tableProperties = tableInfo.getTableProperties(); + Assertions.assertEquals("[name,id]", tableProperties.get(ICEBERG_IDENTIFIER_FIELDS)); + + // drop identifier + sql(String.format("ALTER TABLE %s DROP IDENTIFIER 'id1'", tableName)); + tableInfo = getTableInfo(tableName); + tableProperties = tableInfo.getTableProperties(); + Assertions.assertNull(tableProperties.get(ICEBERG_IDENTIFIER_FIELDS)); + } + + private void testIcebergDistributionAndOrderingOperations() { + String tableName = "test_iceberg_distribution_and_ordering_operations"; + dropTableIfExists(tableName); + createSimpleTable(tableName); + + SparkTableInfo tableInfo = getTableInfo(tableName); + Map tableProperties = tableInfo.getTableProperties(); + Assertions.assertNull(tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE)); + Assertions.assertNull(tableProperties.get(ICEBERG_SORT_ORDER)); + + // set globally ordering + sql(String.format("ALTER TABLE %s WRITE ORDERED BY id DESC", tableName)); + tableInfo = getTableInfo(tableName); + tableProperties = tableInfo.getTableProperties(); + Assertions.assertEquals("range", tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE)); + Assertions.assertEquals("id DESC NULLS LAST", tableProperties.get(ICEBERG_SORT_ORDER)); + + // set locally ordering + sql(String.format("ALTER TABLE %s WRITE LOCALLY ORDERED BY id DESC", tableName)); + tableInfo = getTableInfo(tableName); + tableProperties = tableInfo.getTableProperties(); + Assertions.assertEquals("none", tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE)); + Assertions.assertEquals("id DESC NULLS LAST", tableProperties.get(ICEBERG_SORT_ORDER)); + + // set distribution + sql(String.format("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName)); + tableInfo = getTableInfo(tableName); + tableProperties = tableInfo.getTableProperties(); + Assertions.assertEquals("hash", tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE)); + Assertions.assertNull(tableProperties.get(ICEBERG_SORT_ORDER)); + + // set distribution with locally ordering + sql( + String.format( + "ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION LOCALLY ORDERED BY id desc", tableName)); + tableInfo = getTableInfo(tableName); + tableProperties = tableInfo.getTableProperties(); + Assertions.assertEquals("hash", tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE)); + Assertions.assertEquals("id DESC NULLS LAST", tableProperties.get(ICEBERG_SORT_ORDER)); + } + private List getIcebergSimpleTableColumn() { return Arrays.asList( SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"), diff --git a/spark-connector/spark-connector/build.gradle.kts b/spark-connector/spark-connector/build.gradle.kts index 1cd7a7f2fb0..9461d13fe1a 100644 --- a/spark-connector/spark-connector/build.gradle.kts +++ b/spark-connector/spark-connector/build.gradle.kts @@ -19,6 +19,7 @@ val sparkMajorVersion: String = sparkVersion.substringBeforeLast(".") val icebergVersion: String = libs.versions.iceberg.get() val kyuubiVersion: String = libs.versions.kyuubi.get() val scalaJava8CompatVersion: String = libs.versions.scala.java.compat.get() +val scalaCollectionCompatVersion: String = libs.versions.scala.collection.compat.get() dependencies { implementation(project(":catalogs:bundled-catalog", configuration = "shadow")) @@ -30,6 +31,7 @@ dependencies { compileOnly("org.apache.spark:spark-catalyst_$scalaVersion:$sparkVersion") compileOnly("org.apache.spark:spark-sql_$scalaVersion:$sparkVersion") compileOnly("org.scala-lang.modules:scala-java8-compat_$scalaVersion:$scalaJava8CompatVersion") + compileOnly("org.scala-lang.modules:scala-collection-compat_$scalaVersion:$scalaCollectionCompatVersion") annotationProcessor(libs.lombok) compileOnly(libs.lombok) diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/extensions/GravitinoIcebergSparkSessionExtensions.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/extensions/GravitinoIcebergSparkSessionExtensions.java new file mode 100644 index 00000000000..3da60c979be --- /dev/null +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/extensions/GravitinoIcebergSparkSessionExtensions.java @@ -0,0 +1,22 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ +package com.datastrato.gravitino.spark.connector.iceberg.extensions; + +import org.apache.spark.sql.SparkSessionExtensions; +import scala.Function1; + +public class GravitinoIcebergSparkSessionExtensions + implements Function1 { + + @Override + public Void apply(SparkSessionExtensions extensions) { + + // planner extensions + extensions.injectPlannerStrategy(IcebergExtendedDataSourceV2Strategy::new); + + // There must be a return value, and Void only supports returning null, not other types. + return null; + } +} diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/extensions/IcebergExtendedDataSourceV2Strategy.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/extensions/IcebergExtendedDataSourceV2Strategy.java new file mode 100644 index 00000000000..7678ec0a237 --- /dev/null +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/extensions/IcebergExtendedDataSourceV2Strategy.java @@ -0,0 +1,206 @@ +/* + * Copyright 2024 Datastrato Pvt Ltd. + * This software is licensed under the Apache License version 2. + */ +package com.datastrato.gravitino.spark.connector.iceberg.extensions; + +import com.datastrato.gravitino.spark.connector.iceberg.GravitinoIcebergCatalog; +import org.apache.iceberg.spark.Spark3Util; +import org.apache.spark.sql.SparkSession; +import org.apache.spark.sql.catalyst.plans.logical.AddPartitionField; +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceBranch; +import org.apache.spark.sql.catalyst.plans.logical.CreateOrReplaceTag; +import org.apache.spark.sql.catalyst.plans.logical.DropBranch; +import org.apache.spark.sql.catalyst.plans.logical.DropIdentifierFields; +import org.apache.spark.sql.catalyst.plans.logical.DropPartitionField; +import org.apache.spark.sql.catalyst.plans.logical.DropTag; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.ReplacePartitionField; +import org.apache.spark.sql.catalyst.plans.logical.SetIdentifierFields; +import org.apache.spark.sql.catalyst.plans.logical.SetWriteDistributionAndOrdering; +import org.apache.spark.sql.connector.catalog.CatalogPlugin; +import org.apache.spark.sql.connector.catalog.Identifier; +import org.apache.spark.sql.connector.catalog.TableCatalog; +import org.apache.spark.sql.execution.SparkPlan; +import org.apache.spark.sql.execution.datasources.v2.AddPartitionFieldExec; +import org.apache.spark.sql.execution.datasources.v2.CreateOrReplaceBranchExec; +import org.apache.spark.sql.execution.datasources.v2.CreateOrReplaceTagExec; +import org.apache.spark.sql.execution.datasources.v2.DropBranchExec; +import org.apache.spark.sql.execution.datasources.v2.DropIdentifierFieldsExec; +import org.apache.spark.sql.execution.datasources.v2.DropPartitionFieldExec; +import org.apache.spark.sql.execution.datasources.v2.DropTagExec; +import org.apache.spark.sql.execution.datasources.v2.ExtendedDataSourceV2Strategy; +import org.apache.spark.sql.execution.datasources.v2.ReplacePartitionFieldExec; +import org.apache.spark.sql.execution.datasources.v2.SetIdentifierFieldsExec; +import org.apache.spark.sql.execution.datasources.v2.SetWriteDistributionAndOrderingExec; +import scala.Option; +import scala.Some; +import scala.Tuple2; +import scala.collection.Seq; + +public class IcebergExtendedDataSourceV2Strategy extends ExtendedDataSourceV2Strategy { + + private final SparkSession spark; + + public IcebergExtendedDataSourceV2Strategy(SparkSession spark) { + super(spark); + this.spark = spark; + } + + @Override + public Seq apply(LogicalPlan plan) { + if (plan instanceof AddPartitionField) { + AddPartitionField addPartitionField = (AddPartitionField) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier(spark, addPartitionField.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new AddPartitionFieldExec( + catalog, identifier, addPartitionField.transform(), addPartitionField.name()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof CreateOrReplaceBranch) { + CreateOrReplaceBranch createOrReplaceBranch = (CreateOrReplaceBranch) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier( + spark, createOrReplaceBranch.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new CreateOrReplaceBranchExec( + catalog, + identifier, + createOrReplaceBranch.branch(), + createOrReplaceBranch.branchOptions(), + createOrReplaceBranch.replace(), + createOrReplaceBranch.ifNotExists()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof CreateOrReplaceTag) { + CreateOrReplaceTag createOrReplaceTag = (CreateOrReplaceTag) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier( + spark, createOrReplaceTag.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new CreateOrReplaceTagExec( + catalog, + identifier, + createOrReplaceTag.tag(), + createOrReplaceTag.tagOptions(), + createOrReplaceTag.replace(), + createOrReplaceTag.ifNotExists()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof DropBranch) { + DropBranch dropBranch = (DropBranch) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier(spark, dropBranch.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new DropBranchExec( + catalog, identifier, dropBranch.branch(), dropBranch.ifExists()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof DropTag) { + DropTag dropTag = (DropTag) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier(spark, dropTag.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new DropTagExec(catalog, identifier, dropTag.tag(), dropTag.ifExists()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof DropPartitionField) { + DropPartitionField dropPartitionField = (DropPartitionField) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier( + spark, dropPartitionField.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new DropPartitionFieldExec( + catalog, identifier, dropPartitionField.transform()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof ReplacePartitionField) { + ReplacePartitionField replacePartitionField = (ReplacePartitionField) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier( + spark, replacePartitionField.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new ReplacePartitionFieldExec( + catalog, + identifier, + replacePartitionField.transformFrom(), + replacePartitionField.transformTo(), + replacePartitionField.name()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof SetIdentifierFields) { + SetIdentifierFields setIdentifierFields = (SetIdentifierFields) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier( + spark, setIdentifierFields.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new SetIdentifierFieldsExec( + catalog, identifier, setIdentifierFields.fields()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof DropIdentifierFields) { + DropIdentifierFields dropIdentifierFields = (DropIdentifierFields) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier( + spark, dropIdentifierFields.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new DropIdentifierFieldsExec( + catalog, identifier, dropIdentifierFields.fields()); + }) + .getOrElse(() -> super.apply(plan)); + } else if (plan instanceof SetWriteDistributionAndOrdering) { + SetWriteDistributionAndOrdering setWriteDistributionAndOrdering = + (SetWriteDistributionAndOrdering) plan; + return IcebergCatalogAndIdentifier.buildCatalogAndIdentifier( + spark, setWriteDistributionAndOrdering.table()) + .map( + catalogAndIdentifier -> { + TableCatalog catalog = catalogAndIdentifier._1(); + Identifier identifier = catalogAndIdentifier._2(); + return new SetWriteDistributionAndOrderingExec( + catalog, + identifier, + setWriteDistributionAndOrdering.distributionMode(), + setWriteDistributionAndOrdering.sortOrder()); + }) + .getOrElse(() -> super.apply(plan)); + } else { + return super.apply(plan); + } + } + + static class IcebergCatalogAndIdentifier { + static Option> buildCatalogAndIdentifier( + SparkSession spark, Seq identifier) { + Spark3Util.CatalogAndIdentifier catalogAndIdentifier = + Spark3Util.catalogAndIdentifier( + spark, scala.collection.JavaConversions.seqAsJavaList(identifier)); + CatalogPlugin catalog = catalogAndIdentifier.catalog(); + if (catalog instanceof GravitinoIcebergCatalog) { + return Some.apply(new Tuple2<>((TableCatalog) catalog, catalogAndIdentifier.identifier())); + } else { + // TODO: support SparkSessionCatalog + return Option.empty(); + } + } + } +} diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java index 201666cc004..57fdb6bfa47 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/plugin/GravitinoDriverPlugin.java @@ -12,6 +12,7 @@ import com.datastrato.gravitino.spark.connector.catalog.GravitinoCatalogManager; import com.datastrato.gravitino.spark.connector.hive.GravitinoHiveCatalog; import com.datastrato.gravitino.spark.connector.iceberg.GravitinoIcebergCatalog; +import com.datastrato.gravitino.spark.connector.iceberg.extensions.GravitinoIcebergSparkSessionExtensions; import com.google.common.base.Preconditions; import java.util.Collections; import java.util.Locale; @@ -35,7 +36,10 @@ public class GravitinoDriverPlugin implements DriverPlugin { private GravitinoCatalogManager catalogManager; private static final String[] GRAVITINO_DRIVER_EXTENSIONS = - new String[] {IcebergSparkSessionExtensions.class.getName()}; + new String[] { + GravitinoIcebergSparkSessionExtensions.class.getName(), + IcebergSparkSessionExtensions.class.getName() + }; @Override public Map init(SparkContext sc, PluginContext pluginContext) {