Skip to content

Commit

Permalink
[#3187] feat(spark-connector): Support SparkSQL extended syntax in Ic…
Browse files Browse the repository at this point in the history
…eberg
  • Loading branch information
caican00 committed May 4, 2024
1 parent 4d334aa commit 49baf5c
Show file tree
Hide file tree
Showing 5 changed files with 502 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.hadoop.fs.Path;
import org.apache.iceberg.exceptions.ValidationException;
import org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions;
import org.apache.spark.SparkConf;
import org.apache.spark.sql.Column;
Expand Down Expand Up @@ -48,6 +49,9 @@ public abstract class SparkIcebergCatalogIT extends SparkCommonIT {
private static final String ICEBERG_DELETE_MODE = "write.delete.mode";
private static final String ICEBERG_UPDATE_MODE = "write.update.mode";
private static final String ICEBERG_MERGE_MODE = "write.merge.mode";
private static final String ICEBERG_WRITE_DISTRIBUTION_MODE = "write.distribution-mode";
private static final String ICEBERG_SORT_ORDER = "sort-order";
private static final String ICEBERG_IDENTIFIER_FIELDS = "identifier-fields";

@Override
protected String getCatalogName() {
Expand Down Expand Up @@ -251,6 +255,15 @@ void testIcebergTableRowLevelOperations() {
testIcebergMergeIntoUpdateOperation();
}

@Test
void testIcebergSQLExtensions() {
testIcebergPartitionFieldOperations();
testIcebergBranchOperations();
testIcebergTagOperations();
testIcebergIdentifierOperations();
testIcebergDistributionAndOrderingOperations();
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
Expand Down Expand Up @@ -503,6 +516,260 @@ private void testIcebergMergeIntoUpdateOperation() {
});
}

private void testIcebergPartitionFieldOperations() {
List<String> partitionFields =
Arrays.asList("name", "truncate(1, name), bucket(16, id), days(ts)");
String partitionExpression = "name=a/name_trunc=a/id_bucket=4/ts_day=2024-01-01";
String tableName = "test_iceberg_partition_field_operations";
dropTableIfExists(tableName);
sql(getCreateIcebergSimpleTableString(tableName));

// add partition fields
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getIcebergSimpleTableColumn())
.withMetadataColumns(getIcebergMetadataColumns());
checker.check(tableInfo);

partitionFields.forEach(
partitionField ->
sql(String.format("ALTER TABLE %s ADD PARTITION FIELD %s", tableName, partitionField)));

tableInfo = getTableInfo(tableName);
checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getIcebergSimpleTableColumn())
.withMetadataColumns(getIcebergMetadataColumns())
.withIdentifyPartition(Collections.singletonList("name"))
.withTruncatePartition(1, "name")
.withBucketPartition(16, Collections.singletonList("id"))
.withDayPartition("ts");
checker.check(tableInfo);

sql(
String.format(
"INSERT INTO %s VALUES(2,'a',cast('2024-01-01 12:00:00' as timestamp));", tableName));
List<String> queryResult = getTableData(tableName);
Assertions.assertEquals(1, queryResult.size());
Assertions.assertEquals("2,a,2024-01-01 12:00:00", queryResult.get(0));
Path partitionPath = new Path(getTableLocation(tableInfo), partitionExpression);
checkDirExists(partitionPath);

// replace partition fields
sql(String.format("ALTER TABLE %s REPLACE PARTITION FIELD ts_day WITH months(ts)", tableName));
tableInfo = getTableInfo(tableName);
checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getIcebergSimpleTableColumn())
.withMetadataColumns(getIcebergMetadataColumns())
.withIdentifyPartition(Collections.singletonList("name"))
.withTruncatePartition(1, "name")
.withBucketPartition(16, Collections.singletonList("id"))
.withMonthPartition("ts");
checker.check(tableInfo);

// drop partition fields
sql(String.format("ALTER TABLE %s DROP PARTITION FIELD months(ts)", tableName));
tableInfo = getTableInfo(tableName);
checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(getIcebergSimpleTableColumn())
.withMetadataColumns(getIcebergMetadataColumns())
.withIdentifyPartition(Collections.singletonList("name"))
.withTruncatePartition(1, "name")
.withBucketPartition(16, Collections.singletonList("id"));
checker.check(tableInfo);
}

private void testIcebergBranchOperations() {
String tableName = "test_iceberg_branch_operations";
String fullTableName =
String.format("%s.%s.%s", getCatalogName(), getDefaultDatabase(), tableName);
String branch1 = "branch1";
dropTableIfExists(tableName);
createSimpleTable(tableName);

// create branch and query data using branch
sql(String.format("INSERT INTO %s VALUES(1, '1', 1);", tableName));
List<String> tableData = getTableData(tableName);
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
List<Row> snapshots =
getSparkSession()
.sql(String.format("SELECT snapshot_id FROM %s.snapshots", fullTableName))
.collectAsList();
Assertions.assertEquals(1, snapshots.size());
long snapshotId = snapshots.get(0).getLong(0);

sql(String.format("ALTER TABLE %s CREATE BRANCH IF NOT EXISTS `%s`", tableName, branch1));
sql(String.format("INSERT INTO %s VALUES(2, '2', 2);", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

tableData =
getQueryData(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, branch1));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

sql(String.format("ALTER TABLE %s CREATE OR REPLACE BRANCH `%s`", tableName, branch1));
tableData =
getQueryData(
String.format("SELECT * FROM %s VERSION AS OF '%s' ORDER BY id", tableName, branch1));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

// replace branch
sql(
String.format(
"ALTER TABLE %s REPLACE BRANCH `%s` AS OF VERSION %d RETAIN 1 DAYS",
tableName, branch1, snapshotId));
tableData =
getQueryData(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, branch1));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

// drop branch
sql(String.format("ALTER TABLE %s DROP BRANCH `%s`", tableName, branch1));
Assertions.assertThrows(
ValidationException.class,
() -> sql(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, branch1)));
}

private void testIcebergTagOperations() {
String tableName = "test_iceberg_tag_operations";
String fullTableName =
String.format("%s.%s.%s", getCatalogName(), getDefaultDatabase(), tableName);
String tag1 = "tag1";
dropTableIfExists(tableName);
createSimpleTable(tableName);

// create tag and query data using tag
sql(String.format("INSERT INTO %s VALUES(1, '1', 1);", tableName));
List<String> tableData = getTableData(tableName);
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
List<Row> snapshots =
getSparkSession()
.sql(String.format("SELECT snapshot_id FROM %s.snapshots", fullTableName))
.collectAsList();
Assertions.assertEquals(1, snapshots.size());
long snapshotId = snapshots.get(0).getLong(0);

sql(String.format("ALTER TABLE %s CREATE TAG IF NOT EXISTS `%s`", tableName, tag1));
sql(String.format("INSERT INTO %s VALUES(2, '2', 2);", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

tableData = getQueryData(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, tag1));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

sql(String.format("ALTER TABLE %s CREATE OR REPLACE TAG `%s`", tableName, tag1));
tableData =
getQueryData(
String.format("SELECT * FROM %s VERSION AS OF '%s' ORDER BY id", tableName, tag1));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

// replace tag
sql(
String.format(
"ALTER TABLE %s REPLACE TAG `%s` AS OF VERSION %d RETAIN 1 DAYS",
tableName, tag1, snapshotId));
tableData = getQueryData(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, tag1));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

// drop tag
sql(String.format("ALTER TABLE %s DROP TAG `%s`", tableName, tag1));
Assertions.assertThrows(
ValidationException.class,
() -> sql(String.format("SELECT * FROM %s VERSION AS OF '%s'", tableName, tag1)));
}

private void testIcebergIdentifierOperations() {
String tableName = "test_iceberg_identifier_operations";
List<SparkTableInfo.SparkColumnInfo> columnInfos =
Arrays.asList(
SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment", false),
SparkTableInfo.SparkColumnInfo.of("name", DataTypes.StringType, "", false),
SparkTableInfo.SparkColumnInfo.of("ts", DataTypes.TimestampType, null, true));
dropTableIfExists(tableName);
sql(
String.format(
"CREATE TABLE %s (id INT COMMENT 'id comment' NOT NULL, name STRING COMMENT '' NOT NULL, age INT)",
tableName));
SparkTableInfo tableInfo = getTableInfo(tableName);
SparkTableInfoChecker checker =
SparkTableInfoChecker.create()
.withName(tableName)
.withColumns(columnInfos)
.withMetadataColumns(getIcebergMetadataColumns());
checker.check(tableInfo);
Map<String, String> tableProperties = tableInfo.getTableProperties();
Assertions.assertNull(tableProperties.get(ICEBERG_IDENTIFIER_FIELDS));

// add identifier
sql(String.format("ALTER TABLE %s SET IDENTIFIER FIELDS id, name", tableName));
tableInfo = getTableInfo(tableName);
tableProperties = tableInfo.getTableProperties();
Assertions.assertEquals("[name,id]", tableProperties.get(ICEBERG_IDENTIFIER_FIELDS));

// drop identifier
sql(String.format("ALTER TABLE %s DROP IDENTIFIER 'id1'", tableName));
tableInfo = getTableInfo(tableName);
tableProperties = tableInfo.getTableProperties();
Assertions.assertNull(tableProperties.get(ICEBERG_IDENTIFIER_FIELDS));
}

private void testIcebergDistributionAndOrderingOperations() {
String tableName = "test_iceberg_distribution_and_ordering_operations";
dropTableIfExists(tableName);
createSimpleTable(tableName);

SparkTableInfo tableInfo = getTableInfo(tableName);
Map<String, String> tableProperties = tableInfo.getTableProperties();
Assertions.assertNull(tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE));
Assertions.assertNull(tableProperties.get(ICEBERG_SORT_ORDER));

// set globally ordering
sql(String.format("ALTER TABLE %s WRITE ORDERED BY id DESC", tableName));
tableInfo = getTableInfo(tableName);
tableProperties = tableInfo.getTableProperties();
Assertions.assertEquals("range", tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE));
Assertions.assertEquals("id DESC NULLS LAST", tableProperties.get(ICEBERG_SORT_ORDER));

// set locally ordering
sql(String.format("ALTER TABLE %s WRITE LOCALLY ORDERED BY id DESC", tableName));
tableInfo = getTableInfo(tableName);
tableProperties = tableInfo.getTableProperties();
Assertions.assertEquals("none", tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE));
Assertions.assertEquals("id DESC NULLS LAST", tableProperties.get(ICEBERG_SORT_ORDER));

// set distribution
sql(String.format("ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION", tableName));
tableInfo = getTableInfo(tableName);
tableProperties = tableInfo.getTableProperties();
Assertions.assertEquals("hash", tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE));
Assertions.assertNull(tableProperties.get(ICEBERG_SORT_ORDER));

// set distribution with locally ordering
sql(
String.format(
"ALTER TABLE %s WRITE DISTRIBUTED BY PARTITION LOCALLY ORDERED BY id desc", tableName));
tableInfo = getTableInfo(tableName);
tableProperties = tableInfo.getTableProperties();
Assertions.assertEquals("hash", tableProperties.get(ICEBERG_WRITE_DISTRIBUTION_MODE));
Assertions.assertEquals("id DESC NULLS LAST", tableProperties.get(ICEBERG_SORT_ORDER));
}

private List<SparkTableInfo.SparkColumnInfo> getIcebergSimpleTableColumn() {
return Arrays.asList(
SparkTableInfo.SparkColumnInfo.of("id", DataTypes.IntegerType, "id comment"),
Expand Down
2 changes: 2 additions & 0 deletions spark-connector/spark-connector/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ val sparkMajorVersion: String = sparkVersion.substringBeforeLast(".")
val icebergVersion: String = libs.versions.iceberg.get()
val kyuubiVersion: String = libs.versions.kyuubi.get()
val scalaJava8CompatVersion: String = libs.versions.scala.java.compat.get()
val scalaCollectionCompatVersion: String = libs.versions.scala.collection.compat.get()

dependencies {
implementation(project(":catalogs:bundled-catalog", configuration = "shadow"))
Expand All @@ -30,6 +31,7 @@ dependencies {
compileOnly("org.apache.spark:spark-catalyst_$scalaVersion:$sparkVersion")
compileOnly("org.apache.spark:spark-sql_$scalaVersion:$sparkVersion")
compileOnly("org.scala-lang.modules:scala-java8-compat_$scalaVersion:$scalaJava8CompatVersion")
compileOnly("org.scala-lang.modules:scala-collection-compat_$scalaVersion:$scalaCollectionCompatVersion")
annotationProcessor(libs.lombok)
compileOnly(libs.lombok)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
/*
* Copyright 2024 Datastrato Pvt Ltd.
* This software is licensed under the Apache License version 2.
*/
package com.datastrato.gravitino.spark.connector.iceberg.extensions;

import org.apache.spark.sql.SparkSessionExtensions;
import scala.Function1;

public class GravitinoIcebergSparkSessionExtensions
implements Function1<SparkSessionExtensions, Void> {

@Override
public Void apply(SparkSessionExtensions extensions) {

// planner extensions
extensions.injectPlannerStrategy(IcebergExtendedDataSourceV2Strategy::new);

// There must be a return value, and Void only supports returning null, not other types.
return null;
}
}

0 comments on commit 49baf5c

Please sign in to comment.