Skip to content

Commit

Permalink
Allow null Endpoint / Add init()
Browse files Browse the repository at this point in the history
  • Loading branch information
trustin committed Jul 23, 2019
1 parent af49ff6 commit 97d0d24
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 167 deletions.
Expand Up @@ -160,7 +160,11 @@ static ClientRequestContext of(RpcRequest request, URI uri) {

/**
* Returns the remote {@link Endpoint} of the current {@link Request}.
*
* @return the remote {@link Endpoint}. {@code null} if the {@link Request} has failed
* before its remote {@link Endpoint} is determined.
*/
@Nullable
Endpoint endpoint();

/**
Expand Down
Expand Up @@ -106,8 +106,9 @@ public ClientRequestContext build() {
}

final DefaultClientRequestContext ctx = new DefaultClientRequestContext(
eventLoop(), meterRegistry(), sessionProtocol(), endpoint,
eventLoop(), meterRegistry(), sessionProtocol(),
method(), path(), query(), fragment, options, request());
ctx.init(endpoint);

if (isRequestStartTimeSet()) {
ctx.logBuilder().startRequest(fakeChannel(), sessionProtocol(), sslSession(),
Expand Down
Expand Up @@ -16,6 +16,7 @@

package com.linecorp.armeria.client;

import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

import java.time.Duration;
Expand All @@ -40,6 +41,7 @@
import com.linecorp.armeria.common.logging.RequestLog;
import com.linecorp.armeria.common.logging.RequestLogAvailability;
import com.linecorp.armeria.common.logging.RequestLogBuilder;
import com.linecorp.armeria.common.stream.StreamMessage;

import io.micrometer.core.instrument.MeterRegistry;
import io.netty.buffer.ByteBufAllocator;
Expand All @@ -58,7 +60,8 @@ public class DefaultClientRequestContext extends NonWrappingRequestContext imple
private final EventLoop eventLoop;
private final ClientOptions options;
@Nullable
private final EndpointSelector endpointSelector;
private EndpointSelector endpointSelector;
@Nullable
private Endpoint endpoint;
@Nullable
private final String fragment;
Expand All @@ -82,27 +85,14 @@ public class DefaultClientRequestContext extends NonWrappingRequestContext imple
* @param request the request associated with this context
*/
public DefaultClientRequestContext(
EventLoop eventLoop, MeterRegistry meterRegistry,
SessionProtocol sessionProtocol, Endpoint endpoint,
EventLoop eventLoop, MeterRegistry meterRegistry, SessionProtocol sessionProtocol,
HttpMethod method, String path, @Nullable String query, @Nullable String fragment,
ClientOptions options, Request request) {

super(meterRegistry, sessionProtocol, method, path, query, request);

this.eventLoop = requireNonNull(eventLoop, "eventLoop");
this.options = requireNonNull(options, "options");
requireNonNull(endpoint, "endpoint");
if (endpoint.isGroup()) {
final String groupName = endpoint.groupName();
endpointSelector = EndpointGroupRegistry.getNodeSelector(groupName);
if (endpointSelector == null) {
throw new EndpointGroupException("non-existent " + EndpointGroup.class.getSimpleName() +
": " + groupName);
}
} else {
endpointSelector = null;
}
this.endpoint = endpoint;
this.fragment = fragment;

log = new DefaultRequestLog(this, options.requestContentPreviewerFactory(),
Expand All @@ -116,28 +106,68 @@ public DefaultClientRequestContext(
if (!headers.isEmpty()) {
additionalRequestHeaders = headers;
}

runThreadLocalContextCustomizer();
if (endpointSelector != null) {
// NB: EndpointSelector.select() must be called after the thread-local context customizer is run
// because the customizer might set an attribute which may be accessed by the EndpointSelector.
this.endpoint = endpointSelector.select(this);
}
}

private HttpHeaders createAdditionalHeadersIfAbsent() {
final HttpHeaders additionalRequestHeaders = this.additionalRequestHeaders;
if (additionalRequestHeaders == null) {
return this.additionalRequestHeaders = HttpHeaders.of();
} else {
return additionalRequestHeaders;
/**
* Initializes this context with the specified {@link Endpoint}.
* This method must be invoked to finish the construction of this context.
*
* @return {@code true} if the initialization has succeeded.
* {@code false} if the initialization has failed and this context's {@link RequestLog} has been
* completed with the cause of the failure.
*/
public boolean init(Endpoint endpoint) {
assert this.endpoint == null : this.endpoint;
try {
if (endpoint.isGroup()) {
final String groupName = endpoint.groupName();
final EndpointSelector endpointSelector =
EndpointGroupRegistry.getNodeSelector(groupName);
if (endpointSelector == null) {
throw new EndpointGroupException(
"non-existent " + EndpointGroup.class.getSimpleName() + ": " + groupName);
}

this.endpointSelector = endpointSelector;
// Note: thread-local customizer must be run before EndpointSelector.select()
// so that the customizer can inject the attributes which may be required
// by the EndpointSelector.
runThreadLocalContextCustomizer();
this.endpoint = endpointSelector.select(this);
} else {
endpointSelector = null;
this.endpoint = endpoint;
runThreadLocalContextCustomizer();
}

return true;
} catch (Exception e) {
failEarly(e);
}

return false;
}

private void runThreadLocalContextCustomizer() {
final Consumer<ClientRequestContext> customizer = THREAD_LOCAL_CONTEXT_CUSTOMIZER.get();
if (customizer != null) {
customizer.accept(this);
try {
customizer.accept(this);
} catch (Exception e) {
failEarly(e);
}
}
}

private void failEarly(Exception cause) {
final RequestLogBuilder logBuilder = logBuilder();
final UnprocessedRequestException wrapped = new UnprocessedRequestException(cause);
logBuilder.endRequest(wrapped);
logBuilder.endResponse(wrapped);

final Request req = request();
if (req instanceof StreamMessage) {
((StreamMessage<?>) req).abort();
}
}

Expand All @@ -147,7 +177,7 @@ private DefaultClientRequestContext(DefaultClientRequestContext ctx, Request req
eventLoop = ctx.eventLoop();
options = ctx.options();
endpointSelector = ctx.endpointSelector();
this.endpoint = endpoint;
this.endpoint = requireNonNull(endpoint, "endpoint");
fragment = ctx.fragment();

log = new DefaultRequestLog(this, options.requestContentPreviewerFactory(),
Expand Down Expand Up @@ -176,12 +206,13 @@ private <T> void addAttr(Attribute<?> attribute) {

@Override
public ClientRequestContext newDerivedContext() {
return newDerivedContext(request(), endpoint());
return newDerivedContext(request());
}

@Override
public ClientRequestContext newDerivedContext(Request request) {
return newDerivedContext(request, endpoint());
checkState(endpoint != null, "endpoint not available");
return newDerivedContext(request, endpoint);
}

@Override
Expand Down Expand Up @@ -318,6 +349,15 @@ public void addAdditionalRequestHeaders(Iterable<? extends Entry<? extends CharS
additionalRequestHeaders = createAdditionalHeadersIfAbsent().toBuilder().addObject(headers).build();
}

private HttpHeaders createAdditionalHeadersIfAbsent() {
final HttpHeaders additionalRequestHeaders = this.additionalRequestHeaders;
if (additionalRequestHeaders == null) {
return this.additionalRequestHeaders = HttpHeaders.of();
} else {
return additionalRequestHeaders;
}
}

@Override
public boolean removeAdditionalRequestHeader(CharSequence name) {
requireNonNull(name, "name");
Expand Down Expand Up @@ -361,7 +401,7 @@ public String toString() {
buf.append('[')
.append(sessionProtocol().uriText())
.append("://")
.append(endpoint.authority())
.append(endpoint != null ? endpoint.authority() : "<unknown>")
.append(path())
.append('#')
.append(method())
Expand Down
Expand Up @@ -24,7 +24,6 @@
import com.linecorp.armeria.common.HttpRequest;
import com.linecorp.armeria.common.HttpResponse;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.logging.RequestLogAvailability;
import com.linecorp.armeria.internal.PathAndQuery;

import io.micrometer.core.instrument.MeterRegistry;
Expand Down Expand Up @@ -59,14 +58,7 @@ private HttpResponse execute(@Nullable EventLoop eventLoop, HttpRequest req) {
}

return execute(eventLoop, newReq.method(), pathAndQuery.path(), pathAndQuery.query(), null, newReq,
(ctx, cause) -> {
if (ctx != null && !ctx.log().isAvailable(RequestLogAvailability.REQUEST_START)) {
// An exception is raised even before sending a request, so abort the request to
// release the elements.
newReq.abort();
}
return HttpResponse.ofFailure(cause);
});
(ctx, cause) -> HttpResponse.ofFailure(cause));
}

@Override
Expand Down
Expand Up @@ -41,6 +41,11 @@

final class HttpClientDelegate implements Client<HttpRequest, HttpResponse> {

private static final Throwable CONTEXT_INITIALIZATION_FAILED = new Exception(
ClientRequestContext.class.getSimpleName() + " initialization failed", null, false, false) {
private static final long serialVersionUID = 837901495421033459L;
};

private final HttpClientFactory factory;
private final AddressResolverGroup<InetSocketAddress> addressResolverGroup;

Expand All @@ -52,46 +57,52 @@ final class HttpClientDelegate implements Client<HttpRequest, HttpResponse> {

@Override
public HttpResponse execute(ClientRequestContext ctx, HttpRequest req) throws Exception {
final Endpoint endpoint = ctx.endpoint();
if (endpoint == null) {
return HttpResponse.ofFailure(CONTEXT_INITIALIZATION_FAILED);
}

if (!isValidPath(req)) {
final IllegalArgumentException cause = new IllegalArgumentException("invalid path: " + req.path());
handleEarlyRequestException(ctx, req, cause);
return HttpResponse.ofFailure(cause);
}

final Endpoint endpoint = ctx.endpoint().withDefaultPort(ctx.sessionProtocol().defaultPort());
final Endpoint endpointWithPort = endpoint.withDefaultPort(ctx.sessionProtocol().defaultPort());
final EventLoop eventLoop = ctx.eventLoop();
final DecodedHttpResponse res = new DecodedHttpResponse(eventLoop);

final ClientConnectionTimingsBuilder timingsBuilder = new ClientConnectionTimingsBuilder();

if (endpoint.hasIpAddr()) {
if (endpointWithPort.hasIpAddr()) {
// IP address has been resolved already.
acquireConnectionAndExecute(ctx, endpoint, endpoint.ipAddr(), req, res, timingsBuilder);
acquireConnectionAndExecute(ctx, endpointWithPort, endpointWithPort.ipAddr(),
req, res, timingsBuilder);
} else {
// IP address has not been resolved yet.
final Future<InetSocketAddress> resolveFuture =
addressResolverGroup.getResolver(eventLoop)
.resolve(InetSocketAddress.createUnresolved(endpoint.host(),
endpoint.port()));
.resolve(InetSocketAddress.createUnresolved(endpointWithPort.host(),
endpointWithPort.port()));
if (resolveFuture.isDone()) {
finishResolve(ctx, endpoint, resolveFuture, req, res, timingsBuilder);
finishResolve(ctx, endpointWithPort, resolveFuture, req, res, timingsBuilder);
} else {
resolveFuture.addListener(
(FutureListener<InetSocketAddress>) future ->
finishResolve(ctx, endpoint, future, req, res, timingsBuilder));
finishResolve(ctx, endpointWithPort, future, req, res, timingsBuilder));
}
}

return res;
}

private void finishResolve(ClientRequestContext ctx, Endpoint endpoint,
private void finishResolve(ClientRequestContext ctx, Endpoint endpointWithPort,
Future<InetSocketAddress> resolveFuture, HttpRequest req,
DecodedHttpResponse res, ClientConnectionTimingsBuilder timingsBuilder) {
timingsBuilder.dnsResolutionEnd();
if (resolveFuture.isSuccess()) {
final String ipAddr = resolveFuture.getNow().getAddress().getHostAddress();
acquireConnectionAndExecute(ctx, endpoint, ipAddr, req, res, timingsBuilder);
acquireConnectionAndExecute(ctx, endpointWithPort, ipAddr, req, res, timingsBuilder);
} else {
timingsBuilder.build().setTo(ctx);
final Throwable cause = resolveFuture.cause();
Expand All @@ -100,18 +111,18 @@ private void finishResolve(ClientRequestContext ctx, Endpoint endpoint,
}
}

private void acquireConnectionAndExecute(ClientRequestContext ctx, Endpoint endpoint, String ipAddr,
HttpRequest req, DecodedHttpResponse res,
private void acquireConnectionAndExecute(ClientRequestContext ctx, Endpoint endpointWithPort,
String ipAddr, HttpRequest req, DecodedHttpResponse res,
ClientConnectionTimingsBuilder timingsBuilder) {
final EventLoop eventLoop = ctx.eventLoop();
if (!eventLoop.inEventLoop()) {
eventLoop.execute(() -> acquireConnectionAndExecute(ctx, endpoint, ipAddr,
eventLoop.execute(() -> acquireConnectionAndExecute(ctx, endpointWithPort, ipAddr,
req, res, timingsBuilder));
return;
}

final String host = extractHost(ctx, req, endpoint);
final int port = endpoint.port();
final String host = extractHost(ctx, req, endpointWithPort);
final int port = endpointWithPort.port();
final SessionProtocol protocol = ctx.sessionProtocol();
final HttpChannelPool pool = factory.pool(ctx.eventLoop());

Expand Down
29 changes: 14 additions & 15 deletions core/src/main/java/com/linecorp/armeria/client/UserClient.java
Expand Up @@ -16,7 +16,7 @@

package com.linecorp.armeria.client;

import static com.linecorp.armeria.internal.ClientUtil.createContextAndExecute;
import static com.linecorp.armeria.internal.ClientUtil.initContextAndExecuteWithFallback;

import java.net.URI;
import java.util.function.BiFunction;
Expand Down Expand Up @@ -143,19 +143,18 @@ protected final O execute(HttpMethod method, String path, @Nullable String query
protected final O execute(@Nullable EventLoop eventLoop,
HttpMethod method, String path, @Nullable String query, @Nullable String fragment,
I req, BiFunction<ClientRequestContext, Throwable, O> fallback) {
return createContextAndExecute(delegate(), () -> {
final ClientRequestContext ctx;
if (eventLoop == null) {
final ReleasableHolder<EventLoop> releasableEventLoop = factory().acquireEventLoop(endpoint);
ctx = new DefaultClientRequestContext(
releasableEventLoop.get(), meterRegistry, sessionProtocol, endpoint,
method, path, query, fragment, options(), req);
ctx.log().addListener(log -> releasableEventLoop.release(), RequestLogAvailability.COMPLETE);
} else {
ctx = new DefaultClientRequestContext(eventLoop, meterRegistry, sessionProtocol, endpoint,
method, path, query, fragment, options(), req);
}
return ctx;
}, fallback).response();
final DefaultClientRequestContext ctx;
if (eventLoop == null) {
final ReleasableHolder<EventLoop> releasableEventLoop = factory().acquireEventLoop(endpoint);
ctx = new DefaultClientRequestContext(
releasableEventLoop.get(), meterRegistry, sessionProtocol,
method, path, query, fragment, options(), req);
ctx.log().addListener(log -> releasableEventLoop.release(), RequestLogAvailability.COMPLETE);
} else {
ctx = new DefaultClientRequestContext(eventLoop, meterRegistry, sessionProtocol,
method, path, query, fragment, options(), req);
}

return initContextAndExecuteWithFallback(delegate(), ctx, endpoint, fallback);
}
}
Expand Up @@ -85,7 +85,9 @@ public interface KeySelector<K> {
KeySelector<String> HOST =
(ctx, req) -> {
final Endpoint endpoint = ctx.endpoint();
if (endpoint.isGroup()) {
if (endpoint == null) {
return "UNKNOWN";
} else if (endpoint.isGroup()) {
return endpoint.authority();
} else {
final String ipAddr = endpoint.ipAddr();
Expand Down

0 comments on commit 97d0d24

Please sign in to comment.