Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] Fix unconsumed parameter exception when authenticating with jwtUrlParameter #4065

Merged
merged 2 commits into from
Feb 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/
package org.opensearch.security.http;

import java.security.KeyPair;
import java.util.Base64;
import java.util.List;
import java.util.Map;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import org.apache.http.Header;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;

import org.opensearch.test.framework.JwtConfigBuilder;
import org.opensearch.test.framework.TestSecurityConfig;
import org.opensearch.test.framework.cluster.ClusterManager;
import org.opensearch.test.framework.cluster.LocalCluster;
import org.opensearch.test.framework.cluster.TestRestClient;
import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse;
import org.opensearch.test.framework.log.LogsRule;

import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.security.Keys;

import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.apache.http.HttpHeaders.AUTHORIZATION;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.BASIC_AUTH_DOMAIN_ORDER;
import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS;

@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class JwtAuthenticationWithUrlParamTests {

public static final String CLAIM_USERNAME = "preferred-username";
public static final String CLAIM_ROLES = "backend-user-roles";
public static final String POINTER_USERNAME = "/user_name";

private static final KeyPair KEY_PAIR = Keys.keyPairFor(SignatureAlgorithm.RS256);
private static final String PUBLIC_KEY = new String(Base64.getEncoder().encode(KEY_PAIR.getPublic().getEncoded()), US_ASCII);

static final TestSecurityConfig.User ADMIN_USER = new TestSecurityConfig.User("admin").roles(ALL_ACCESS);

private static final String TOKEN_URL_PARAM = "token";

private static final JwtAuthorizationHeaderFactory tokenFactory = new JwtAuthorizationHeaderFactory(
KEY_PAIR.getPrivate(),
CLAIM_USERNAME,
CLAIM_ROLES,
AUTHORIZATION
);

public static final TestSecurityConfig.AuthcDomain JWT_AUTH_DOMAIN = new TestSecurityConfig.AuthcDomain(
"jwt",
BASIC_AUTH_DOMAIN_ORDER - 1
).jwtHttpAuthenticator(
new JwtConfigBuilder().jwtUrlParameter(TOKEN_URL_PARAM).signingKey(PUBLIC_KEY).subjectKey(CLAIM_USERNAME).rolesKey(CLAIM_ROLES)
).backend("noop");

@ClassRule
public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE)
.anonymousAuth(false)
.nodeSettings(
Map.of("plugins.security.restapi.roles_enabled", List.of("user_" + ADMIN_USER.getName() + "__" + ALL_ACCESS.getName()))
)
.authc(AUTHC_HTTPBASIC_INTERNAL)
.authc(JWT_AUTH_DOMAIN)
.users(ADMIN_USER)
.build();

@Rule
public LogsRule logsRule = new LogsRule("com.amazon.dlic.auth.http.jwt.HTTPJwtAuthenticator");

@Test
public void shouldAuthenticateWithJwtTokenInUrl_positive() {
Header jwtToken = tokenFactory.generateValidToken(ADMIN_USER.getName());
String jwtTokenValue = jwtToken.getValue();
try (TestRestClient client = cluster.getRestClient()) {
HttpResponse response = client.getAuthInfo(Map.of(TOKEN_URL_PARAM, jwtTokenValue));

response.assertStatusCode(200);
String username = response.getTextFromJsonBody(POINTER_USERNAME);
assertThat(username, equalTo(ADMIN_USER.getName()));
}
}

@Test
public void testUnauthenticatedRequest() {
try (TestRestClient client = cluster.getRestClient()) {
HttpResponse response = client.getAuthInfo();

response.assertStatusCode(401);
logsRule.assertThatContainExactly(String.format("No JWT token found in '%s' url parameter header", TOKEN_URL_PARAM));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

public class JwtConfigBuilder {
private String jwtHeader;
private String jwtUrlParameter;
private String signingKey;
private String subjectKey;
private String rolesKey;
Expand All @@ -27,6 +28,11 @@ public JwtConfigBuilder jwtHeader(String jwtHeader) {
return this;
}

public JwtConfigBuilder jwtUrlParameter(String jwtUrlParameter) {
this.jwtUrlParameter = jwtUrlParameter;
return this;
}

public JwtConfigBuilder signingKey(String signingKey) {
this.signingKey = signingKey;
return this;
Expand All @@ -51,6 +57,9 @@ public Map<String, Object> build() {
if (isNoneBlank(jwtHeader)) {
builder.put("jwt_header", jwtHeader);
}
if (isNoneBlank(jwtUrlParameter)) {
builder.put("jwt_url_parameter", jwtUrlParameter);
}
if (isNoneBlank(subjectKey)) {
builder.put("subject_key", subjectKey);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -114,6 +115,12 @@ public HttpResponse getAuthInfo(Header... headers) {
return executeRequest(new HttpGet(getHttpServerUri() + "/_opendistro/_security/authinfo?pretty"), headers);
}

public HttpResponse getAuthInfo(Map<String, String> urlParams, Header... headers) {
String urlParamsString = "?"
+ urlParams.entrySet().stream().map(e -> e.getKey() + "=" + e.getValue()).collect(Collectors.joining("&"));
return executeRequest(new HttpGet(getHttpServerUri() + "/_opendistro/_security/authinfo" + urlParamsString), headers);
}

public void confirmCorrectCredentials(String expectedUserName) {
HttpResponse response = getAuthInfo();
assertThat(response, notNullValue());
Expand Down
35 changes: 34 additions & 1 deletion src/main/java/org/opensearch/security/filter/NettyRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import javax.net.ssl.SSLEngine;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;

import org.opensearch.http.netty4.Netty4HttpChannel;
import org.opensearch.rest.RestRequest.Method;
import org.opensearch.rest.RestUtils;
Expand All @@ -34,6 +39,7 @@ public class NettyRequest implements SecurityRequest {

protected final HttpRequest underlyingRequest;
protected final Netty4HttpChannel underlyingChannel;
protected final Supplier<CheckedAccessMap> parameters = Suppliers.memoize(() -> new CheckedAccessMap(params(uri())));

NettyRequest(final HttpRequest request, final Netty4HttpChannel channel) {
this.underlyingRequest = request;
Expand Down Expand Up @@ -78,7 +84,12 @@ public String uri() {

@Override
public Map<String, String> params() {
return params(underlyingRequest.uri());
return parameters.get();
}

@Override
public Set<String> getUnconsumedParams() {
return parameters.get().accessedKeys();
}

private static Map<String, String> params(String uri) {
Expand All @@ -96,4 +107,26 @@ private static Map<String, String> params(String uri) {

return params;
}

/** Records access of any keys if explicitly requested from this map */
private static class CheckedAccessMap extends HashMap<String, String> {
private final Set<String> accessedKeys = new HashSet<>();

public CheckedAccessMap(final Map<String, String> map) {
super(map);
}

@Override
public String get(final Object key) {
// Never noticed this about java's map interface the getter is not generic
if (key instanceof String) {
accessedKeys.add((String) key);
}
return super.get(key);
}

public Set<String> accessedKeys() {
return accessedKeys;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
package org.opensearch.security.filter;

import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import javax.net.ssl.SSLEngine;

import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -71,7 +73,18 @@

@Override
public Map<String, String> params() {
return underlyingRequest.params();
return new HashMap<>(underlyingRequest.params()) {
@Override
public String get(Object key) {
return underlyingRequest.param((String) key);
}
};
}

@Override
public Set<String> getUnconsumedParams() {
// params() Map consumes explict parameter access
return Set.of();

Check warning on line 87 in src/main/java/org/opensearch/security/filter/OpenSearchRequest.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/security/filter/OpenSearchRequest.java#L87

Added line #L87 was not covered by tests
}

/** Gets access to the underlying request object */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import javax.net.ssl.SSLEngine;

Expand Down Expand Up @@ -49,4 +50,7 @@ default String header(final String headerName) {

/** The parameters associated with this request */
Map<String, String> params();

/** The list of parameters that have been accessed but not recorded as being consumed */
Set<String> getUnconsumedParams();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import java.util.Optional;

/**
* When a request is recieved by the security plugin this governs getting information about the request and complete with with a response
* When a request is received by the security plugin this governs getting information about the request and complete with a response
*/
public interface SecurityRequestChannel extends SecurityRequest {

/** Associate a response with this channel */
void queueForSending(final SecurityResponse response);

/** Acess the queued response */
/** Access the queued response */
Optional<SecurityResponse> getQueuedResponse();
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE;
import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE;
import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED;
import static org.opensearch.security.http.SecurityHttpServerTransport.UNCONSUMED_PARAMS;

public class SecurityRestFilter {

Expand Down Expand Up @@ -144,6 +145,13 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
}
});

NettyAttribute.popFrom(request, UNCONSUMED_PARAMS).ifPresent(unconsumedParams -> {
for (String unconsumedParam : unconsumedParams) {
// Consume the parameter on the RestRequest
request.param(unconsumedParam);
}
});

final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(request, channel);

// Authenticate request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

package org.opensearch.security.http;

import java.util.Set;

import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
Expand All @@ -47,6 +49,7 @@
public class SecurityHttpServerTransport extends SecuritySSLNettyHttpServerTransport {

public static final AttributeKey<SecurityResponse> EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response");
public static final AttributeKey<Set<String>> UNCONSUMED_PARAMS = AttributeKey.newInstance("opensearch-http-request-consumed-params");
public static final AttributeKey<ThreadContext.StoredContext> CONTEXT_TO_RESTORE = AttributeKey.newInstance(
"opensearch-http-request-thread-context"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE;
import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED;
import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS;
import static org.opensearch.security.http.SecurityHttpServerTransport.UNCONSUMED_PARAMS;

@Sharable
public class Netty4HttpRequestHeaderVerifier extends SimpleChannelInboundHandler<DefaultHttpRequest> {
Expand Down Expand Up @@ -84,6 +85,8 @@ public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) thro
// If request channel is completed and a response is sent, then there was a failure during authentication
restFilter.checkAndAuthenticateRequest(requestChannel);

ctx.channel().attr(UNCONSUMED_PARAMS).set(requestChannel.getUnconsumedParams());

ThreadContext.StoredContext contextToRestore = threadPool.getThreadContext().newStoredContext(false);
ctx.channel().attr(CONTEXT_TO_RESTORE).set(contextToRestore);

Expand Down
Loading