Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Relax the validation of Location header when redirecting #5477

Merged
merged 6 commits into from
Mar 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,13 @@ ClientRequestContext newDerivedContext(RequestId id, @Nullable HttpRequest req,
@UnstableApi
String authority();

/**
* Returns the host part of {@link #authority()}, without a port number.
*/
@Nullable
@UnstableApi
String host();

/**
* Returns the {@link URI} constructed based on {@link ClientRequestContext#sessionProtocol()},
* {@link ClientRequestContext#authority()}, {@link ClientRequestContext#path()} and
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@
return unwrap().authority();
}

@Override
public String host() {
return unwrap().host();

Check warning on line 77 in core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/client/ClientRequestContextWrapper.java#L77

Added line #L77 was not covered by tests
}

@Override
public URI uri() {
return unwrap().uri();
Expand Down
233 changes: 169 additions & 64 deletions core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,16 @@
import static com.linecorp.armeria.internal.client.ClientUtil.executeWithFallback;
import static com.linecorp.armeria.internal.client.RedirectingClientUtil.allowAllDomains;
import static com.linecorp.armeria.internal.client.RedirectingClientUtil.allowSameDomain;
import static com.linecorp.armeria.internal.common.ArmeriaHttpUtil.findAuthority;

import java.net.URI;
import java.net.URISyntaxException;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.function.BiPredicate;
import java.util.function.Function;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.LinkedListMultimap;
import com.google.common.collect.Multimap;
Expand All @@ -48,8 +49,9 @@
import com.linecorp.armeria.common.HttpStatus;
import com.linecorp.armeria.common.RequestHeaders;
import com.linecorp.armeria.common.RequestHeadersBuilder;
import com.linecorp.armeria.common.RequestTarget;
import com.linecorp.armeria.common.RequestTargetForm;
import com.linecorp.armeria.common.ResponseHeaders;
import com.linecorp.armeria.common.Scheme;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.logging.RequestLogBuilder;
Expand All @@ -70,6 +72,8 @@
private static final Set<SessionProtocol> httpAndHttps =
Sets.immutableEnumSet(SessionProtocol.HTTP, SessionProtocol.HTTPS);

private static final Splitter pathSplitter = Splitter.on('/');

static Function<? super HttpClient, RedirectingClient> newDecorator(
ClientBuilderParams params, RedirectConfig redirectConfig) {
final boolean undefinedUri = Clients.isUndefinedUri(params.uri());
Expand Down Expand Up @@ -212,42 +216,54 @@
}

final RequestHeaders requestHeaders = log.requestHeaders();
final URI redirectUri;
try {
redirectUri = URI.create(requestHeaders.path()).resolve(location);
if (redirectUri.isAbsolute()) {
final SessionProtocol redirectProtocol = Scheme.parse(redirectUri.getScheme())
.sessionProtocol();
if (!allowedProtocols.contains(redirectProtocol)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedProtocolRedirectException.of(
redirectProtocol, allowedProtocols));
return;
}

if (!domainFilter.test(ctx, redirectUri.getHost())) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedDomainRedirectException.of(redirectUri.getHost()));
return;
}
}
} catch (Throwable t) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, t);
// Resolve the actual redirect location.
final RequestTarget nextReqTarget = resolveLocation(ctx, location);
if (nextReqTarget == null) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,

Check warning on line 223 in core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java#L223

Added line #L223 was not covered by tests
new IllegalArgumentException("Invalid redirect location: " + location));
return;
}

final HttpRequestDuplicator newReqDuplicator =
newReqDuplicator(reqDuplicator, responseHeaders, requestHeaders, redirectUri);
final String nextScheme = nextReqTarget.scheme();
final String nextAuthority = nextReqTarget.authority();
final String nextHost = nextReqTarget.host();
assert nextReqTarget.form() == RequestTargetForm.ABSOLUTE &&
nextScheme != null && nextAuthority != null && nextHost != null
: "resolveLocation() must return an absolute request target: " + nextReqTarget;

final String redirectFullUri;
try {
redirectFullUri = buildFullUri(ctx, redirectUri, newReqDuplicator.headers());
// Reject if:
// 1) the protocol is not same with the original one; and
// 2) the protocol is not in the allow-list.
final SessionProtocol nextProtocol = SessionProtocol.of(nextScheme);
if (ctx.sessionProtocol() != nextProtocol &&
!allowedProtocols.contains(nextProtocol)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedProtocolRedirectException.of(
nextProtocol, allowedProtocols));
return;
}

// Reject if:
// 1) the host is not same with the original one; and
// 2) the host does not pass the domain filter.
if (!nextHost.equals(ctx.host()) &&
!domainFilter.test(ctx, nextHost)) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
UnexpectedDomainRedirectException.of(nextHost));
return;
}
} catch (Throwable t) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response, t);
return;
}

if (isCyclicRedirects(redirectCtx, redirectFullUri, newReqDuplicator.headers())) {
final HttpRequestDuplicator newReqDuplicator =
newReqDuplicator(reqDuplicator, responseHeaders, requestHeaders,
nextReqTarget.toString(), nextAuthority);

if (isCyclicRedirects(redirectCtx, nextReqTarget.toString(), newReqDuplicator.headers().method())) {
handleException(ctx, derivedCtx, reqDuplicator, responseFuture, response,
CyclicRedirectsException.of(redirectCtx.originalUri(),
redirectCtx.redirectUris().values()));
Expand All @@ -274,17 +290,127 @@
});
}

@Nullable
@VisibleForTesting
static RequestTarget resolveLocation(ClientRequestContext ctx, String location) {
final long length = location.length();
assert length > 0;

final String resolvedUri;
if (location.charAt(0) == '/') {
if (length > 1 && location.charAt(1) == '/') {
// No scheme, e.g. //foo.com/bar
resolvedUri = ctx.sessionProtocol().uriText() + ':' + location;
} else {
// No scheme, no authority, e.g. /bar
resolvedUri = ctx.sessionProtocol().uriText() + "://" + ctx.authority() + location;
}
} else {
final int authorityIdx = findAuthority(location);
if (authorityIdx < 0) {
// A relative path, e.g. ./bar
resolvedUri = resolveRelativeLocation(ctx, location);
if (resolvedUri == null) {
return null;
}
} else {
// A full absolute URI, e.g. http://foo.com/bar
// Note that we should normalize an explicit scheme such as `h1c` into `http` or `https`,
// because otherwise a potentially malicious peer can force us to use inefficient protocols
// like HTTP/1.
final SessionProtocol proto = SessionProtocol.find(location.substring(0, authorityIdx - 3));
if (proto != null) {
switch (proto) {
case HTTP:
case HTTPS:
resolvedUri = location;
break;
default:
if (proto.isHttp()) {
resolvedUri = "http://" + location.substring(authorityIdx);
} else if (proto.isHttps()) {
resolvedUri = "https://" + location.substring(authorityIdx);
} else {
return null;

Check warning on line 334 in core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java

View check run for this annotation

Codecov / codecov/patch

core/src/main/java/com/linecorp/armeria/client/RedirectingClient.java#L334

Added line #L334 was not covered by tests
}
}
} else {
// Unknown scheme.
return null;
}
}
}

return RequestTarget.forClient(resolvedUri);
}

@Nullable
private static String resolveRelativeLocation(ClientRequestContext ctx, String location) {
final String originalPath = ctx.path();

// Find the base path, e.g.
// - /foo -> /
// - /foo/ -> /foo/
// - /foo/bar -> /foo/
final int lastSlashIdx = originalPath.lastIndexOf('/');
assert lastSlashIdx >= 0 : "originalPath doesn't contain a slash: " + originalPath;

// Generate the full path.
final String fullPath = originalPath.substring(0, lastSlashIdx + 1) + location;
final Iterator<String> it = pathSplitter.split(fullPath).iterator();
// Splitter will always emit an empty string as the first component, so we skip it.
assert it.hasNext() && it.next().isEmpty() : fullPath;

// Resolve `.` and `..` from the full path.
try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) {
final StringBuilder buf = tmp.stringBuilder();
buf.append(ctx.sessionProtocol().uriText()).append("://").append(ctx.authority());
final int authorityEndIdx = buf.length();
while (it.hasNext()) {
final String component = it.next();
switch (component) {
case ".":
if (!it.hasNext()) {
// Append '/' only when the '.' is the last component, e.g. /foo/. -> /foo/
buf.append('/');
}
break;
case "..":
final int idx = buf.lastIndexOf("/");
if (idx < authorityEndIdx) {
// Too few parents
return null;
}
if (it.hasNext()) {
// Don't keep the '/' because the next component will add it anyway,
// e.g. /foo/../bar -> /bar
buf.delete(idx, buf.length());
} else {
// Keep the last '/' if the '..' is the last component,
// e.g. /foo/bar/.. -> /foo/
buf.delete(idx + 1, buf.length());
}
break;
default:
buf.append('/').append(component);
break;
}
}

return buf.toString();
}
}

private static HttpRequestDuplicator newReqDuplicator(HttpRequestDuplicator reqDuplicator,
ResponseHeaders responseHeaders,
RequestHeaders requestHeaders, URI newUri) {
RequestHeaders requestHeaders,
String nextUri,
String nextAuthority) {

final RequestHeadersBuilder builder = requestHeaders.toBuilder();
builder.path(newUri.toString());
final String newAuthority = newUri.getAuthority();
if (newAuthority != null) {
// Update the old authority with the new one because the request is redirected to a different
// domain.
builder.authority(newAuthority);
}
builder.path(nextUri);
builder.authority(nextAuthority);

final HttpMethod method = requestHeaders.method();
if (responseHeaders.status() == HttpStatus.SEE_OTHER &&
!(method == HttpMethod.GET || method == HttpMethod.HEAD)) {
Expand Down Expand Up @@ -343,36 +469,15 @@
}
}

private static String buildFullUri(ClientRequestContext ctx, URI redirectUri, RequestHeaders newHeaders)
throws URISyntaxException {
// Build the full uri so we don't consider the situation, which session protocol or port is changed,
// as a cyclic redirects.
if (redirectUri.isAbsolute()) {
if (redirectUri.getPort() > 0) {
return redirectUri.toString();
}
final int port;
if (redirectUri.getScheme().startsWith("https")) {
port = SessionProtocol.HTTPS.defaultPort();
} else {
port = SessionProtocol.HTTP.defaultPort();
}
return new URI(redirectUri.getScheme(), redirectUri.getRawUserInfo(), redirectUri.getHost(), port,
redirectUri.getRawPath(), redirectUri.getRawQuery(), redirectUri.getRawFragment())
.toString();
}
return buildUri(ctx, newHeaders);
}

private static boolean isCyclicRedirects(RedirectContext redirectCtx, String redirectUri,
RequestHeaders newHeaders) {
final boolean added = redirectCtx.addRedirectUri(newHeaders.method(), redirectUri);
private static boolean isCyclicRedirects(RedirectContext redirectCtx,
String redirectUri, HttpMethod method) {
final boolean added = redirectCtx.addRedirectUri(method, redirectUri);
if (!added) {
return true;
}

return redirectCtx.originalUri().equals(redirectUri) &&
redirectCtx.request().method() == newHeaders.method();
redirectCtx.request().method() == method;
}

private static String buildUri(ClientRequestContext ctx, RequestHeaders headers) {
Expand All @@ -391,15 +496,15 @@
if (authority == null) {
authority = endpoint.authority();
}
setAuthorityAndPort(ctx, endpoint, sb, authority);
appendAuthority(ctx, endpoint, sb, authority);
sb.append(headers.path());
originalUri = sb.toString();
}
return originalUri;
}

private static void setAuthorityAndPort(ClientRequestContext ctx, Endpoint endpoint, StringBuilder sb,
String authority) {
private static void appendAuthority(ClientRequestContext ctx, Endpoint endpoint, StringBuilder sb,
String authority) {
// Add port number as well so that we don't raise a CyclicRedirectsException when the port is
// different.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ protected AbstractRequestContextBuilder(boolean server, RpcRequest rpcReq, URI u
this.reqTarget = reqTarget;
} else {
reqTarget = DefaultRequestTarget.createWithoutValidation(
RequestTargetForm.ORIGIN, null, null,
RequestTargetForm.ORIGIN, null, null, null, -1,
uri.getRawPath(), uri.getRawPath(), uri.getRawQuery(), uri.getRawFragment());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@

import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.common.util.StringUtil;
import com.linecorp.armeria.internal.common.util.TemporaryThreadLocals;

import io.netty.util.AsciiString;

Expand Down Expand Up @@ -214,18 +215,17 @@ URI uri() {
checkState(scheme != null, ":scheme header does not exist.");
final String authority = authority();

final StringBuilder sb = new StringBuilder(
scheme.length() + 1 +
(authority != null ? (authority.length() + 2) : 0) +
path.length());
sb.append(scheme);
sb.append(':');
if (authority != null) {
sb.append("//");
sb.append(authority);
try (TemporaryThreadLocals tmp = TemporaryThreadLocals.acquire()) {
final StringBuilder sb = tmp.stringBuilder();
sb.append(scheme);
sb.append(':');
if (authority != null) {
sb.append("//");
sb.append(authority);
}
sb.append(path);
uri = sb.toString();
}
sb.append(path);
uri = sb.toString();
}

try {
Expand Down