Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#3264] feat(spark-connector): Support Iceberg time travel in SQL queries #3265

Merged
merged 15 commits into from
May 20, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import lombok.Data;
Expand Down Expand Up @@ -252,6 +253,74 @@ void testIcebergCallOperations() throws NoSuchTableException {
testIcebergCallRewritePositionDeleteFiles();
}

@Test
void testIcebergTimeTravelQuery() throws NoSuchTableException {
String tableName = "test_iceberg_as_of_query";
dropTableIfExists(tableName);
createSimpleTable(tableName);
checkTableColumns(tableName, getSimpleTableColumn(), getTableInfo(tableName));

sql(String.format("INSERT INTO %s VALUES (1, '1', 1)", tableName));
List<String> tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
long snapshotId = getCurrentSnapshotId(tableName);
sparkIcebergTable.table().manageSnapshots().createBranch("test_branch", snapshotId).commit();
sparkIcebergTable.table().manageSnapshots().createTag("test_tag", snapshotId).commit();
long snapshotTimestamp = getCurrentSnapshotTimestamp(tableName);
long timestamp = waitUntilAfter(snapshotTimestamp + 1000);
long timestampInSeconds = TimeUnit.MILLISECONDS.toSeconds(timestamp);

// create a second snapshot
sql(String.format("INSERT INTO %s VALUES (2, '2', 2)", tableName));
tableData = getQueryData(getSelectAllSqlWithOrder(tableName, "id"));
Assertions.assertEquals(2, tableData.size());
Assertions.assertEquals("1,1,1;2,2,2", String.join(";", tableData));

tableData =
getQueryData(
String.format("SELECT * FROM %s TIMESTAMP AS OF %s", tableName, timestampInSeconds));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format(
"SELECT * FROM %s FOR SYSTEM_TIME AS OF %s", tableName, timestampInSeconds));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData =
getQueryData(String.format("SELECT * FROM %s VERSION AS OF %d", tableName, snapshotId));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF %d", tableName, snapshotId));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData =
getQueryData(String.format("SELECT * FROM %s VERSION AS OF 'test_branch'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_branch'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));

tableData = getQueryData(String.format("SELECT * FROM %s VERSION AS OF 'test_tag'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
tableData =
getQueryData(
String.format("SELECT * FROM %s FOR SYSTEM_VERSION AS OF 'test_tag'", tableName));
Assertions.assertEquals(1, tableData.size());
Assertions.assertEquals("1,1,1", tableData.get(0));
}

private void testMetadataColumns() {
String tableName = "test_metadata_columns";
dropTableIfExists(tableName);
Expand Down Expand Up @@ -722,13 +791,31 @@ static IcebergTableWriteProperties of(
}
}

private long getCurrentSnapshotId(String tableName) throws NoSuchTableException {
private SparkIcebergTable getSparkIcebergTableInstance(String tableName)
caican00 marked this conversation as resolved.
Show resolved Hide resolved
throws NoSuchTableException {
CatalogPlugin catalogPlugin =
getSparkSession().sessionState().catalogManager().catalog(getCatalogName());
Assertions.assertInstanceOf(TableCatalog.class, catalogPlugin);
TableCatalog catalog = (TableCatalog) catalogPlugin;
Table table = catalog.loadTable(Identifier.of(new String[] {getDefaultDatabase()}, tableName));
SparkIcebergTable sparkIcebergTable = (SparkIcebergTable) table;
return (SparkIcebergTable) table;
}

private long getCurrentSnapshotTimestamp(String tableName) throws NoSuchTableException {
SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
return sparkIcebergTable.table().currentSnapshot().timestampMillis();
}

private long getCurrentSnapshotId(String tableName) throws NoSuchTableException {
SparkIcebergTable sparkIcebergTable = getSparkIcebergTableInstance(tableName);
return sparkIcebergTable.table().currentSnapshot().snapshotId();
}

private long waitUntilAfter(Long timestampMillis) {
long current = System.currentTimeMillis();
while (current <= timestampMillis) {
current = System.currentTimeMillis();
}
return current;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ protected abstract TableCatalog createAndInitSparkCatalog(
*
* @param identifier Spark's table identifier
* @param gravitinoTable Gravitino table to do DDL operations
* @param sparkTable Spark internal table to do IO operations
* @param sparkCatalog specific Spark catalog to do IO operations
* @param propertiesConverter transform properties between Gravitino and Spark
* @param sparkTransformConverter sparkTransformConverter convert transforms between Gravitino and
Expand All @@ -101,6 +102,7 @@ protected abstract TableCatalog createAndInitSparkCatalog(
protected abstract Table createSparkTable(
Identifier identifier,
com.datastrato.gravitino.rel.Table gravitinoTable,
Table sparkTable,
TableCatalog sparkCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter);
Expand Down Expand Up @@ -194,8 +196,14 @@ public Table createTable(
partitionings,
distributionAndSortOrdersInfo.getDistribution(),
distributionAndSortOrdersInfo.getSortOrders());
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (NoSuchSchemaException e) {
throw new NoSuchNamespaceException(ident.namespace());
} catch (com.datastrato.gravitino.exceptions.TableAlreadyExistsException e) {
Expand All @@ -206,14 +214,16 @@ public Table createTable(
@Override
public Table loadTable(Identifier ident) throws NoSuchTableException {
try {
String database = getDatabase(ident);
com.datastrato.gravitino.rel.Table gravitinoTable =
gravitinoCatalogClient
.asTableCatalog()
.loadTable(NameIdentifier.of(metalakeName, catalogName, database, ident.name()));
com.datastrato.gravitino.rel.Table gravitinoTable = loadGravitinoTable(ident);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
// Will create a catalog specific table
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
Expand All @@ -240,8 +250,14 @@ public Table alterTable(Identifier ident, TableChange... changes) throws NoSuchT
.alterTable(
NameIdentifier.of(metalakeName, catalogName, getDatabase(ident), ident.name()),
gravitinoTableChanges);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident);
return createSparkTable(
ident, gravitinoTable, sparkCatalog, propertiesConverter, sparkTransformConverter);
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
Expand Down Expand Up @@ -377,6 +393,25 @@ public boolean dropNamespace(String[] namespace, boolean cascade)
}
}

protected com.datastrato.gravitino.rel.Table loadGravitinoTable(Identifier ident)
throws NoSuchTableException {
try {
String database = getDatabase(ident);
return gravitinoCatalogClient
.asTableCatalog()
.loadTable(NameIdentifier.of(metalakeName, catalogName, database, ident.name()));
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

protected String getDatabase(Identifier sparkIdentifier) {
if (sparkIdentifier.namespace().length > 0) {
return sparkIdentifier.namespace()[0];
}
return getCatalogDefaultNamespace();
}

private void validateNamespace(String[] namespace) {
Preconditions.checkArgument(
namespace.length == 1,
Expand All @@ -403,13 +438,6 @@ private com.datastrato.gravitino.rel.Column createGravitinoColumn(Column sparkCo
com.datastrato.gravitino.rel.Column.DEFAULT_VALUE_NOT_SET);
}

protected String getDatabase(Identifier sparkIdentifier) {
if (sparkIdentifier.namespace().length > 0) {
return sparkIdentifier.namespace()[0];
}
return getCatalogDefaultNamespace();
}

private String getDatabase(NameIdentifier gravitinoIdentifier) {
Preconditions.checkArgument(
gravitinoIdentifier.namespace().length() == 3,
Expand Down Expand Up @@ -497,4 +525,16 @@ private static com.datastrato.gravitino.rel.TableChange.ColumnPosition transform
"Unsupported table column position %s", columnPosition.getClass().getName()));
}
}

private Table loadSparkTable(Identifier ident) {
try {
return sparkCatalog.loadTable(ident);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import java.util.Map;
import org.apache.kyuubi.spark.connector.hive.HiveTable;
import org.apache.kyuubi.spark.connector.hive.HiveTableCatalog;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.connector.catalog.Identifier;
import org.apache.spark.sql.connector.catalog.TableCatalog;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
Expand All @@ -34,19 +33,10 @@ protected TableCatalog createAndInitSparkCatalog(
protected org.apache.spark.sql.connector.catalog.Table createSparkTable(
Identifier identifier,
Table gravitinoTable,
org.apache.spark.sql.connector.catalog.Table sparkTable,
TableCatalog sparkHiveCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter) {
org.apache.spark.sql.connector.catalog.Table sparkTable;
try {
sparkTable = sparkHiveCatalog.loadTable(identifier);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(identifier), identifier.name())),
e);
}
return new SparkHiveTable(
identifier,
gravitinoTable,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,19 +53,10 @@ protected TableCatalog createAndInitSparkCatalog(
protected org.apache.spark.sql.connector.catalog.Table createSparkTable(
Identifier identifier,
Table gravitinoTable,
org.apache.spark.sql.connector.catalog.Table sparkTable,
TableCatalog sparkIcebergCatalog,
PropertiesConverter propertiesConverter,
SparkTransformConverter sparkTransformConverter) {
org.apache.spark.sql.connector.catalog.Table sparkTable;
try {
sparkTable = sparkIcebergCatalog.loadTable(identifier);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(identifier), identifier.name())),
e);
}
return new SparkIcebergTable(
identifier,
gravitinoTable,
Expand Down Expand Up @@ -128,6 +119,44 @@ public Catalog icebergCatalog() {
return ((SparkCatalog) sparkCatalog).icebergCatalog();
}

@Override
public org.apache.spark.sql.connector.catalog.Table loadTable(Identifier ident, String version)
throws NoSuchTableException {
try {
com.datastrato.gravitino.rel.Table gravitinoTable = loadGravitinoTable(ident);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident, version);
// Will create a catalog specific table
return createSparkTable(
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

@Override
public org.apache.spark.sql.connector.catalog.Table loadTable(Identifier ident, long timestamp)
throws NoSuchTableException {
try {
com.datastrato.gravitino.rel.Table gravitinoTable = loadGravitinoTable(ident);
org.apache.spark.sql.connector.catalog.Table sparkTable = loadSparkTable(ident, timestamp);
// Will create a catalog specific table
return createSparkTable(
ident,
gravitinoTable,
sparkTable,
sparkCatalog,
propertiesConverter,
sparkTransformConverter);
} catch (com.datastrato.gravitino.exceptions.NoSuchTableException e) {
throw new NoSuchTableException(ident);
}
}

private boolean isSystemNamespace(String[] namespace)
throws NoSuchMethodException, InvocationTargetException, IllegalAccessException,
ClassNotFoundException {
Expand All @@ -136,4 +165,30 @@ private boolean isSystemNamespace(String[] namespace)
isSystemNamespace.setAccessible(true);
return (Boolean) isSystemNamespace.invoke(baseCatalog, (Object) namespace);
}

private org.apache.spark.sql.connector.catalog.Table loadSparkTable(
Identifier ident, String version) {
try {
return sparkCatalog.loadTable(ident, version);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}

private org.apache.spark.sql.connector.catalog.Table loadSparkTable(
Identifier ident, long timestamp) {
try {
return sparkCatalog.loadTable(ident, timestamp);
} catch (NoSuchTableException e) {
throw new RuntimeException(
String.format(
"Failed to load the real sparkTable: %s",
String.join(".", getDatabase(ident), ident.name())),
e);
}
}
}
Loading
Loading