Skip to content

Commit

Permalink
Improve the default redirect handler robustness and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vietj committed Mar 1, 2017
1 parent 64442c8 commit 9f5a599
Show file tree
Hide file tree
Showing 6 changed files with 144 additions and 46 deletions.
Expand Up @@ -130,7 +130,7 @@ public synchronized void onPushPromiseRead(ChannelHandlerContext ctx, int stream
MultiMap headersMap = new Http2HeadersAdaptor(headers);
Http2Stream promisedStream = handler.connection().stream(promisedStreamId);
int port = remoteAddress().port();
HttpClientRequestPushPromise pushReq = new HttpClientRequestPushPromise(this, promisedStream, http2Pool.client, method, rawMethod, uri, host, port, headersMap);
HttpClientRequestPushPromise pushReq = new HttpClientRequestPushPromise(this, promisedStream, http2Pool.client, isSsl(), method, rawMethod, uri, host, port, headersMap);
if (metrics.isEnabled()) {
pushReq.metric(metrics.responsePushed(queueMetric, metric(), localAddress(), remoteAddress(), pushReq));
}
Expand Down
75 changes: 37 additions & 38 deletions src/main/java/io/vertx/core/http/impl/HttpClientImpl.java
Expand Up @@ -64,43 +64,42 @@
public class HttpClientImpl implements HttpClient, MetricsProvider {

private final Function<HttpClientResponse, Future<HttpClientRequest>> DEFAULT_HANDLER = resp -> {
int statusCode = resp.statusCode();
String location = resp.getHeader(HttpHeaders.LOCATION);
if (location != null && (statusCode == 301 || statusCode == 302 || statusCode == 303 || statusCode == 307)) {
HttpMethod m = resp.request().method();
if (statusCode == 301 || statusCode == 302 || statusCode == 303) {
m = HttpMethod.GET;
}
URI uri;
try {
uri = HttpUtils.resolveURIReference(resp.request().absoluteURI(), location);
} catch (URISyntaxException e) {
return null;
}
boolean ssl;
int port = uri.getPort();
String protocol = uri.getScheme();
char chend = protocol.charAt(protocol.length() - 1);
if (chend == 'p') {
ssl = false;
if (port == -1) {
port = 80;
try {
int statusCode = resp.statusCode();
String location = resp.getHeader(HttpHeaders.LOCATION);
if (location != null && (statusCode == 301 || statusCode == 302 || statusCode == 303 || statusCode == 307)) {
HttpMethod m = resp.request().method();
if (statusCode == 301 || statusCode == 302 || statusCode == 303) {
m = HttpMethod.GET;
}
URI uri = HttpUtils.resolveURIReference(resp.request().absoluteURI(), location);
boolean ssl;
int port = uri.getPort();
String protocol = uri.getScheme();
char chend = protocol.charAt(protocol.length() - 1);
if (chend == 'p') {
ssl = false;
if (port == -1) {
port = 80;
}
} else if (chend == 's') {
ssl = true;
if (port == -1) {
port = 443;
}
} else {
return null;
}
} else if (chend == 's') {
ssl = true;
if (port == -1) {
port = 443;
String requestURI = uri.getPath();
if (uri.getQuery() != null) {
requestURI += "?" + uri.getQuery();
}
} else {
return null;
}
String requestURI = uri.getPath();
if (uri.getQuery() != null) {
requestURI += "?" + uri.getQuery();
return Future.succeededFuture(createRequest(m, uri.getHost(), port, ssl, requestURI, null));
}
return Future.succeededFuture(doRequest(m, uri.getHost(), port, ssl, requestURI, null));
return null;
} catch (Exception e) {
return Future.failedFuture(e);
}
return null;
};

private static final Logger log = LoggerFactory.getLogger(HttpClientImpl.class);
Expand Down Expand Up @@ -457,12 +456,12 @@ public HttpClientRequest requestAbs(HttpMethod method, String absoluteURI) {
port = 443;
}
}
return doRequest(method, url.getHost(), port, ssl, url.getFile(), null);
return createRequest(method, url.getHost(), port, ssl, url.getFile(), null);
}

@Override
public HttpClientRequest request(HttpMethod method, int port, String host, String requestURI) {
return doRequest(method, host, port, null, requestURI, null);
return createRequest(method, host, port, null, requestURI, null);
}

@Override
Expand All @@ -472,7 +471,7 @@ public HttpClientRequest request(HttpMethod method, RequestOptions options, Hand

@Override
public HttpClientRequest request(HttpMethod method, RequestOptions options) {
return doRequest(method, options.getHost(), options.getPort(), options.isSsl(), options.getURI(), null);
return createRequest(method, options.getHost(), options.getPort(), options.isSsl(), options.getURI(), null);
}

@Override
Expand Down Expand Up @@ -936,11 +935,11 @@ private URL parseUrl(String surl) {
}

private HttpClient requestNow(HttpMethod method, RequestOptions options, Handler<HttpClientResponse> responseHandler) {
doRequest(method, options.getHost(), options.getPort(), options.isSsl(), options.getURI(), null).handler(responseHandler).end();
createRequest(method, options.getHost(), options.getPort(), options.isSsl(), options.getURI(), null).handler(responseHandler).end();
return this;
}

private HttpClientRequest doRequest(HttpMethod method, String host, int port, Boolean ssl, String relativeURI, MultiMap headers) {
private HttpClientRequest createRequest(HttpMethod method, String host, int port, Boolean ssl, String relativeURI, MultiMap headers) {
Objects.requireNonNull(method, "no null method accepted");
Objects.requireNonNull(host, "no null host accepted");
Objects.requireNonNull(relativeURI, "no null relativeURI accepted");
Expand Down
Expand Up @@ -18,6 +18,7 @@

import io.vertx.core.Handler;
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.logging.Logger;
import io.vertx.core.logging.LoggerFactory;

Expand All @@ -37,21 +38,23 @@ abstract class HttpClientRequestBase implements HttpClientRequest {
protected final String host;
protected final int port;
protected final String query;
protected final boolean ssl;
private Handler<Throwable> exceptionHandler;
private long currentTimeoutTimerId = -1;
private long currentTimeoutMs;
private long lastDataReceived;
protected Throwable exceptionOccurred;
private Object metric;

HttpClientRequestBase(HttpClientImpl client, io.vertx.core.http.HttpMethod method, String host, int port, String uri) {
HttpClientRequestBase(HttpClientImpl client, boolean ssl, HttpMethod method, String host, int port, String uri) {
this.client = client;
this.uri = uri;
this.method = method;
this.host = host;
this.port = port;
this.path = uri.length() > 0 ? HttpUtils.parsePath(uri) : "";
this.query = HttpUtils.parseQuery(uri);
this.ssl = ssl;
}

Object metric() {
Expand All @@ -67,7 +70,6 @@ void metric(Object metric) {
protected abstract void checkComplete();

protected String hostHeader() {
boolean ssl = client.getOptions().isSsl();
if ((port == 80 && !ssl) || (port == 443 && ssl)) {
return host;
} else {
Expand All @@ -77,7 +79,7 @@ protected String hostHeader() {

@Override
public String absoluteURI() {
return (client.getOptions().isSsl() ? "https://" : "http://") + hostHeader() + uri;
return (ssl ? "https://" : "http://") + hostHeader() + uri;
}

public String query() {
Expand Down
Expand Up @@ -52,7 +52,6 @@ public class HttpClientRequestImpl extends HttpClientRequestBase implements Http

private final VertxInternal vertx;
private final int port;
private final boolean ssl;
private Handler<HttpClientResponse> respHandler;
private Handler<Void> endHandler;
private boolean chunked;
Expand Down Expand Up @@ -80,8 +79,7 @@ public class HttpClientRequestImpl extends HttpClientRequestBase implements Http

HttpClientRequestImpl(HttpClientImpl client, boolean ssl, HttpMethod method, String host, int port,
String relativeURI, VertxInternal vertx) {
super(client, method, host, port, relativeURI);
this.ssl = ssl;
super(client, ssl, method, host, port, relativeURI);
this.chunked = false;
this.vertx = vertx;
this.port = port;
Expand Down
Expand Up @@ -43,13 +43,14 @@ public HttpClientRequestPushPromise(
Http2ClientConnection conn,
Http2Stream stream,
HttpClientImpl client,
boolean ssl,
HttpMethod method,
String rawMethod,
String uri,
String host,
int port,
MultiMap headers) throws Http2Exception {
super(client, method, host, port, uri);
super(client, ssl, method, host, port, uri);
this.conn = conn;
this.stream = new Http2ClientConnection.Http2ClientStream(conn, this, stream, false);
this.rawMethod = rawMethod;
Expand Down
98 changes: 98 additions & 0 deletions src/test/java/io/vertx/test/core/HttpTest.java
Expand Up @@ -19,6 +19,7 @@
import io.netty.handler.codec.http.DefaultHttpHeaders;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.util.internal.logging.InternalLoggerFactory;
import io.vertx.codegen.annotations.Nullable;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.AsyncResult;
import io.vertx.core.Context;
Expand All @@ -34,6 +35,7 @@
import io.vertx.core.http.HttpClientRequest;
import io.vertx.core.http.HttpClientResponse;
import io.vertx.core.http.HttpConnection;
import io.vertx.core.http.HttpFrame;
import io.vertx.core.http.HttpHeaders;
import io.vertx.core.http.HttpMethod;
import io.vertx.core.http.HttpServerOptions;
Expand All @@ -42,6 +44,7 @@
import io.vertx.core.http.impl.HeadersAdaptor;
import io.vertx.core.impl.EventLoopContext;
import io.vertx.core.impl.WorkerContext;
import io.vertx.core.net.NetSocket;
import io.vertx.test.netty.TestLoggerFactory;
import org.junit.Assume;
import org.junit.Rule;
Expand Down Expand Up @@ -70,6 +73,7 @@
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.IntStream;

import static io.vertx.test.core.TestUtils.assertIllegalArgumentException;
Expand Down Expand Up @@ -3503,6 +3507,100 @@ public void testFollowRedirectWithCustomHandler() throws Exception {
await();
}

@Test
public void testDefaultRedirectHandler() throws Exception {
testFoo("http://example.com", "http://example.com");
testFoo("http://example.com/somepath", "http://example.com/somepath");
testFoo("http://example.com:8000", "http://example.com:8000");
testFoo("http://example.com:8000/somepath", "http://example.com:8000/somepath");
testFoo("https://example.com", "https://example.com");
testFoo("https://example.com/somepath", "https://example.com/somepath");
testFoo("https://example.com:8000", "https://example.com:8000");
testFoo("https://example.com:8000/somepath", "https://example.com:8000/somepath");
testFoo("whatever://example.com", null);
testFoo("http://", null);
testFoo("http://:8080/somepath", null);
}

private void testFoo(String location, String expected) throws Exception {
int status = 301;
Map<String, String> headers = Collections.singletonMap("Location", location);
HttpMethod method = HttpMethod.GET;
String baseURI = "https://localhost:8080";
class MockReq implements HttpClientRequest {
public HttpClientRequest exceptionHandler(Handler<Throwable> handler) { throw new UnsupportedOperationException(); }
public HttpClientRequest write(Buffer data) { throw new UnsupportedOperationException(); }
public HttpClientRequest setWriteQueueMaxSize(int maxSize) { throw new UnsupportedOperationException(); }
public HttpClientRequest drainHandler(Handler<Void> handler) { throw new UnsupportedOperationException(); }
public HttpClientRequest handler(Handler<HttpClientResponse> handler) { throw new UnsupportedOperationException(); }
public HttpClientRequest pause() { throw new UnsupportedOperationException(); }
public HttpClientRequest resume() { throw new UnsupportedOperationException(); }
public HttpClientRequest endHandler(Handler<Void> endHandler) { throw new UnsupportedOperationException(); }
public HttpClientRequest setFollowRedirects(boolean followRedirects) { throw new UnsupportedOperationException(); }
public HttpClientRequest setChunked(boolean chunked) { throw new UnsupportedOperationException(); }
public boolean isChunked() { return false; }
public HttpMethod method() { return method; }
public String getRawMethod() { throw new UnsupportedOperationException(); }
public HttpClientRequest setRawMethod(String method) { throw new UnsupportedOperationException(); }
public String absoluteURI() { return baseURI; }
public String uri() { throw new UnsupportedOperationException(); }
public String path() { throw new UnsupportedOperationException(); }
public String query() { throw new UnsupportedOperationException(); }
public HttpClientRequest setHost(String host) { throw new UnsupportedOperationException(); }
public String getHost() { throw new UnsupportedOperationException(); }
public MultiMap headers() { throw new UnsupportedOperationException(); }
public HttpClientRequest putHeader(String name, String value) { throw new UnsupportedOperationException(); }
public HttpClientRequest putHeader(CharSequence name, CharSequence value) { throw new UnsupportedOperationException(); }
public HttpClientRequest putHeader(String name, Iterable<String> values) { throw new UnsupportedOperationException(); }
public HttpClientRequest putHeader(CharSequence name, Iterable<CharSequence> values) { throw new UnsupportedOperationException(); }
public HttpClientRequest write(String chunk) { throw new UnsupportedOperationException(); }
public HttpClientRequest write(String chunk, String enc) { throw new UnsupportedOperationException(); }
public HttpClientRequest continueHandler(@Nullable Handler<Void> handler) { throw new UnsupportedOperationException(); }
public HttpClientRequest sendHead() { throw new UnsupportedOperationException(); }
public HttpClientRequest sendHead(Handler<HttpVersion> completionHandler) { throw new UnsupportedOperationException(); }
public void end(String chunk) { throw new UnsupportedOperationException(); }
public void end(String chunk, String enc) { throw new UnsupportedOperationException(); }
public void end(Buffer chunk) { throw new UnsupportedOperationException(); }
public void end() { throw new UnsupportedOperationException(); }
public HttpClientRequest setTimeout(long timeoutMs) { throw new UnsupportedOperationException(); }
public HttpClientRequest pushHandler(Handler<HttpClientRequest> handler) { throw new UnsupportedOperationException(); }
public boolean reset(long code) { return false; }
public HttpConnection connection() { throw new UnsupportedOperationException(); }
public HttpClientRequest connectionHandler(@Nullable Handler<HttpConnection> handler) { throw new UnsupportedOperationException(); }
public HttpClientRequest writeCustomFrame(int type, int flags, Buffer payload) { throw new UnsupportedOperationException(); }
public boolean writeQueueFull() { throw new UnsupportedOperationException(); }
}
HttpClientRequest req = new MockReq();
class MockResp implements HttpClientResponse {
public HttpClientResponse resume() { throw new UnsupportedOperationException(); }
public HttpClientResponse exceptionHandler(Handler<Throwable> handler) { throw new UnsupportedOperationException(); }
public HttpClientResponse handler(Handler<Buffer> handler) { throw new UnsupportedOperationException(); }
public HttpClientResponse pause() { throw new UnsupportedOperationException(); }
public HttpClientResponse endHandler(Handler<Void> endHandler) { throw new UnsupportedOperationException(); }
public HttpVersion version() { throw new UnsupportedOperationException(); }
public int statusCode() { return status; }
public String statusMessage() { throw new UnsupportedOperationException(); }
public MultiMap headers() { throw new UnsupportedOperationException(); }
public String getHeader(String headerName) { return headers.get(headerName); }
public String getHeader(CharSequence headerName) { return getHeader(headerName.toString()); }
public String getTrailer(String trailerName) { throw new UnsupportedOperationException(); }
public MultiMap trailers() { throw new UnsupportedOperationException(); }
public List<String> cookies() { throw new UnsupportedOperationException(); }
public HttpClientResponse bodyHandler(Handler<Buffer> bodyHandler) { throw new UnsupportedOperationException(); }
public HttpClientResponse customFrameHandler(Handler<HttpFrame> handler) { throw new UnsupportedOperationException(); }
public NetSocket netSocket() { throw new UnsupportedOperationException(); }
public HttpClientRequest request() { return req; }
}
MockResp resp = new MockResp();
Function<HttpClientResponse, Future<HttpClientRequest>> handler = client.redirectHandler();
Future<HttpClientRequest> redirection = handler.apply(resp);
if (expected != null) {
assertEquals(location, redirection.result().absoluteURI());
} else {
assertTrue(redirection == null || redirection.failed());
}
}

@Test
public void testServerResponseCloseHandlerNotHoldingLock() throws Exception {
server.requestHandler(req -> {
Expand Down

0 comments on commit 9f5a599

Please sign in to comment.