Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#3264] feat(spark-connector): Support Iceberg time travel in SQL queries #3265

Merged
merged 15 commits into from
May 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import com.datastrato.gravitino.integration.test.util.spark.SparkMetadataColumnInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfo;
import com.datastrato.gravitino.integration.test.util.spark.SparkTableInfoChecker;
import com.datastrato.gravitino.spark.connector.iceberg.SparkIcebergTable;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.io.File;
Expand All @@ -17,6 +18,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import lombok.Data;
import org.apache.hadoop.fs.Path;
Expand All @@ -31,6 +33,8 @@
import org.apache.spark.sql.connector.catalog.CatalogPlugin;
import org.apache.spark.sql.connector.catalog.FunctionCatalog;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.connector.catalog.functions.UnboundFunction;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
Expand Down Expand Up @@ -239,6 +243,75 @@ void testIcebergTableRowLevelOperations(IcebergTableWriteProperties icebergTable
testIcebergMergeIntoUpdateOperation(icebergTableWriteProperties);
}

@Test
void testIcebergAsOfQuery() throws NoSuchTableException {
caican00 marked this conversation as resolved.
Show resolved Hide resolved
String tableName = "test_iceberg_as_of_query";
dropTableIfExists(tableName);
createSimpleTable(tableName);
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));

sql(String.format("INSERT INTO %s VALUES (1, '1', 1)", tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
long snapshotId = getCurrentSnapshotId(tableName);
sparkIcebergTable.table().manageSnapshots().createBranch("test_branch", snapshotId).commit();
sparkIcebergTable.table().manageSnapshots().createTag("test_tag", snapshotId).commit();
long snapshotTimestamp = getCurrentSnapshotTimestamp(tableName);
long timestamp = waitUntilAfter(snapshotTimestamp + 1000);
waitUntilAfter(timestamp + 1000);
caican00 marked this conversation as resolved.
Show resolved Hide resolved
long timestampInSeconds = TimeUnit.MILLISECONDS.toSeconds(timestamp);

// create a second snapshot
sql(String.format("INSERT INTO %s VALUES (2, '2', 2)", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

tableData =
getQueryData(
String.format("SELECT * FROM %s TIMESTAMP AS OF %s", tableName, timestampInSeconds));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format(
"SELECT * FROM %s FOR SYSTEM_TIME AS OF %s", tableName, timestampInSeconds));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData =
getQueryData(String.format("SELECT * FROM %s VERSION AS OF %d", tableName, snapshotId));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF %d", tableName, snapshotId));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData =
getQueryData(String.format("SELECT * FROM %s VERSION AS OF 'test_branch'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_branch'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData = getQueryData(String.format("SELECT * FROM %s VERSION AS OF 'test_tag'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_tag'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
Expand Down Expand Up @@ -559,4 +632,32 @@ static IcebergTableWriteProperties of(
return new IcebergTableWriteProperties(isPartitionedTable, formatVersion, writeMode);
}
}

private SparkIcebergTable getSparkIcebergTableInstance(String tableName)
caican00 marked this conversation as resolved.
Show resolved Hide resolved
throws NoSuchTableException {
CatalogPlugin catalogPlugin =
getSparkSession().sessionState().catalogManager().catalog(getCatalogName());
Assertions.assertInstanceOf(TableCatalog.class, catalogPlugin);
TableCatalog catalog = (TableCatalog) catalogPlugin;
Table table = catalog.loadTable(Identifier.of(new String[] {getDefaultDatabase()}, tableName));
return (SparkIcebergTable) table;
}

private long getCurrentSnapshotTimestamp(String tableName) throws NoSuchTableException {
SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
return sparkIcebergTable.table().currentSnapshot().timestampMillis();
}

private long getCurrentSnapshotId(String tableName) throws NoSuchTableException {
SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
return sparkIcebergTable.table().currentSnapshot().snapshotId();
}

private long waitUntilAfter(Long timestampMillis) {
long current = System.currentTimeMillis();
while (current <= timestampMillis) {
current = System.currentTimeMillis();
}
return current;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ public abstract class BaseCatalog implements TableCatalog, SupportsNamespaces {
protected Catalog gravitinoCatalogClient;
protected PropertiesConverter propertiesConverter;
protected SparkTransformConverter sparkTransformConverter;
protected final String metalakeName;
caican00 marked this conversation as resolved.
Show resolved Hide resolved
protected String catalogName;

private final String metalakeName;
private String catalogName;
private final GravitinoCatalogManager gravitinoCatalogManager;

protected BaseCatalog() {
Expand All @@ -92,6 +92,7 @@ protected abstract TableCatalog createAndInitSparkCatalog(
*
* @param identifier Spark's table identifier
* @param gravitinoTable Gravitino table to do DDL operations
* @param sparkTable Spark internal table to do IO operations
* @param sparkCatalog specific Spark catalog to do IO operations
* @param propertiesConverter transform properties between Gravitino and Spark
* @param sparkTransformConverter sparkTransformConverter convert transforms between Gravitino and
Expand All @@ -101,6 +102,7 @@ protected abstract TableCatalog createAndInitSparkCatalog(
protected abstract Table createSparkTable(
Identifier identifier,
com.datastrato.gravitino.rel.Table gravitinoTable,
Table sparkTable,
TableCatalog sparkCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter);
Expand Down Expand Up @@ -194,8 +196,14 @@ public Table createTable(
partitionings,
distributionAndSortOrdersInfo.getDistribution(),
distributionAndSortOrdersInfo.getSortOrders());
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (NoSuchSchemaException e) {
throw new NoSuchNamespaceException(ident.namespace());
} catch (com.datastrato.gravitino.exceptions.TableAlreadyExistsException e) {
Expand All @@ -211,9 +219,15 @@ public Table loadTable(Identifier ident) throws NoSuchTableException {
gravitinoCatalogClient
.asTableCatalog()
.loadTable(NameIdentifier.of(metalakeName, catalogName, database, ident.name()));
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
// Will create a catalog specific table
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
Expand All @@ -240,8 +254,14 @@ public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchT
.alterTable(
NameIdentifier.of(metalakeName, catalogName, getDatabase(ident), ident.name()),
gravitinoTableChanges);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
Expand Down Expand Up @@ -497,4 +517,16 @@ private static com.datastrato.gravitino.rel.TableChange.ColumnPosition transform
"Unsupported table column position %s", columnPosition.getClass().getName()));
}
}

private Table loadSparkTable(Identifier ident) {
try {
return sparkCatalog.loadTable(ident);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.Map;
import org.apache.kyuubi.spark.connector.hive.HiveTable;
import org.apache.kyuubi.spark.connector.hive.HiveTableCatalog;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
Expand All @@ -34,19 +33,10 @@ protected TableCatalog createAndInitSparkCatalog(
protected org.apache.spark.sql.connector.catalog.Table createSparkTable(
Identifier identifier,
Table gravitinoTable,
org.apache.spark.sql.connector.catalog.Table sparkTable,
TableCatalog sparkHiveCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter) {
org.apache.spark.sql.connector.catalog.Table sparkTable;
try {
sparkTable = sparkHiveCatalog.loadTable(identifier);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(identifier), identifier.name())),
e);
}
return new SparkHiveTable(
identifier,
gravitinoTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package com.datastrato.gravitino.spark.connector.iceberg;

import com.datastrato.gravitino.NameIdentifier;
import com.datastrato.gravitino.rel.Table;
import com.datastrato.gravitino.spark.connector.PropertiesConverter;
import com.datastrato.gravitino.spark.connector.SparkTransformConverter;
Expand Down Expand Up @@ -44,19 +45,10 @@ protected TableCatalog createAndInitSparkCatalog(
protected org.apache.spark.sql.connector.catalog.Table createSparkTable(
Identifier identifier,
Table gravitinoTable,
org.apache.spark.sql.connector.catalog.Table sparkTable,
TableCatalog sparkIcebergCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter) {
org.apache.spark.sql.connector.catalog.Table sparkTable;
try {
sparkTable = sparkIcebergCatalog.loadTable(identifier);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(identifier), identifier.name())),
e);
}
return new SparkIcebergTable(
identifier,
gravitinoTable,
Expand Down Expand Up @@ -85,4 +77,80 @@ public Identifier[] listFunctions(String[] namespace) throws NoSuchNamespaceExce
public UnboundFunction loadFunction(Identifier ident) throws NoSuchFunctionException {
return ((SparkCatalog) sparkCatalog).loadFunction(ident);
}

@Override
public org.apache.spark.sql.connector.catalog.Table loadTable(Identifier ident, String version)
throws NoSuchTableException {
try {
com.datastrato.gravitino.rel.Table gravitinoTable = loadGravitinoTable(ident);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident, version);
// Will create a catalog specific table
return createSparkTable(
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

@Override
public org.apache.spark.sql.connector.catalog.Table loadTable(Identifier ident, long timestamp)
throws NoSuchTableException {
try {
com.datastrato.gravitino.rel.Table gravitinoTable = loadGravitinoTable(ident);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident, timestamp);
// Will create a catalog specific table
return createSparkTable(
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

private com.datastrato.gravitino.rel.Table loadGravitinoTable(Identifier ident)
caican00 marked this conversation as resolved.
Show resolved Hide resolved
throws NoSuchTableException {
try {
String database = getDatabase(ident);
return gravitinoCatalogClient
.asTableCatalog()
.loadTable(NameIdentifier.of(metalakeName, catalogName, database, ident.name()));
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

private org.apache.spark.sql.connector.catalog.Table loadSparkTable(
Identifier ident, String version) {
try {
return sparkCatalog.loadTable(ident, version);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}

private org.apache.spark.sql.connector.catalog.Table loadSparkTable(
Identifier ident, long timestamp) {
try {
return sparkCatalog.loadTable(ident, timestamp);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}
}