diff --git a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java index e3399a0b473..fed94c7f799 100644 --- a/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java +++ b/binder/src/androidTest/java/io/grpc/binder/BinderChannelSmokeTest.java @@ -22,6 +22,8 @@ import static java.util.concurrent.TimeUnit.SECONDS; import android.content.Context; +import android.os.Parcel; +import android.os.Parcelable; import androidx.test.core.app.ApplicationProvider; import androidx.test.ext.junit.runners.AndroidJUnit4; import com.google.common.io.ByteStreams; @@ -29,16 +31,26 @@ import com.google.common.util.concurrent.ListenableFuture; import com.google.common.util.concurrent.SettableFuture; import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptors; +import io.grpc.ConnectivityState; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; import io.grpc.ManagedChannel; import io.grpc.Metadata; import io.grpc.MethodDescriptor; import io.grpc.NameResolverRegistry; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; import io.grpc.ServerInterceptors; import io.grpc.ServerServiceDefinition; +import io.grpc.Status.Code; +import io.grpc.StatusRuntimeException; import io.grpc.internal.GrpcUtil; import io.grpc.internal.testing.FakeNameResolverProvider; import io.grpc.stub.ClientCalls; +import io.grpc.stub.MetadataUtils; import io.grpc.stub.ServerCalls; import io.grpc.stub.StreamObserver; import io.grpc.testing.TestUtils; @@ -49,6 +61,7 @@ import java.util.concurrent.Executors; import java.util.concurrent.atomic.AtomicReference; import org.junit.After; +import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; @@ -64,6 +77,8 @@ public final class BinderChannelSmokeTest { private static final int SLIGHTLY_MORE_THAN_ONE_BLOCK = 16 * 1024 + 100; private static final String MSG = "Some text which will be repeated many many times"; private static final String SERVER_TARGET_URI = "fake://server"; + private static final Metadata.Key POISON_KEY = ParcelableUtils.metadataKey( + "poison-bin", PoisonParcelable.CREATOR); final MethodDescriptor method = MethodDescriptor.newBuilder(StringMarshaller.INSTANCE, StringMarshaller.INSTANCE) @@ -87,6 +102,7 @@ public final class BinderChannelSmokeTest { ManagedChannel channel; AtomicReference headersCapture = new AtomicReference<>(); AtomicReference clientUidCapture = new AtomicReference<>(); + PoisonParcelable parcelableForResponseHeaders; @Before public void setUp() throws Exception { @@ -116,6 +132,7 @@ public void setUp() throws Exception { .addMethod(singleLargeResultMethod, singleLargeResultCallHandler) .addMethod(bidiMethod, bidiCallHandler) .build(), + new AddParcelableServerInterceptor(), TestUtils.recordRequestHeadersInterceptor(headersCapture), PeerUids.newPeerIdentifyingServerInterceptor()); @@ -124,13 +141,20 @@ public void setUp() throws Exception { NameResolverRegistry.getDefaultRegistry().register(fakeNameResolverProvider); HostServices.configureService(serverAddress, HostServices.serviceParamsBuilder() - .setServerFactory((service, receiver) -> - BinderServerBuilder.forAddress(serverAddress, receiver) - .addService(serviceDef) - .build()) - .build()); - - channel = BinderChannelBuilder.forAddress(serverAddress, appContext).build(); + .setServerFactory((service, receiver) -> + BinderServerBuilder.forAddress(serverAddress, receiver) + .inboundParcelablePolicy(InboundParcelablePolicy.newBuilder() + .setAcceptParcelableMetadataValues(true) + .build()) + .addService(serviceDef) + .build()) + .build()); + + channel = BinderChannelBuilder.forAddress(serverAddress, appContext) + .inboundParcelablePolicy(InboundParcelablePolicy.newBuilder() + .setAcceptParcelableMetadataValues(true) + .build()) + .build(); } @After @@ -209,6 +233,42 @@ public void testConnectViaTargetUri() throws Exception { assertThat(doCall("Hello").get()).isEqualTo("Hello"); } + @Test + public void testUncaughtServerException() throws Exception { + // Use a poison parcelable to cause an unexpected Exception in the server's onTransact(). + PoisonParcelable bad = new PoisonParcelable(); + Metadata extraHeadersToSend = new Metadata(); + extraHeadersToSend.put(POISON_KEY, bad); + ClientCall call = + ClientInterceptors.intercept(channel, + MetadataUtils.newAttachHeadersInterceptor(extraHeadersToSend)) + .newCall(method, CallOptions.DEFAULT.withDeadlineAfter(5, SECONDS)); + try { + ClientCalls.blockingUnaryCall(call, "hello"); + Assert.fail(); + } catch (StatusRuntimeException e) { + // We don't care how *our* RPC failed, but make sure we didn't have to rely on the deadline. + assertThat(e.getStatus().getCode()).isNotEqualTo(Code.DEADLINE_EXCEEDED); + assertThat(channel.getState(false)).isEqualTo(ConnectivityState.IDLE); + } + } + + @Test + public void testUncaughtClientException() throws Exception { + // Use a poison parcelable to cause an unexpected Exception in the client's onTransact(). + parcelableForResponseHeaders = new PoisonParcelable(); + ClientCall call = channel + .newCall(method, CallOptions.DEFAULT.withDeadlineAfter(5, SECONDS)); + try { + ClientCalls.blockingUnaryCall(call, "hello"); + Assert.fail(); + } catch (StatusRuntimeException e) { + // We don't care *how* our RPC failed, but make sure we didn't have to rely on the deadline. + assertThat(e.getStatus().getCode()).isNotEqualTo(Code.DEADLINE_EXCEEDED); + assertThat(channel.getState(false)).isEqualTo(ConnectivityState.IDLE); + } + } + private static String createLargeString(int size) { StringBuilder sb = new StringBuilder(); while (sb.length() < size) { @@ -286,4 +346,44 @@ public void onCompleted() { delegate.onCompleted(); } } + + class AddParcelableServerInterceptor implements ServerInterceptor { + @Override + public Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + return next.startCall(new SimpleForwardingServerCall(call) { + @Override + public void sendHeaders(Metadata headers) { + if (parcelableForResponseHeaders != null) { + headers.put(POISON_KEY, parcelableForResponseHeaders); + } + super.sendHeaders(headers); + } + }, headers); + } + } + + static class PoisonParcelable implements Parcelable { + + public static final Creator CREATOR = new Parcelable.Creator() { + @Override + public PoisonParcelable createFromParcel(Parcel parcel) { + throw new RuntimeException("ouch"); + } + + @Override + public PoisonParcelable[] newArray(int n) { + return new PoisonParcelable[n]; + } + }; + + @Override + public int describeContents() { + return 0; + } + + @Override + public void writeToParcel(Parcel parcel, int flags) { + } + } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java index bdcd53a9ea6..70b89165174 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderTransport.java @@ -415,6 +415,22 @@ final void sendOutOfBandClose(int callId, Status status) { @Override public final boolean handleTransaction(int code, Parcel parcel) { + try { + return handleTransactionInternal(code, parcel); + } catch (RuntimeException e) { + logger.log(Level.SEVERE, + "Terminating transport for uncaught Exception in transaction " + code, e); + synchronized (this) { + // This unhandled exception may have put us in an inconsistent state. Force terminate the + // whole transport so our peer knows something is wrong and so that clients can retry with + // a fresh transport instance on both sides. + shutdownInternal(Status.INTERNAL.withCause(e), true); + return false; + } + } + } + + private boolean handleTransactionInternal(int code, Parcel parcel) { if (code < FIRST_CALL_ID) { synchronized (this) { switch (code) {