Skip to content
Permalink
Browse files
fix: query could hang transaction if ResultSet#next() is not called (#…
…643)

If the first statement of a read/write transaction was a query or a read operation,
and the application would not call ResultSet#next() on the return result, the transaction
would hang indefinetely as the query would be marked as the one that should initiate the
transaction (inline the BeginTransaction option). The query would however never be
executed, as the actual query execution is deferred until the first call to ResultSet#next().

Fixes #641
  • Loading branch information
olavloite committed Nov 18, 2020
1 parent 7584baa commit 48f92e3d1b26644bde62a8d864cec96c3c71687d
@@ -608,33 +608,33 @@ ExecuteBatchDmlRequest.Builder getExecuteBatchDmlRequestBuilder(Iterable<Stateme

ResultSet executeQueryInternalWithOptions(
final Statement statement,
com.google.spanner.v1.ExecuteSqlRequest.QueryMode queryMode,
final com.google.spanner.v1.ExecuteSqlRequest.QueryMode queryMode,
Options options,
ByteString partitionToken) {
final ByteString partitionToken) {
beforeReadOrQuery();
final ExecuteSqlRequest.Builder request = getExecuteSqlRequestBuilder(statement, queryMode);
if (partitionToken != null) {
request.setPartitionToken(partitionToken);
}
final int prefetchChunks =
options.hasPrefetchChunks() ? options.prefetchChunks() : defaultPrefetchChunks;
ResumableStreamIterator stream =
new ResumableStreamIterator(MAX_BUFFERED_CHUNKS, SpannerImpl.QUERY, span) {
@Override
CloseableIterator<PartialResultSet> startStream(@Nullable ByteString resumeToken) {
GrpcStreamIterator stream = new GrpcStreamIterator(statement, prefetchChunks);
final ExecuteSqlRequest.Builder request =
getExecuteSqlRequestBuilder(statement, queryMode);
if (partitionToken != null) {
request.setPartitionToken(partitionToken);
}
if (resumeToken != null) {
request.setResumeToken(resumeToken);
}
SpannerRpc.StreamingCall call =
rpc.executeQuery(request.build(), stream.consumer(), session.getOptions());
call.request(prefetchChunks);
stream.setCall(call);
stream.setCall(call, request.hasTransaction() && request.getTransaction().hasBegin());
return stream;
}
};
return new GrpcResultSet(
stream, this, request.hasTransaction() && request.getTransaction().hasBegin());
return new GrpcResultSet(stream, this);
}

/**
@@ -723,10 +723,6 @@ ResultSet readInternalWithOptions(
if (index != null) {
builder.setIndex(index);
}
TransactionSelector selector = getTransactionSelector();
if (selector != null) {
builder.setTransaction(selector);
}
if (partitionToken != null) {
builder.setPartitionToken(partitionToken);
}
@@ -740,15 +736,18 @@ CloseableIterator<PartialResultSet> startStream(@Nullable ByteString resumeToken
if (resumeToken != null) {
builder.setResumeToken(resumeToken);
}
TransactionSelector selector = getTransactionSelector();
if (selector != null) {
builder.setTransaction(selector);
}
SpannerRpc.StreamingCall call =
rpc.read(builder.build(), stream.consumer(), session.getOptions());
call.request(prefetchChunks);
stream.setCall(call);
stream.setCall(call, selector != null && selector.hasBegin());
return stream;
}
};
GrpcResultSet resultSet =
new GrpcResultSet(stream, this, selector != null && selector.hasBegin());
GrpcResultSet resultSet = new GrpcResultSet(stream, this);
return resultSet;
}

@@ -91,17 +91,14 @@ interface Listener {
static class GrpcResultSet extends AbstractResultSet<List<Object>> {
private final GrpcValueIterator iterator;
private final Listener listener;
private final boolean beginTransaction;
private GrpcStruct currRow;
private SpannerException error;
private ResultSetStats statistics;
private boolean closed;

GrpcResultSet(
CloseableIterator<PartialResultSet> iterator, Listener listener, boolean beginTransaction) {
GrpcResultSet(CloseableIterator<PartialResultSet> iterator, Listener listener) {
this.iterator = new GrpcValueIterator(iterator);
this.listener = listener;
this.beginTransaction = beginTransaction;
}

@Override
@@ -130,7 +127,7 @@ public boolean next() throws SpannerException {
}
return hasNext;
} catch (SpannerException e) {
throw yieldError(e, beginTransaction && currRow == null);
throw yieldError(e, iterator.isWithBeginTransaction() && currRow == null);
}
}

@@ -297,6 +294,10 @@ void close(@Nullable String message) {
stream.close(message);
}

boolean isWithBeginTransaction() {
return stream.isWithBeginTransaction();
}

/** @param a is a mutable list and b will be concatenated into a. */
private void concatLists(List<com.google.protobuf.Value> a, List<com.google.protobuf.Value> b) {
if (a.size() == 0 || b.size() == 0) {
@@ -760,6 +761,8 @@ protected List<Struct> getStructListInternal(int columnIndex) {
* @param message a message to include in the final RPC status
*/
void close(@Nullable String message);

boolean isWithBeginTransaction();
}

/** Adapts a streaming read/query call into an iterator over partial result sets. */
@@ -774,6 +777,7 @@ static class GrpcStreamIterator extends AbstractIterator<PartialResultSet>
private final Statement statement;

private SpannerRpc.StreamingCall call;
private boolean withBeginTransaction;
private SpannerException error;

@VisibleForTesting
@@ -792,8 +796,9 @@ protected final SpannerRpc.ResultStreamConsumer consumer() {
return consumer;
}

public void setCall(SpannerRpc.StreamingCall call) {
public void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) {
this.call = call;
this.withBeginTransaction = withBeginTransaction;
}

@Override
@@ -803,6 +808,11 @@ public void close(@Nullable String message) {
}
}

@Override
public boolean isWithBeginTransaction() {
return withBeginTransaction;
}

@Override
protected final PartialResultSet computeNext() {
PartialResultSet next;
@@ -873,8 +883,8 @@ public void onError(SpannerException e) {

// Visible only for testing.
@VisibleForTesting
void setCall(SpannerRpc.StreamingCall call) {
GrpcStreamIterator.this.setCall(call);
void setCall(SpannerRpc.StreamingCall call, boolean withBeginTransaction) {
GrpcStreamIterator.this.setCall(call, withBeginTransaction);
}
}
}
@@ -987,6 +997,11 @@ public void close(@Nullable String message) {
}
}

@Override
public boolean isWithBeginTransaction() {
return stream != null && stream.isWithBeginTransaction();
}

@Override
protected PartialResultSet computeNext() {
Context context = Context.current();
@@ -267,7 +267,7 @@ ApiFuture<Timestamp> commitAsync() {
final SettableApiFuture<Void> finishOps;
CommitRequest.Builder builder = CommitRequest.newBuilder().setSession(session.getName());
synchronized (lock) {
if (transactionIdFuture == null && transactionId == null) {
if (transactionIdFuture == null && transactionId == null && runningAsyncOperations == 0) {
finishOps = SettableApiFuture.create();
createTxnAsync(finishOps);
} else {
@@ -75,13 +75,14 @@ public void cancel(@Nullable String message) {}

@Override
public void request(int numMessages) {}
});
},
false);
consumer = stream.consumer();
resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false);
resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener());
}

public AbstractResultSet.GrpcResultSet resultSetWithMode(QueryMode queryMode) {
return new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false);
return new AbstractResultSet.GrpcResultSet(stream, new NoOpListener());
}

@Test
@@ -642,7 +643,7 @@ public com.google.protobuf.Value apply(@Nullable Value input) {

private void verifySerialization(
Function<Value, com.google.protobuf.Value> protoFn, Value... values) {
resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false);
resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener());
PartialResultSet.Builder builder = PartialResultSet.newBuilder();
List<Type.StructField> types = new ArrayList<>();
for (Value value : values) {
@@ -16,6 +16,7 @@

package com.google.cloud.spanner;

import static com.google.cloud.spanner.MockSpannerTestUtil.SELECT1;
import static com.google.common.truth.Truth.assertThat;
import static org.junit.Assert.fail;

@@ -65,6 +66,7 @@
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ScheduledThreadPoolExecutor;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.After;
import org.junit.AfterClass;
@@ -1139,6 +1141,123 @@ public ApiFuture<Long> apply(TransactionContext txn, Long input)
assertThat(countTransactionsStarted()).isEqualTo(1);
}

@Test
public void queryWithoutNext() {
DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d"));
assertThat(
client
.readWriteTransaction()
.run(
new TransactionCallable<Long>() {
@Override
public Long run(TransactionContext transaction) throws Exception {
// This will not actually send an RPC, so it will also not request a
// transaction.
transaction.executeQuery(SELECT1);
return transaction.executeUpdate(UPDATE_STATEMENT);
}
}))
.isEqualTo(UPDATE_COUNT);
assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L);
assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(1L);
assertThat(countTransactionsStarted()).isEqualTo(1);
}

@Test
public void queryAsyncWithoutCallback() {
DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d"));
assertThat(
client
.readWriteTransaction()
.run(
new TransactionCallable<Long>() {
@Override
public Long run(TransactionContext transaction) throws Exception {
transaction.executeQueryAsync(SELECT1);
return transaction.executeUpdate(UPDATE_STATEMENT);
}
}))
.isEqualTo(UPDATE_COUNT);
assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L);
assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(1L);
assertThat(countTransactionsStarted()).isEqualTo(1);
}

@Test
public void readWithoutNext() {
DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d"));
assertThat(
client
.readWriteTransaction()
.run(
new TransactionCallable<Long>() {
@Override
public Long run(TransactionContext transaction) throws Exception {
transaction.read("FOO", KeySet.all(), Arrays.asList("ID"));
return transaction.executeUpdate(UPDATE_STATEMENT);
}
}))
.isEqualTo(UPDATE_COUNT);
assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L);
assertThat(mockSpanner.countRequestsOfType(ReadRequest.class)).isEqualTo(0L);
assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(1L);
assertThat(countTransactionsStarted()).isEqualTo(1);
}

@Test
public void readAsyncWithoutCallback() {
DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d"));
assertThat(
client
.readWriteTransaction()
.run(
new TransactionCallable<Long>() {
@Override
public Long run(TransactionContext transaction) throws Exception {
transaction.readAsync("FOO", KeySet.all(), Arrays.asList("ID"));
return transaction.executeUpdate(UPDATE_STATEMENT);
}
}))
.isEqualTo(UPDATE_COUNT);
assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L);
assertThat(mockSpanner.countRequestsOfType(ReadRequest.class)).isEqualTo(0L);
assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(1L);
assertThat(countTransactionsStarted()).isEqualTo(1);
}

@Test
public void query_ThenUpdate_ThenConsumeResultSet()
throws InterruptedException, TimeoutException {
DatabaseClient client = spanner.getDatabaseClient(DatabaseId.of("p", "i", "d"));
assertThat(
client
.readWriteTransaction()
.run(
new TransactionCallable<Long>() {
@Override
public Long run(TransactionContext transaction) throws Exception {
ResultSet rs = transaction.executeQuery(SELECT1);
long updateCount = transaction.executeUpdate(UPDATE_STATEMENT);
// Consume the result set.
while (rs.next()) {}
return updateCount;
}
}))
.isEqualTo(UPDATE_COUNT);
// The update statement should start the transaction, and the query should use the transaction
// id returned by the update.
assertThat(mockSpanner.countRequestsOfType(BeginTransactionRequest.class)).isEqualTo(0L);
assertThat(mockSpanner.countRequestsOfType(ExecuteSqlRequest.class)).isEqualTo(2L);
assertThat(countTransactionsStarted()).isEqualTo(1);
List<AbstractMessage> requests = mockSpanner.getRequests();
requests = requests.subList(requests.size() - 3, requests.size());
assertThat(requests.get(0)).isInstanceOf(ExecuteSqlRequest.class);
assertThat(((ExecuteSqlRequest) requests.get(0)).getSql()).isEqualTo(UPDATE_STATEMENT.getSql());
assertThat(requests.get(1)).isInstanceOf(ExecuteSqlRequest.class);
assertThat(((ExecuteSqlRequest) requests.get(1)).getSql()).isEqualTo(SELECT1.getSql());
assertThat(requests.get(2)).isInstanceOf(CommitRequest.class);
}

private int countRequests(Class<? extends AbstractMessage> requestType) {
int count = 0;
for (AbstractMessage msg : mockSpanner.getRequests()) {
@@ -117,9 +117,10 @@ public void cancel(@Nullable String message) {}

@Override
public void request(int numMessages) {}
});
},
false);
consumer = stream.consumer();
resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener(), false);
resultSet = new AbstractResultSet.GrpcResultSet(stream, new NoOpListener());

JSONArray chunks = testCase.getJSONArray("chunks");
JSONObject expectedResult = testCase.getJSONObject("result");
@@ -116,6 +116,11 @@ protected PartialResultSet computeNext() {
public void close(@Nullable String message) {
stream.close();
}

@Override
public boolean isWithBeginTransaction() {
return false;
}
}

Starter starter = Mockito.mock(Starter.class);

0 comments on commit 48f92e3

Please sign in to comment.