diff --git a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java index 532778b7d3e7..d992bb15df6e 100644 --- a/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java +++ b/google-cloud-clients/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionPool.java @@ -573,7 +573,7 @@ public Timestamp getCommitTimestamp() { @Override public TransactionRunner allowNestedTransaction() { runner.allowNestedTransaction(); - return runner; + return this; } } diff --git a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java index c0934aeeabb5..7aba1b4ba435 100644 --- a/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java +++ b/google-cloud-clients/google-cloud-spanner/src/test/java/com/google/cloud/spanner/DatabaseClientImplTest.java @@ -27,6 +27,7 @@ import com.google.cloud.spanner.MockSpannerServiceImpl.SimulatedExecutionTime; import com.google.cloud.spanner.MockSpannerServiceImpl.StatementResult; import com.google.cloud.spanner.TransactionRunner.TransactionCallable; +import com.google.common.base.Stopwatch; import com.google.protobuf.ListValue; import com.google.spanner.v1.ResultSetMetadata; import com.google.spanner.v1.StructType; @@ -37,6 +38,7 @@ import io.grpc.inprocess.InProcessServerBuilder; import java.io.IOException; import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; import org.junit.After; import org.junit.AfterClass; import org.junit.Before; @@ -257,4 +259,102 @@ public Long run(TransactionContext transaction) throws Exception { assertThat(updateCount, is(equalTo(UPDATE_COUNT))); } } + + @Test + public void testAllowNestedTransactions() throws InterruptedException { + final DatabaseClientImpl client = + (DatabaseClientImpl) + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE]")); + // Wait until all sessions have been created. + final int minSessions = spanner.getOptions().getSessionPoolOptions().getMinSessions(); + Stopwatch watch = Stopwatch.createStarted(); + while (watch.elapsed(TimeUnit.SECONDS) < 5 + && client.pool.getNumberOfSessionsInPool() < minSessions) { + Thread.sleep(1L); + } + assertThat(client.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions))); + Long res = + client + .readWriteTransaction() + .allowNestedTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + assertThat( + client.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions - 1))); + return transaction.executeUpdate(UPDATE_STATEMENT); + } + }); + assertThat(res, is(equalTo(UPDATE_COUNT))); + assertThat(client.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions))); + } + + @Test + public void testNestedTransactionsUsingTwoDatabases() throws InterruptedException { + final DatabaseClientImpl client1 = + (DatabaseClientImpl) + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE1]")); + final DatabaseClientImpl client2 = + (DatabaseClientImpl) + spanner.getDatabaseClient(DatabaseId.of("[PROJECT]", "[INSTANCE]", "[DATABASE2]")); + // Wait until all sessions have been created so we can actually check the number of sessions + // checked out of the pools. + final int minSessions = spanner.getOptions().getSessionPoolOptions().getMinSessions(); + Stopwatch watch = Stopwatch.createStarted(); + while (watch.elapsed(TimeUnit.SECONDS) < 5 + && (client1.pool.getNumberOfSessionsInPool() < minSessions + || client2.pool.getNumberOfSessionsInPool() < minSessions)) { + Thread.sleep(1L); + } + assertThat(client1.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions))); + assertThat(client2.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions))); + Long res = + client1 + .readWriteTransaction() + .allowNestedTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + // Client1 should have 1 session checked out. + // Client2 should have 0 sessions checked out. + assertThat( + client1.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions - 1))); + assertThat(client2.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions))); + Long add = + client2 + .readWriteTransaction() + .run( + new TransactionCallable() { + @Override + public Long run(TransactionContext transaction) throws Exception { + // Both clients should now have 1 session checked out. + assertThat( + client1.pool.getNumberOfSessionsInPool(), + is(equalTo(minSessions - 1))); + assertThat( + client2.pool.getNumberOfSessionsInPool(), + is(equalTo(minSessions - 1))); + try (ResultSet rs = transaction.executeQuery(SELECT1)) { + if (rs.next()) { + return rs.getLong(0); + } + return 0L; + } + } + }); + try (ResultSet rs = transaction.executeQuery(SELECT1)) { + if (rs.next()) { + return add + rs.getLong(0); + } + return add + 0L; + } + } + }); + assertThat(res, is(equalTo(2L))); + // All sessions should now be checked back in to the pools. + assertThat(client1.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions))); + assertThat(client2.pool.getNumberOfSessionsInPool(), is(equalTo(minSessions))); + } }