diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java index 405c5f86812..1dba4b92ffe 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/SessionClient.java @@ -22,6 +22,7 @@ import com.google.api.pathtemplate.PathTemplate; import com.google.cloud.grpc.GrpcTransportOptions.ExecutorFactory; import com.google.cloud.spanner.spi.v1.SpannerRpc; +import com.google.cloud.spanner.spi.v1.SpannerRpc.Option; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; @@ -109,6 +110,13 @@ Object value() { return ImmutableMap.copyOf(tmp); } + static Map createRequestOptions( + long channelId, XGoogSpannerRequestId requestId) { + return ImmutableMap.of( + Option.CHANNEL_HINT, channelId, + Option.REQUEST_ID, requestId); + } + private final class BatchCreateSessionsRunnable implements Runnable { private final long channelHint; private final int sessionCount; @@ -219,15 +227,14 @@ public XGoogSpannerRequestId nextRequestId(long channelId, int attempt) { SessionImpl createSession() { // The sessionChannelCounter could overflow, but that will just flip it to Integer.MIN_VALUE, // which is also a valid channel hint. - final Map options; final long channelId; synchronized (this) { - options = optionMap(SessionOption.channelHint(sessionChannelCounter++)); channelId = sessionChannelCounter; + sessionChannelCounter++; } + XGoogSpannerRequestId reqId = nextRequestId(channelId, 1); ISpan span = spanner.getTracer().spanBuilder(SpannerImpl.CREATE_SESSION, this.commonAttributes); try (IScope s = spanner.getTracer().withSpan(span)) { - XGoogSpannerRequestId reqId = this.nextRequestId(channelId, 1); com.google.spanner.v1.Session session = spanner .getRpc() @@ -235,10 +242,13 @@ SessionImpl createSession() { db.getName(), spanner.getOptions().getDatabaseRole(), spanner.getOptions().getSessionLabels(), - reqId.withOptions(options)); + createRequestOptions(channelId, reqId)); SessionReference sessionReference = new SessionReference( - session.getName(), session.getCreateTime(), session.getMultiplexed(), options); + session.getName(), + session.getCreateTime(), + session.getMultiplexed(), + optionMap(SessionOption.channelHint(channelId))); SessionImpl sessionImpl = new SessionImpl(spanner, sessionReference); sessionImpl.setRequestIdCreator(this); return sessionImpl; @@ -399,7 +409,6 @@ void asyncBatchCreateSessions( */ private List internalBatchCreateSessions( final int sessionCount, final long channelHint) throws SpannerException { - final Map options = optionMap(SessionOption.channelHint(channelHint)); ISpan parent = spanner.getTracer().getCurrentSpan(); ISpan span = spanner @@ -417,7 +426,7 @@ private List internalBatchCreateSessions( sessionCount, spanner.getOptions().getDatabaseRole(), spanner.getOptions().getSessionLabels(), - reqId.withOptions(options)); + createRequestOptions(channelHint, reqId)); span.addAnnotation( String.format( "Request for %d sessions returned %d sessions", sessionCount, sessions.size())); @@ -428,7 +437,10 @@ private List internalBatchCreateSessions( new SessionImpl( spanner, new SessionReference( - session.getName(), session.getCreateTime(), session.getMultiplexed(), options)); + session.getName(), + session.getCreateTime(), + session.getMultiplexed(), + optionMap(SessionOption.channelHint(channelHint)))); sessionImpl.setRequestIdCreator(this); res.add(sessionImpl); } diff --git a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java index 325aace2d2c..49a0c7eb3d2 100644 --- a/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java +++ b/google-cloud-spanner/src/main/java/com/google/cloud/spanner/XGoogSpannerRequestId.java @@ -124,8 +124,7 @@ public void incrementAttempt() { this.attempt++; } - @SuppressWarnings("unchecked") - public Map withOptions(Map options) { + Map withOptions(Map options) { Map copyOptions = new HashMap<>(); if (options != null) { copyOptions.putAll(options); @@ -139,11 +138,11 @@ public int hashCode() { return Objects.hash(this.nthClientId, this.nthChannelId, this.nthRequest, this.attempt); } - public interface RequestIdCreator { + interface RequestIdCreator { XGoogSpannerRequestId nextRequestId(long channelId, int attempt); } - public static class NoopRequestIdCreator implements RequestIdCreator { + static class NoopRequestIdCreator implements RequestIdCreator { NoopRequestIdCreator() {} @Override @@ -152,7 +151,7 @@ public XGoogSpannerRequestId nextRequestId(long channelId, int attempt) { } } - public static void assertMonotonicityOfIds(String prefix, List reqIds) { + static void assertMonotonicityOfIds(String prefix, List reqIds) { int size = reqIds.size(); List violations = new ArrayList<>(); @@ -164,7 +163,7 @@ public static void assertMonotonicityOfIds(String prefix, List labels, @Nullable Map options) throws SpannerException { - // By default sessions are not multiplexed + // By default, sessions are not multiplexed return createSession(databaseName, databaseRole, labels, options, false); } @@ -2043,8 +2043,10 @@ GrpcCallContext newCallContext( context = context.withChannelAffinity(affinity.intValue()); } } - String methodName = method.getFullMethodName(); - context = withRequestId(context, options, methodName); + if (method != null) { + String methodName = method.getFullMethodName(); + context = withRequestId(context, options, methodName); + } } context = context.withExtraHeaders(metadataProvider.newExtraHeaders(resource, projectName)); if (routeToLeader && leaderAwareRoutingEnabled) { @@ -2065,7 +2067,8 @@ GrpcCallContext newCallContext( return (GrpcCallContext) context.merge(apiCallContextFromContext); } - GrpcCallContext withRequestId(GrpcCallContext context, Map options, String methodName) { + GrpcCallContext withRequestId( + GrpcCallContext context, Map options, String methodName) { XGoogSpannerRequestId reqId = (XGoogSpannerRequestId) options.get(Option.REQUEST_ID); if (reqId == null) { return context; diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java index a6196df01e6..307504a20ca 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionImplTest.java @@ -21,6 +21,7 @@ import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.fail; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -223,7 +224,7 @@ public void writeAtLeastOnce() throws ParseException { ArgumentCaptor commit = ArgumentCaptor.forClass(CommitRequest.class); CommitResponse response = CommitResponse.newBuilder().setCommitTimestamp(Timestamps.parse(timestampString)).build(); - Mockito.when(rpc.commit(commit.capture(), Mockito.eq(options))).thenReturn(response); + Mockito.when(rpc.commit(commit.capture(), anyMap())).thenReturn(response); Timestamp timestamp = session.writeAtLeastOnce( @@ -255,7 +256,7 @@ public void writeAtLeastOnceWithOptions() throws ParseException { ArgumentCaptor commit = ArgumentCaptor.forClass(CommitRequest.class); CommitResponse response = CommitResponse.newBuilder().setCommitTimestamp(Timestamps.parse(timestampString)).build(); - Mockito.when(rpc.commit(commit.capture(), Mockito.eq(options))).thenReturn(response); + Mockito.when(rpc.commit(commit.capture(), anyMap())).thenReturn(response); session.writeAtLeastOnceWithOptions( Collections.singletonList(Mutation.newInsertBuilder("T").set("C").to("x").build()), Options.tag(tag)); @@ -340,7 +341,7 @@ public void newMultiUseReadOnlyTransactionContextClosesOldSingleUseContext() { public void writeClosesOldSingleUseContext() throws ParseException { ReadContext ctx = session.singleUse(TimestampBound.strong()); - Mockito.when(rpc.commit(Mockito.any(), Mockito.eq(options))) + Mockito.when(rpc.commit(Mockito.any(), anyMap())) .thenReturn( CommitResponse.newBuilder() .setCommitTimestamp(Timestamps.parse("2015-10-01T10:54:20.021Z")) @@ -442,7 +443,7 @@ public void request(int numMessages) {} private void mockRead(final PartialResultSet myResultSet) { final ArgumentCaptor consumer = ArgumentCaptor.forClass(SpannerRpc.ResultStreamConsumer.class); - Mockito.when(rpc.read(Mockito.any(), consumer.capture(), Mockito.eq(options), eq(false))) + Mockito.when(rpc.read(Mockito.any(), consumer.capture(), anyMap(), eq(false))) .then( invocation -> { consumer.getValue().onPartialResultSet(myResultSet); @@ -458,8 +459,7 @@ public void multiUseReadOnlyTransactionReturnsEmptyTransactionMetadata() { PartialResultSet.newBuilder() .setMetadata(newMetadata(Type.struct(Type.StructField.of("C", Type.string())))) .build(); - Mockito.when(rpc.beginTransaction(Mockito.any(), Mockito.eq(options), eq(false))) - .thenReturn(txnMetadata); + Mockito.when(rpc.beginTransaction(Mockito.any(), anyMap(), eq(false))).thenReturn(txnMetadata); mockRead(resultSet); ReadOnlyTransaction txn = session.readOnlyTransaction(TimestampBound.strong()); @@ -477,8 +477,7 @@ public void multiUseReadOnlyTransactionReturnsMissingTimestamp() { PartialResultSet.newBuilder() .setMetadata(newMetadata(Type.struct(Type.StructField.of("C", Type.string())))) .build(); - Mockito.when(rpc.beginTransaction(Mockito.any(), Mockito.eq(options), eq(false))) - .thenReturn(txnMetadata); + Mockito.when(rpc.beginTransaction(Mockito.any(), anyMap(), eq(false))).thenReturn(txnMetadata); mockRead(resultSet); ReadOnlyTransaction txn = session.readOnlyTransaction(TimestampBound.strong()); @@ -497,8 +496,7 @@ public void multiUseReadOnlyTransactionReturnsMissingTransactionId() throws Pars PartialResultSet.newBuilder() .setMetadata(newMetadata(Type.struct(Type.StructField.of("C", Type.string())))) .build(); - Mockito.when(rpc.beginTransaction(Mockito.any(), Mockito.eq(options), eq(false))) - .thenReturn(txnMetadata); + Mockito.when(rpc.beginTransaction(Mockito.any(), anyMap(), eq(false))).thenReturn(txnMetadata); mockRead(resultSet); ReadOnlyTransaction txn = session.readOnlyTransaction(TimestampBound.strong()); diff --git a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java index 4538be784b1..2b9ad35af1b 100644 --- a/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java +++ b/google-cloud-spanner/src/test/java/com/google/cloud/spanner/SessionPoolTest.java @@ -1647,13 +1647,13 @@ public void testSessionNotFoundWrite() { SpannerExceptionFactoryTest.newSessionNotFoundException(sessionName); List mutations = Collections.singletonList(Mutation.newInsertBuilder("FOO").build()); final SessionImpl closedSession = mockSession(); - when(closedSession.writeWithOptions(mutations)).thenThrow(sessionNotFound); + when(closedSession.writeWithOptions(eq(mutations), any())).thenThrow(sessionNotFound); final SessionImpl openSession = mockSession(); com.google.cloud.spanner.CommitResponse response = mock(com.google.cloud.spanner.CommitResponse.class); when(response.getCommitTimestamp()).thenReturn(Timestamp.now()); - when(openSession.writeWithOptions(mutations)).thenReturn(response); + when(openSession.writeWithOptions(eq(mutations), any())).thenReturn(response); doAnswer( invocation -> { executor.submit( @@ -1690,13 +1690,14 @@ public void testSessionNotFoundWriteAtLeastOnce() { SpannerExceptionFactoryTest.newSessionNotFoundException(sessionName); List mutations = Collections.singletonList(Mutation.newInsertBuilder("FOO").build()); final SessionImpl closedSession = mockSession(); - when(closedSession.writeAtLeastOnceWithOptions(mutations)).thenThrow(sessionNotFound); + when(closedSession.writeAtLeastOnceWithOptions(eq(mutations), any())) + .thenThrow(sessionNotFound); final SessionImpl openSession = mockSession(); com.google.cloud.spanner.CommitResponse response = mock(com.google.cloud.spanner.CommitResponse.class); when(response.getCommitTimestamp()).thenReturn(Timestamp.now()); - when(openSession.writeAtLeastOnceWithOptions(mutations)).thenReturn(response); + when(openSession.writeAtLeastOnceWithOptions(eq(mutations), any())).thenReturn(response); doAnswer( invocation -> { executor.submit( @@ -1732,10 +1733,10 @@ public void testSessionNotFoundPartitionedUpdate() { SpannerExceptionFactoryTest.newSessionNotFoundException(sessionName); Statement statement = Statement.of("UPDATE FOO SET BAR=1 WHERE 1=1"); final SessionImpl closedSession = mockSession(); - when(closedSession.executePartitionedUpdate(statement)).thenThrow(sessionNotFound); + when(closedSession.executePartitionedUpdate(eq(statement), any())).thenThrow(sessionNotFound); final SessionImpl openSession = mockSession(); - when(openSession.executePartitionedUpdate(statement)).thenReturn(1L); + when(openSession.executePartitionedUpdate(eq(statement), any())).thenReturn(1L); doAnswer( invocation -> { executor.submit(