Skip to content

Commit

Permalink
[datastrato#3186] feat(spark-connector): Support Iceberg Spark Procedure
Browse files Browse the repository at this point in the history
  • Loading branch information
caican00 committed May 3, 2024
1 parent 4d334aa commit bbe1399
Show file tree
Hide file tree
Showing 2 changed files with 253 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.File;
import java.sql.Timestamp;
import java.time.Instant;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.hadoop.fs.Path;
import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions;
import org.apache.spark.SparkConf;
Expand Down Expand Up @@ -251,6 +254,17 @@ void testIcebergTableRowLevelOperations() {
testIcebergMergeIntoUpdateOperation();
}

@Test
void testIcebergCallOperations() {
testIcebergCallRollbackToSnapshot();
testIcebergCallRollbackToTimestamp();
testIcebergCallSetCurrentSnapshot();
testIcebergCallRewriteDataFiles();
testIcebergCallExpireSnapshots();
testIcebergCallRewriteManifests();
testIcebergCallRewritePositionDeleteFiles();
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
Expand Down Expand Up @@ -503,6 +517,228 @@ private void testIcebergMergeIntoUpdateOperation() {
});
}

private void testIcebergCallRollbackToSnapshot() {
String tableName =
String.format(
"%s.%s.test_iceberg_call_rollback_to_snapshot", getCatalogName(), getDefaultDatabase());
dropTableIfExists(tableName);
createSimpleTable(tableName);

sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
List<Row> rows =
getSparkSession()
.sql(String.format("SELECT snapshot_id FROM %s.snapshots", tableName))
.collectAsList();
Assertions.assertEquals(1, rows.size());
long snapshotId = rows.get(0).getLong(0);
sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));
sql(
String.format(
"CALL %s.system.rollback_to_snapshot('%s', %d)",
getCatalogName(), tableName, snapshotId));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
}

private void testIcebergCallRollbackToTimestamp() {
String tableName =
String.format(
"%s.%s.test_iceberg_call_rollback_to_timestamp",
getCatalogName(), getDefaultDatabase());
dropTableIfExists(tableName);
createSimpleTable(tableName);

sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName));
List<Row> timestamp =
getSparkSession().sql("SELECT committed_at FROM %S.snapshots").collectAsList();
Assertions.assertEquals(1, timestamp.size());
Timestamp timestampAt = timestamp.get(0).getTimestamp(0);
waitUntilAfter(timestampAt.getTime());
Timestamp current_timestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis()));

List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

sql(
String.format(
"CALL %s.system.rollback_to_timestamp('%s', TIMESTAMP '%s')",
getCatalogName(), tableName, current_timestamp));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
}

private void testIcebergCallSetCurrentSnapshot() {
String tableName =
String.format(
"%s.%s.test_iceberg_call_set_current_snapshot", getCatalogName(), getDefaultDatabase());
dropTableIfExists(tableName);
createSimpleTable(tableName);

sql(String.format("INSERT INTO %s VALUES(1, '1', 1)", tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
List<Row> rows =
getSparkSession()
.sql(String.format("SELECT snapshot_id FROM %s.snapshots", tableName))
.collectAsList();
Assertions.assertEquals(1, rows.size());
long snapshotId = rows.get(0).getLong(0);
sql(String.format("INSERT INTO %s VALUES(2, '2', 2)", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));
sql(
String.format(
"CALL %s.system.set_current_snapshot('%s', %d)",
getCatalogName(), tableName, snapshotId));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
}

private void testIcebergCallRewriteDataFiles() {
String tableName =
String.format(
"%s.%s.test_iceberg_call_rewrite_data_files", getCatalogName(), getDefaultDatabase());
dropTableIfExists(tableName);
createSimpleTable(tableName);

IntStream.rangeClosed(0, 5)
.forEach(
i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i)));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(5, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData));
List<Row> callResult =
getSparkSession()
.sql(
String.format(
"CALL %s.system.rewrite_data_files(table => '%s', strategy => 'sort', sort_order => 'id DESC NULLS LAST', where => 'id < 10')",
getCatalogName(), tableName))
.collectAsList();
Assertions.assertEquals(1, callResult.size());
Assertions.assertEquals(5, callResult.get(0).getInt(0));
Assertions.assertEquals(1, callResult.get(0).getInt(1));
}

private void testIcebergCallExpireSnapshots() {
String tableName =
String.format(
"%s.%s.test_iceberg_call_expire_snapshots", getCatalogName(), getDefaultDatabase());
dropTableIfExists(tableName);
createSimpleTable(tableName);

IntStream.rangeClosed(0, 5)
.forEach(
i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i)));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(5, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData));
List<Row> timestamp =
getSparkSession()
.sql("SELECT committed_at FROM %s.snapshots ORDER BY committed_at DESC")
.collectAsList();
Timestamp timestampAt = timestamp.get(0).getTimestamp(0);
waitUntilAfter(timestampAt.getTime());
Timestamp currentTimestamp = Timestamp.from(Instant.ofEpochMilli(System.currentTimeMillis()));
List<Row> callResult =
getSparkSession()
.sql(
String.format(
"CALL %s.system.expire_snapshots(table => '%s', older_than => TIMESTAMP '%s', max_concurrent_deletes => 4)",
getCatalogName(), tableName, currentTimestamp))
.collectAsList();
Assertions.assertEquals(1, callResult.size());
Assertions.assertEquals(4, callResult.get(0).getInt(4));
}

private void testIcebergCallRewriteManifests() {
String tableName =
String.format("%s.%s.rewrite_manifests", getCatalogName(), getDefaultDatabase());
dropTableIfExists(tableName);
createSimpleTable(tableName);

IntStream.rangeClosed(0, 5)
.forEach(
i -> sql(String.format("INSERT INTO %s VALUES(%d, '%d', %d)", tableName, i, i, i)));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(5, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData));

List<Row> callResult =
getSparkSession()
.sql(
String.format(
"CALL %s.system.rewrite_manifests(table => '%s', use_caching => false)",
getCatalogName(), tableName))
.collectAsList();
Assertions.assertEquals(1, callResult.size());
Assertions.assertEquals(5, callResult.get(0).getInt(0));
Assertions.assertEquals(1, callResult.get(0).getInt(1));
}

private void testIcebergCallRewritePositionDeleteFiles() {
String tableName =
String.format(
"%s.%s.rewrite_position_delete_files", getCatalogName(), getDefaultDatabase());
dropTableIfExists(tableName);
createIcebergTableWithTabProperties(
tableName,
false,
ImmutableMap.of(ICEBERG_FORMAT_VERSION, "2", ICEBERG_DELETE_MODE, "merge-on-read"));

sql(
String.format(
"INSERT INTO %s VALUES(1, '1', 1), (2, '2', 2), (3, '3', 3), (4, '4', 4), (5, '5', 5)",
tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(5, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2;3,3,3;4,4,4;5,5,5", String.join(";", tableData));

List<String> delete_files =
getQueryData(String.format("SELECT * FROM %s.all_delete_files", tableName));
Assertions.assertEquals(0, delete_files.size());

sql(String.format("DELETE FROM %s WHERE id = 1", tableName));
sql(String.format("DELETE FROM %s WHERE id = 2", tableName));

tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(3, tableData.size());
Assertions.assertEquals("3,3,3;4,4,4;5,5,5", String.join(";", tableData));

delete_files = getQueryData(String.format("SELECT * FROM %s.all_delete_files", tableName));
Assertions.assertEquals(2, delete_files.size());

List<Row> callResult =
getSparkSession()
.sql(
String.format(
"CALL %s.system.rewrite_position_delete_files(table => '%s', options => map('rewrite-all','true'))",
getCatalogName(), tableName))
.collectAsList();
Assertions.assertEquals(1, callResult.size());
Assertions.assertEquals(2, callResult.get(0).getInt(0));
Assertions.assertEquals(1, callResult.get(0).getInt(1));

delete_files = getQueryData(String.format("SELECT * FROM %s.all_delete_files", tableName));
Assertions.assertEquals(3, delete_files.size());
}

private List<SparkTableInfo.SparkColumnInfo> getIcebergSimpleTableColumn() {
return Arrays.asList(
SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"),
Expand Down Expand Up @@ -555,4 +791,11 @@ private void createIcebergTableWithTabProperties(
tableName, partitionedClause, tblPropertiesStr);
sql(createSql);
}

private void waitUntilAfter(Long timestampMillis) {
long current = System.currentTimeMillis();
while (current <= timestampMillis) {
current = System.currentTimeMillis();
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import org.apache.iceberg.spark.SparkCatalog;
import org.apache.spark.sql.catalyst.analysis.NoSuchFunctionException;
import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException;
import org.apache.spark.sql.catalyst.analysis.NoSuchProcedureException;
import org.apache.spark.sql.connector.catalog.FunctionCatalog;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.connector.iceberg.catalog.Procedure;
import org.apache.spark.sql.connector.iceberg.catalog.ProcedureCatalog;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;

/**
Expand All @@ -30,7 +33,8 @@
* StagingTableCatalog and FunctionCatalog, allowing for advanced operations like table staging and
* function management tailored to the needs of Iceberg tables.
*/
public class GravitinoIcebergCatalog extends BaseCatalog implements FunctionCatalog {
public class GravitinoIcebergCatalog extends BaseCatalog
implements FunctionCatalog, ProcedureCatalog {

@Override
protected TableCatalog createAndInitSparkCatalog(
Expand Down Expand Up @@ -101,6 +105,11 @@ public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionExcep
return ((SparkCatalog) sparkCatalog).loadFunction(ident);
}

@Override
public Procedure loadProcedure(Identifier identifier) throws NoSuchProcedureException {
return ((SparkCatalog) sparkCatalog).loadProcedure(identifier);
}

private void initHiveProperties(
String catalogBackend,
Map<String, String> gravitinoProperties,
Expand Down

0 comments on commit bbe1399

Please sign in to comment.