diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java index c6aa195544e..54d65936fab 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderClientTransport.java @@ -148,31 +148,33 @@ public synchronized void onUnbound(Status reason) { @Override public synchronized Runnable start(Listener clientTransportListener) { this.clientTransportListener = checkNotNull(clientTransportListener); - return () -> { - synchronized (BinderClientTransport.this) { - if (inState(TransportState.NOT_STARTED)) { - setState(TransportState.SETUP); - try { - if (preAuthorizeServer) { - preAuthorize(serviceBinding.resolve()); - } else { - serviceBinding.bind(); - } - } catch (StatusException e) { - shutdownInternal(e.getStatus(), true); - return; - } - if (readyTimeoutMillis >= 0) { - readyTimeoutFuture = - getScheduledExecutorService() - .schedule( - BinderClientTransport.this::onReadyTimeout, - readyTimeoutMillis, - MILLISECONDS); - } - } + return this::postStartRunnable; + } + + private synchronized void postStartRunnable() { + if (!inState(TransportState.NOT_STARTED)) { + return; + } + + setState(TransportState.SETUP); + + try { + if (preAuthorizeServer) { + preAuthorize(serviceBinding.resolve()); + } else { + serviceBinding.bind(); } - }; + } catch (StatusException e) { + shutdownInternal(e.getStatus(), true); + return; + } + + if (readyTimeoutMillis >= 0) { + readyTimeoutFuture = + getScheduledExecutorService() + .schedule( + BinderClientTransport.this::onReadyTimeout, readyTimeoutMillis, MILLISECONDS); + } } @GuardedBy("this") @@ -204,13 +206,16 @@ public void onFailure(Throwable t) { } private synchronized void handlePreAuthResult(Status authorization) { - if (inState(TransportState.SETUP)) { - if (!authorization.isOk()) { - shutdownInternal(authorization, true); - } else { - serviceBinding.bind(); - } + if (!inState(TransportState.SETUP)) { + return; + } + + if (!authorization.isOk()) { + shutdownInternal(authorization, true); + return; } + + serviceBinding.bind(); } private synchronized void onReadyTimeout() { @@ -252,17 +257,17 @@ public synchronized ClientStream newStream( Status failure = Status.INTERNAL.withDescription("Clashing call IDs"); shutdownInternal(failure, true); return newFailingClientStream(failure, attributes, headers, tracers); + } + + if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { + clientTransportListener.transportInUse(true); + } + Outbound.ClientOutbound outbound = + new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); + if (method.getType().clientSendsOneMessage()) { + return new SingleMessageClientStream(inbound, outbound, attributes); } else { - if (inbound.countsForInUse() && numInUseStreams.getAndIncrement() == 0) { - clientTransportListener.transportInUse(true); - } - Outbound.ClientOutbound outbound = - new Outbound.ClientOutbound(this, callId, method, headers, statsTraceContext); - if (method.getType().clientSendsOneMessage()) { - return new SingleMessageClientStream(inbound, outbound, attributes); - } else { - return new MultiMessageClientStream(inbound, outbound, attributes); - } + return new MultiMessageClientStream(inbound, outbound, attributes); } } @@ -314,39 +319,46 @@ void notifyTerminated() { @Override @GuardedBy("this") protected void handleSetupTransport(Parcel parcel) { - int remoteUid = Binder.getCallingUid(); - if (inState(TransportState.SETUP)) { - int version = parcel.readInt(); - IBinder binder = parcel.readStrongBinder(); - if (version != WIRE_FORMAT_VERSION) { - shutdownInternal(Status.UNAVAILABLE.withDescription("Wire format version mismatch"), true); - } else if (binder == null) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Malformed SETUP_TRANSPORT data"), true); - } else if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { - shutdownInternal( - Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); - } else { - restrictIncomingBinderToCallsFrom(remoteUid); - attributes = setSecurityAttrs(attributes, remoteUid); - ListenableFuture authResultFuture = - register(checkServerAuthorizationAsync(remoteUid)); - Futures.addCallback( - authResultFuture, - new FutureCallback() { - @Override - public void onSuccess(Status result) { - handleAuthResult(result); - } - - @Override - public void onFailure(Throwable t) { - handleAuthResult(t); - } - }, - offloadExecutor); - } + if (!inState(TransportState.SETUP)) { + return; + } + + int version = parcel.readInt(); + if (version != WIRE_FORMAT_VERSION) { + shutdownInternal(Status.UNAVAILABLE.withDescription("Wire format version mismatch"), true); + return; + } + + IBinder binder = parcel.readStrongBinder(); + if (binder == null) { + shutdownInternal(Status.UNAVAILABLE.withDescription("Malformed SETUP_TRANSPORT data"), true); + return; + } + + if (!setOutgoingBinder(OneWayBinderProxy.wrap(binder, offloadExecutor))) { + shutdownInternal( + Status.UNAVAILABLE.withDescription("Failed to observe outgoing binder"), true); + return; } + + int remoteUid = Binder.getCallingUid(); + restrictIncomingBinderToCallsFrom(remoteUid); + attributes = setSecurityAttrs(attributes, remoteUid); + ListenableFuture authResultFuture = register(checkServerAuthorizationAsync(remoteUid)); + Futures.addCallback( + authResultFuture, + new FutureCallback() { + @Override + public void onSuccess(Status result) { + handleAuthResult(result); + } + + @Override + public void onFailure(Throwable t) { + handleAuthResult(t); + } + }, + offloadExecutor); } private ListenableFuture checkServerAuthorizationAsync(int remoteUid) { @@ -356,18 +368,21 @@ private ListenableFuture checkServerAuthorizationAsync(int remoteUid) { } private synchronized void handleAuthResult(Status authorization) { - if (inState(TransportState.SETUP)) { - if (!authorization.isOk()) { - shutdownInternal(authorization, true); - } else { - setState(TransportState.READY); - attributes = clientTransportListener.filterTransport(attributes); - clientTransportListener.transportReady(); - if (readyTimeoutFuture != null) { - readyTimeoutFuture.cancel(false); - readyTimeoutFuture = null; - } - } + if (!inState(TransportState.SETUP)) { + return; + } + + if (!authorization.isOk()) { + shutdownInternal(authorization, true); + return; + } + + setState(TransportState.READY); + attributes = clientTransportListener.filterTransport(attributes); + clientTransportListener.transportReady(); + if (readyTimeoutFuture != null) { + readyTimeoutFuture.cancel(false); + readyTimeoutFuture = null; } } diff --git a/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java b/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java index cbe64149e15..44909f9c0bc 100644 --- a/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java +++ b/binder/src/main/java/io/grpc/binder/internal/BinderServerTransport.java @@ -69,15 +69,22 @@ public BinderServerTransport( */ public synchronized void start(ServerTransportListener serverTransportListener) { this.listenerPromise.set(serverTransportListener); - if (!isShutdown()) { - sendSetupTransaction(); - // Check we're not shutdown again, since a failure inside sendSetupTransaction (or a callback - // it triggers), could have shut us down. - if (!isShutdown()) { - setState(TransportState.READY); - attributes = serverTransportListener.transportReady(attributes); - } + if (isShutdown()) { + // It's unlikely, but we could be shutdown externally between construction and start(). One + // possible cause is an extremely short handshake timeout. + return; } + + sendSetupTransaction(); + + // Check we're not shutdown again, since a failure inside sendSetupTransaction (or a callback + // it triggers), could have shut us down. + if (isShutdown()) { + return; + } + + setState(TransportState.READY); + attributes = serverTransportListener.transportReady(attributes); } StatsTraceContext createStatsTraceContext(String methodName, Metadata headers) { @@ -92,10 +99,10 @@ StatsTraceContext createStatsTraceContext(String methodName, Metadata headers) { synchronized Status startStream(ServerStream stream, String methodName, Metadata headers) { if (isShutdown()) { return Status.UNAVAILABLE.withDescription("transport is shutdown"); - } else { - listenerPromise.get().streamCreated(stream, methodName, headers); - return Status.OK; } + + listenerPromise.get().streamCreated(stream, methodName, headers); + return Status.OK; } @Override