Skip to content

Commit

Permalink
Fix observation scope handlings in grpc server instrumentation (#3836)
Browse files Browse the repository at this point in the history
Prior to this change, when grpc server receives requests in parallel,
wrong scopes might have been used.
This change ensures the proper scope is used for dispatched callback
methods.

Fixes gh-3805
  • Loading branch information
ttddyy committed May 17, 2023
1 parent 4c76caa commit 081fd49
Show file tree
Hide file tree
Showing 6 changed files with 245 additions and 34 deletions.
2 changes: 1 addition & 1 deletion config/checkstyle/checkstyle.xml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
<!-- Imports -->
<module name="IllegalImportCheck" >
<property name="id" value="GeneralIllegalImportCheck"/>
<property name="illegalPkgs" value="com.google.common.(?!cache).*,org.apache.commons.text.*,org.jetbrains.*,jdk.internal.jline.internal.*,reactor.util.annotation.*,org.checkerframework.checker.*,javax.ws.*"/>
<property name="illegalPkgs" value="com.google.common.(?![cache|concurrent]).*,org.apache.commons.text.*,org.jetbrains.*,jdk.internal.jline.internal.*,reactor.util.annotation.*,org.checkerframework.checker.*,javax.ws.*"/>
<property name="illegalClasses" value="org\.assertj\.core\.api\.Java6Assertions\..*,javax.annotation.Nullable"/>
<property name="regexp" value="true"/>
</module>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.grpc.Status;
import io.micrometer.core.instrument.binder.grpc.GrpcObservationDocumentation.GrpcServerEvents;
import io.micrometer.observation.Observation;
import io.micrometer.observation.Observation.Scope;

/**
* A simple forwarding server call for {@link Observation}.
Expand All @@ -31,33 +30,27 @@
*/
class ObservationGrpcServerCall<ReqT, RespT> extends SimpleForwardingServerCall<ReqT, RespT> {

private final Scope scope;
private final Observation observation;

ObservationGrpcServerCall(ServerCall<ReqT, RespT> delegate, Scope scope) {
ObservationGrpcServerCall(ServerCall<ReqT, RespT> delegate, Observation observation) {
super(delegate);
this.scope = scope;
this.observation = observation;
}

@Override
public void sendMessage(RespT message) {
this.scope.getCurrentObservation().event(GrpcServerEvents.MESSAGE_SENT);
this.observation.event(GrpcServerEvents.MESSAGE_SENT);
super.sendMessage(message);
}

@Override
public void close(Status status, Metadata trailers) {
Observation observation = this.scope.getCurrentObservation();

if (status.getCause() != null) {
observation.error(status.getCause());
this.observation.error(status.getCause());
}

GrpcServerObservationContext context = (GrpcServerObservationContext) observation.getContext();
GrpcServerObservationContext context = (GrpcServerObservationContext) this.observation.getContext();
context.setStatusCode(status.getCode());

this.scope.close();
observation.stop();

super.close(status, trailers);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,37 +25,57 @@
* A simple forwarding client call listener for {@link Observation}.
*
* @param <RespT> type of message received one or more times from the server.
* @see io.grpc.Contexts
*/
class ObservationGrpcServerCallListener<RespT> extends SimpleForwardingServerCallListener<RespT> {

private final Scope scope;
private final Observation observation;

ObservationGrpcServerCallListener(Listener<RespT> delegate, Scope scope) {
ObservationGrpcServerCallListener(Listener<RespT> delegate, Observation observation) {
super(delegate);
this.scope = scope;
this.observation = observation;
}

@Override
public void onMessage(RespT message) {
this.scope.getCurrentObservation().event(GrpcServerEvents.MESSAGE_RECEIVED);
super.onMessage(message);
this.observation.event(GrpcServerEvents.MESSAGE_RECEIVED);
try (Scope scope = observation.openScope()) {
super.onMessage(message);
}
}

@Override
public void onHalfClose() {
try {
try (Scope scope = observation.openScope()) {
super.onHalfClose();
}
catch (Throwable ex) {
handleFailure(ex);
throw ex;
}

@Override
public void onCancel() {
try (Scope scope = this.observation.openScope()) {
super.onCancel();
}
finally {
this.observation.stop();
}
}

@Override
public void onComplete() {
try (Scope scope = this.observation.openScope()) {
super.onComplete();
}
finally {
this.observation.stop();
}
}

private void handleFailure(Throwable ex) {
Observation observation = this.scope.getCurrentObservation();
this.scope.close();
observation.error(ex).stop();
@Override
public void onReady() {
try (Scope scope = this.observation.openScope()) {
super.onReady();
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import io.grpc.ServerCall.Listener;
import io.micrometer.common.lang.Nullable;
import io.micrometer.observation.Observation;
import io.micrometer.observation.Observation.Scope;
import io.micrometer.observation.ObservationRegistry;

import java.util.Map;
Expand Down Expand Up @@ -88,23 +87,22 @@ public <ReqT, RespT> Listener<ReqT> interceptCall(ServerCall<ReqT, RespT> call,
return context;
};

Observation observation = GrpcObservationDocumentation.SERVER.observation(this.customConvention,
DEFAULT_CONVENTION, contextSupplier, this.registry);
Observation observation = GrpcObservationDocumentation.SERVER
.observation(this.customConvention, DEFAULT_CONVENTION, contextSupplier, this.registry)
.start();

if (observation.isNoop()) {
// do not instrument anymore
return next.startCall(call, headers);
}

Scope scope = observation.start().openScope();
ObservationGrpcServerCall<ReqT, RespT> serverCall = new ObservationGrpcServerCall<>(call, scope);
ObservationGrpcServerCall<ReqT, RespT> serverCall = new ObservationGrpcServerCall<>(call, observation);

try {
Listener<ReqT> result = next.startCall(serverCall, headers);
return new ObservationGrpcServerCallListener<>(result, scope);
return new ObservationGrpcServerCallListener<>(result, observation);
}
catch (Exception ex) {
scope.close();
observation.error(ex).stop();
throw ex;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
/*
* Copyright 2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.micrometer.core.instrument.binder.grpc;

import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.ManagedChannel;
import io.grpc.Metadata;
import io.grpc.Server;
import io.grpc.inprocess.InProcessChannelBuilder;
import io.grpc.inprocess.InProcessServerBuilder;
import io.grpc.stub.MetadataUtils;
import io.grpc.stub.StreamObserver;
import io.grpc.testing.protobuf.SimpleRequest;
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceFutureStub;
import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceImplBase;
import io.micrometer.observation.Observation.Context;
import io.micrometer.observation.ObservationHandler;
import io.micrometer.observation.ObservationRegistry;
import io.micrometer.observation.ObservationTextPublisher;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;

import static org.assertj.core.api.Assertions.assertThat;
import static org.awaitility.Awaitility.await;

/**
* @author Tadaya Tsuyukubo
*/
class GrpcAsyncTest {

static final Metadata.Key<String> REQUEST_ID_KEY = Metadata.Key.of("request-id", Metadata.ASCII_STRING_MARSHALLER);

Server server;

ManagedChannel channel;

@BeforeEach
void setUp() throws Exception {
ObservationRegistry observationRegistry = ObservationRegistry.create();
observationRegistry.observationConfig()
.observationHandler(new ObservationTextPublisher())
.observationHandler(new StoreRequestIdInScopeObservationHandler());

this.server = InProcessServerBuilder.forName("sample")
.addService(new MyService())
.intercept(new ObservationGrpcServerInterceptor(observationRegistry))
.build();
this.server.start();

this.channel = InProcessChannelBuilder.forName("sample")
.intercept(new ObservationGrpcClientInterceptor(observationRegistry))
.build();
}

@AfterEach
void cleanUp() {
if (this.channel != null) {
this.channel.shutdownNow();
}
if (this.server != null) {
this.server.shutdownNow();
}
}

@Test
void simulate_trace_in_async_requests() {
// Send requests asynchronously with request-id in metadata.
// The request-id is stored in threadlocal in server when scope is opened.
// The main logic retrieves the request-id from threadlocal and include it as
// part of the response message.
// This simulates a tracer with span.
SimpleServiceFutureStub stub = SimpleServiceGrpc.newFutureStub(this.channel);
Map<ListenableFuture<SimpleResponse>, String> requestIds = new HashMap<>();
List<ListenableFuture<SimpleResponse>> futures = new ArrayList<>();
int max = 40;
for (int i = 0; i < max; i++) {
String message = "Hello-" + i;
SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(message).build();

String requestId = "req-" + i;
Metadata metadata = new Metadata();
metadata.put(REQUEST_ID_KEY, requestId);
ListenableFuture<SimpleResponse> future = stub
.withInterceptors(MetadataUtils.newAttachHeadersInterceptor(metadata))
.unaryRpc(request);

requestIds.put(future, requestId);
futures.add(future);
}
await().until(() -> futures.stream().allMatch(Future::isDone));
assertThat(futures).allSatisfy((future) -> {
// Make sure the request-id in the response message matches with the one sent
// to server.
String expectedRequestId = requestIds.get(future);
assertThat(future.get().getResponseMessage()).contains("request-id=" + expectedRequestId);
});
}

static class StoreRequestIdInScopeObservationHandler implements ObservationHandler<GrpcServerObservationContext> {

@Override
public boolean supportsContext(Context context) {
return context instanceof GrpcServerObservationContext;
}

@Override
public void onScopeOpened(GrpcServerObservationContext context) {
String requestId = context.getCarrier().get(REQUEST_ID_KEY);
assertThat(requestId).isNotNull();
MyService.requestIdHolder.set(requestId);
}

@Override
public void onScopeClosed(GrpcServerObservationContext context) {
MyService.requestIdHolder.remove();
}

}

static class MyService extends SimpleServiceImplBase {

static ThreadLocal<String> requestIdHolder = new ThreadLocal<>();

@Override
public void unaryRpc(SimpleRequest request, StreamObserver<SimpleResponse> responseObserver) {
StringBuilder sb = new StringBuilder();
sb.append("message=");
sb.append(request.getRequestMessage());
sb.append(",request-id=");
sb.append(requestIdHolder.get());
sb.append(",thread=");
sb.append(Thread.currentThread().getId());

SimpleResponse response = SimpleResponse.newBuilder().setResponseMessage(sb.toString()).build();
responseObserver.onNext(response);
responseObserver.onCompleted();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/
package io.micrometer.core.instrument.binder.grpc;

import com.google.common.util.concurrent.FutureCallback;
import com.google.common.util.concurrent.Futures;
import com.google.common.util.concurrent.ListenableFuture;
import io.grpc.ManagedChannel;
import io.grpc.MethodDescriptor.MethodType;
import io.grpc.Server;
Expand All @@ -27,6 +30,7 @@
import io.grpc.testing.protobuf.SimpleResponse;
import io.grpc.testing.protobuf.SimpleServiceGrpc;
import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceBlockingStub;
import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceFutureStub;
import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceImplBase;
import io.grpc.testing.protobuf.SimpleServiceGrpc.SimpleServiceStub;
import io.micrometer.common.lang.Nullable;
Expand All @@ -48,6 +52,8 @@

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -137,6 +143,36 @@ void unaryRpc() {
GrpcClientEvents.MESSAGE_RECEIVED);
}

@Test
void unaryRpcAsync() {
SimpleServiceFutureStub stub = SimpleServiceGrpc.newFutureStub(channel);
List<String> messages = new ArrayList<>();
List<String> responses = new ArrayList<>();
List<ListenableFuture<SimpleResponse>> futures = new ArrayList<>();
int count = 40;
for (int i = 0; i < count; i++) {
String message = "Hello-" + i;
messages.add(message);
SimpleRequest request = SimpleRequest.newBuilder().setRequestMessage(message).build();
ListenableFuture<SimpleResponse> future = stub.unaryRpc(request);
Futures.addCallback(future, new FutureCallback<>() {
@Override
public void onSuccess(SimpleResponse result) {
responses.add(result.getResponseMessage());
}

@Override
public void onFailure(Throwable t) {

}
}, Executors.newCachedThreadPool());
futures.add(stub.unaryRpc(request));
}

await().until(() -> futures.stream().allMatch(Future::isDone));
assertThat(responses).hasSize(count).containsExactlyInAnyOrderElementsOf(messages);
}

@Test
void clientStreamingRpc() {
SimpleServiceStub asyncStub = SimpleServiceGrpc.newStub(channel);
Expand Down

0 comments on commit 081fd49

Please sign in to comment.