diff --git a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java index f7da5564809..8465d506587 100644 --- a/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java +++ b/integration-test/src/test/java/com/datastrato/gravitino/integration/test/spark/iceberg/SparkIcebergCatalogIT.java @@ -11,6 +11,8 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.io.File; +import java.sql.Timestamp; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; @@ -18,6 +20,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import java.util.stream.IntStream; import org.apache.hadoop.fs.Path; import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions; import org.apache.spark.SparkConf; @@ -251,6 +254,17 @@ void testIcebergTableRowLevelOperations() { testIcebergMergeIntoUpdateOperation(); } + @Test + void testIcebergCallOperations() { + testIcebergCallRollbackToSnapshot(); + testIcebergCallRollbackToTimestamp(); + testIcebergCallSetCurrentSnapshot(); + testIcebergCallRewriteDataFiles(); + testIcebergCallExpireSnapshots(); + testIcebergCallRewriteManifests(); + testIcebergCallRewritePositionDeleteFiles(); + } + private void testMetadataColumns() { String tableName = "test_metadata_columns"; dropTableIfExists(tableName); @@ -503,6 +517,228 @@ private void testIcebergMergeIntoUpdateOperation() { }); } + private void testIcebergCallRollbackToSnapshot() { + String tableName = + String.format( + "%s.%s.test_iceberg_call_rollback_to_snapshot", getCatalogName(), getDefaultDatabase()); + dropTableIfExists(tableName); + createSimpleTable(tableName); + + sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName)); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + List rows = + getSparkSession() + .sql(String.format("SELECT snapshot_id FROM %s.snapshots", tableName)) + .collectAsList(); + Assertions.assertEquals(1, rows.size()); + long snapshotId = rows.get(0).getLong(0); + sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + sql( + String.format( + "CALL %s.system.rollback_to_snapshot('%s', %d)", + getCatalogName(), tableName, snapshotId)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + } + + private void testIcebergCallRollbackToTimestamp() { + String tableName = + String.format( + "%s.%s.test_iceberg_call_rollback_to_timestamp", + getCatalogName(), getDefaultDatabase()); + dropTableIfExists(tableName); + createSimpleTable(tableName); + + sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName)); + List timestamp = + getSparkSession().sql("SELECT committed_at FROM %S.snapshots").collectAsList(); + Assertions.assertEquals(1, timestamp.size()); + Timestamp timestampAt = timestamp.get(0).getTimestamp(0); + waitUntilAfter(timestampAt.getTime()); + Timestamp current_timestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + + sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + + sql( + String.format( + "CALL %s.system.rollback_to_timestamp('%s', TIMESTAMP '%s')", + getCatalogName(), tableName, current_timestamp)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + } + + private void testIcebergCallSetCurrentSnapshot() { + String tableName = + String.format( + "%s.%s.test_iceberg_call_set_current_snapshot", getCatalogName(), getDefaultDatabase()); + dropTableIfExists(tableName); + createSimpleTable(tableName); + + sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName)); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + List rows = + getSparkSession() + .sql(String.format("SELECT snapshot_id FROM %s.snapshots", tableName)) + .collectAsList(); + Assertions.assertEquals(1, rows.size()); + long snapshotId = rows.get(0).getLong(0); + sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(2, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData)); + sql( + String.format( + "CALL %s.system.set_current_snapshot('%s', %d)", + getCatalogName(), tableName, snapshotId)); + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(1, tableData.size()); + Assertions.assertEquals("1,1,1", tableData.get(0)); + } + + private void testIcebergCallRewriteDataFiles() { + String tableName = + String.format( + "%s.%s.test_iceberg_call_rewrite_data_files", getCatalogName(), getDefaultDatabase()); + dropTableIfExists(tableName); + createSimpleTable(tableName); + + IntStream.rangeClosed(0, 5) + .forEach( + i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i))); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(5, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + List callResult = + getSparkSession() + .sql( + String.format( + "CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', sort_order => 'id DESC NULLS LAST', where => 'id < 10')", + getCatalogName(), tableName)) + .collectAsList(); + Assertions.assertEquals(1, callResult.size()); + Assertions.assertEquals(5, callResult.get(0).getInt(0)); + Assertions.assertEquals(1, callResult.get(0).getInt(1)); + } + + private void testIcebergCallExpireSnapshots() { + String tableName = + String.format( + "%s.%s.test_iceberg_call_expire_snapshots", getCatalogName(), getDefaultDatabase()); + dropTableIfExists(tableName); + createSimpleTable(tableName); + + IntStream.rangeClosed(0, 5) + .forEach( + i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i))); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(5, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + List timestamp = + getSparkSession() + .sql("SELECT committed_at FROM %s.snapshots ORDER BY committed_at DESC") + .collectAsList(); + Timestamp timestampAt = timestamp.get(0).getTimestamp(0); + waitUntilAfter(timestampAt.getTime()); + Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis())); + List callResult = + getSparkSession() + .sql( + String.format( + "CALL %s.system.expire_snapshots(table => '%s', older_than => TIMESTAMP '%s', max_concurrent_deletes => 4)", + getCatalogName(), tableName, currentTimestamp)) + .collectAsList(); + Assertions.assertEquals(1, callResult.size()); + Assertions.assertEquals(4, callResult.get(0).getInt(4)); + } + + private void testIcebergCallRewriteManifests() { + String tableName = + String.format("%s.%s.rewrite_manifests", getCatalogName(), getDefaultDatabase()); + dropTableIfExists(tableName); + createSimpleTable(tableName); + + IntStream.rangeClosed(0, 5) + .forEach( + i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i))); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(5, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + + List callResult = + getSparkSession() + .sql( + String.format( + "CALL %s.system.rewrite_manifests(table => '%s', use_caching => false)", + getCatalogName(), tableName)) + .collectAsList(); + Assertions.assertEquals(1, callResult.size()); + Assertions.assertEquals(5, callResult.get(0).getInt(0)); + Assertions.assertEquals(1, callResult.get(0).getInt(1)); + } + + private void testIcebergCallRewritePositionDeleteFiles() { + String tableName = + String.format( + "%s.%s.rewrite_position_delete_files", getCatalogName(), getDefaultDatabase()); + dropTableIfExists(tableName); + createIcebergTableWithTabProperties( + tableName, + false, + ImmutableMap.of(ICEBERG_FORMAT_VERSION, "2", ICEBERG_DELETE_MODE, "merge-on-read")); + + sql( + String.format( + "INSERT INTO %s VALUES(1, '1', 1), (2, '2', 2), (3, '3', 3), (4, '4', 4), (5, '5', 5)", + tableName)); + List tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(5, tableData.size()); + Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + + List delete_files = + getQueryData(String.format("SELECT * FROM %s.all_delete_files", tableName)); + Assertions.assertEquals(0, delete_files.size()); + + sql(String.format("DELETE FROM %s WHERE id = 1", tableName)); + sql(String.format("DELETE FROM %s WHERE id = 2", tableName)); + + tableData = getQueryData(getSelectAllSqlWithOrder(tableName)); + Assertions.assertEquals(3, tableData.size()); + Assertions.assertEquals("3,3,3;4,4,4;5,5,5", String.join(";", tableData)); + + delete_files = getQueryData(String.format("SELECT * FROM %s.all_delete_files", tableName)); + Assertions.assertEquals(2, delete_files.size()); + + List callResult = + getSparkSession() + .sql( + String.format( + "CALL %s.system.rewrite_position_delete_files(table => '%s', options => map('rewrite-all','true'))", + getCatalogName(), tableName)) + .collectAsList(); + Assertions.assertEquals(1, callResult.size()); + Assertions.assertEquals(2, callResult.get(0).getInt(0)); + Assertions.assertEquals(1, callResult.get(0).getInt(1)); + + delete_files = getQueryData(String.format("SELECT * FROM %s.all_delete_files", tableName)); + Assertions.assertEquals(3, delete_files.size()); + } + private List getIcebergSimpleTableColumn() { return Arrays.asList( SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"), @@ -555,4 +791,11 @@ private void createIcebergTableWithTabProperties( tableName, partitionedClause, tblPropertiesStr); sql(createSql); } + + private void waitUntilAfter(Long timestampMillis) { + long current = System.currentTimeMillis(); + while (current <= timestampMillis) { + current = System.currentTimeMillis(); + } + } } diff --git a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java index 5355dbc3dfd..b104aacb331 100644 --- a/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java +++ b/spark-connector/spark-connector/src/main/java/com/datastrato/gravitino/spark/connector/iceberg/GravitinoIcebergCatalog.java @@ -17,10 +17,13 @@ import org.apache.iceberg.spark.SparkCatalog; import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException; import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException; import org.apache.spark.sql.connector.catalog.FunctionCatalog; import org.apache.spark.sql.connector.catalog.Identifier; import org.apache.spark.sql.connector.catalog.TableCatalog; import org.apache.spark.sql.connector.catalog.functions.UnboundFunction; +import org.apache.spark.sql.connector.iceberg.catalog.Procedure; +import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog; import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** @@ -30,7 +33,8 @@ * StagingTableCatalog and FunctionCatalog, allowing for advanced operations like table staging and * function management tailored to the needs of Iceberg tables. */ -public class GravitinoIcebergCatalog extends BaseCatalog implements FunctionCatalog { +public class GravitinoIcebergCatalog extends BaseCatalog + implements FunctionCatalog, ProcedureCatalog { @Override protected TableCatalog createAndInitSparkCatalog( @@ -101,6 +105,11 @@ public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionExcep return ((SparkCatalog) sparkCatalog).loadFunction(ident); } + @Override + public Procedure loadProcedure(Identifier identifier) throws NoSuchProcedureException { + return ((SparkCatalog) sparkCatalog).loadProcedure(identifier); + } + private void initHiveProperties( String catalogBackend, Map gravitinoProperties,