Skip to content

Commit

Permalink
Fix unconsumed parameter exception when authenticating with jwtUrlPar…
Browse files Browse the repository at this point in the history
…ameter (#3975)

Signed-off-by: Craig Perkins <cwperx@amazon.com>
Signed-off-by: Peter Nied <peternied@hotmail.com>
Signed-off-by: Peter Nied <petern@amazon.com>
Co-authored-by: Peter Nied <petern@amazon.com>
Co-authored-by: Peter Nied <peternied@hotmail.com>
(cherry picked from commit ccea744)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
3 people committed Feb 22, 2024
1 parent c9bbe8d commit b863395
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 4 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*
*/
package org.opensearch.security.http;

import java.security.KeyPair;
import java.util.Base64;
import java.util.List;
import java.util.Map;

import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope;
import org.apache.hc.core5.http.Header;
import org.junit.ClassRule;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;

import org.opensearch.test.framework.JwtConfigBuilder;
import org.opensearch.test.framework.TestSecurityConfig;
import org.opensearch.test.framework.cluster.ClusterManager;
import org.opensearch.test.framework.cluster.LocalCluster;
import org.opensearch.test.framework.cluster.TestRestClient;
import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse;
import org.opensearch.test.framework.log.LogsRule;

import io.jsonwebtoken.SignatureAlgorithm;
import io.jsonwebtoken.security.Keys;

import static java.nio.charset.StandardCharsets.US_ASCII;
import static org.apache.http.HttpHeaders.AUTHORIZATION;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.equalTo;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL;
import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.BASIC_AUTH_DOMAIN_ORDER;
import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS;

@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
@ThreadLeakScope(ThreadLeakScope.Scope.NONE)
public class JwtAuthenticationWithUrlParamTests {

public static final String CLAIM_USERNAME = "preferred-username";
public static final String CLAIM_ROLES = "backend-user-roles";
public static final String POINTER_USERNAME = "/user_name";

private static final KeyPair KEY_PAIR = Keys.keyPairFor(SignatureAlgorithm.RS256);
private static final String PUBLIC_KEY = new String(Base64.getEncoder().encode(KEY_PAIR.getPublic().getEncoded()), US_ASCII);

static final TestSecurityConfig.User ADMIN_USER = new TestSecurityConfig.User("admin").roles(ALL_ACCESS);

private static final String TOKEN_URL_PARAM = "token";

private static final JwtAuthorizationHeaderFactory tokenFactory = new JwtAuthorizationHeaderFactory(
KEY_PAIR.getPrivate(),
CLAIM_USERNAME,
CLAIM_ROLES,
AUTHORIZATION
);

public static final TestSecurityConfig.AuthcDomain JWT_AUTH_DOMAIN = new TestSecurityConfig.AuthcDomain(
"jwt",
BASIC_AUTH_DOMAIN_ORDER - 1
).jwtHttpAuthenticator(
new JwtConfigBuilder().jwtUrlParameter(TOKEN_URL_PARAM).signingKey(PUBLIC_KEY).subjectKey(CLAIM_USERNAME).rolesKey(CLAIM_ROLES)
).backend("noop");

@ClassRule
public static final LocalCluster cluster = new LocalCluster.Builder().clusterManager(ClusterManager.SINGLENODE)
.anonymousAuth(false)
.nodeSettings(
Map.of("plugins.security.restapi.roles_enabled", List.of("user_" + ADMIN_USER.getName() + "__" + ALL_ACCESS.getName()))
)
.authc(AUTHC_HTTPBASIC_INTERNAL)
.authc(JWT_AUTH_DOMAIN)
.users(ADMIN_USER)
.build();

@Rule
public LogsRule logsRule = new LogsRule("com.amazon.dlic.auth.http.jwt.HTTPJwtAuthenticator");

@Test
public void shouldAuthenticateWithJwtTokenInUrl_positive() {
Header jwtToken = tokenFactory.generateValidToken(ADMIN_USER.getName());
String jwtTokenValue = jwtToken.getValue();
try (TestRestClient client = cluster.getRestClient()) {
HttpResponse response = client.getAuthInfo(Map.of(TOKEN_URL_PARAM, jwtTokenValue));

response.assertStatusCode(200);
String username = response.getTextFromJsonBody(POINTER_USERNAME);
assertThat(username, equalTo(ADMIN_USER.getName()));
}
}

@Test
public void testUnauthenticatedRequest() {
try (TestRestClient client = cluster.getRestClient()) {
HttpResponse response = client.getAuthInfo();

response.assertStatusCode(401);
logsRule.assertThatContainExactly(String.format("No JWT token found in '%s' url parameter header", TOKEN_URL_PARAM));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

public class JwtConfigBuilder {
private String jwtHeader;
private String jwtUrlParameter;
private String signingKey;
private String subjectKey;
private String rolesKey;
Expand All @@ -27,6 +28,11 @@ public JwtConfigBuilder jwtHeader(String jwtHeader) {
return this;
}

public JwtConfigBuilder jwtUrlParameter(String jwtUrlParameter) {
this.jwtUrlParameter = jwtUrlParameter;
return this;
}

public JwtConfigBuilder signingKey(String signingKey) {
this.signingKey = signingKey;
return this;
Expand All @@ -51,6 +57,9 @@ public Map<String, Object> build() {
if (isNoneBlank(jwtHeader)) {
builder.put("jwt_header", jwtHeader);
}
if (isNoneBlank(jwtUrlParameter)) {
builder.put("jwt_url_parameter", jwtUrlParameter);
}
if (isNoneBlank(subjectKey)) {
builder.put("subject_key", subjectKey);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
Expand Down Expand Up @@ -114,6 +115,12 @@ public HttpResponse getAuthInfo(Header... headers) {
return executeRequest(new HttpGet(getHttpServerUri() + "/_opendistro/_security/authinfo?pretty"), headers);
}

public HttpResponse getAuthInfo(Map<String, String> urlParams, Header... headers) {
String urlParamsString = "?"
+ urlParams.entrySet().stream().map(e -> e.getKey() + "=" + e.getValue()).collect(Collectors.joining("&"));
return executeRequest(new HttpGet(getHttpServerUri() + "/_opendistro/_security/authinfo" + urlParamsString), headers);
}

public void confirmCorrectCredentials(String expectedUserName) {
HttpResponse response = getAuthInfo();
assertThat(response, notNullValue());
Expand Down
35 changes: 34 additions & 1 deletion src/main/java/org/opensearch/security/filter/NettyRequest.java
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,17 @@
import java.net.InetSocketAddress;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.TreeMap;
import javax.net.ssl.SSLEngine;

import com.google.common.base.Supplier;
import com.google.common.base.Suppliers;

import org.opensearch.http.netty4.Netty4HttpChannel;
import org.opensearch.rest.RestRequest.Method;
import org.opensearch.rest.RestUtils;
Expand All @@ -34,6 +39,7 @@ public class NettyRequest implements SecurityRequest {

protected final HttpRequest underlyingRequest;
protected final Netty4HttpChannel underlyingChannel;
protected final Supplier<CheckedAccessMap> parameters = Suppliers.memoize(() -> new CheckedAccessMap(params(uri())));

NettyRequest(final HttpRequest request, final Netty4HttpChannel channel) {
this.underlyingRequest = request;
Expand Down Expand Up @@ -78,7 +84,12 @@ public String uri() {

@Override
public Map<String, String> params() {
return params(underlyingRequest.uri());
return parameters.get();
}

@Override
public Set<String> getUnconsumedParams() {
return parameters.get().accessedKeys();
}

private static Map<String, String> params(String uri) {
Expand All @@ -96,4 +107,26 @@ private static Map<String, String> params(String uri) {

return params;
}

/** Records access of any keys if explicitly requested from this map */
private static class CheckedAccessMap extends HashMap<String, String> {
private final Set<String> accessedKeys = new HashSet<>();

public CheckedAccessMap(final Map<String, String> map) {
super(map);
}

@Override
public String get(final Object key) {
// Never noticed this about java's map interface the getter is not generic
if (key instanceof String) {
accessedKeys.add((String) key);
}
return super.get(key);
}

public Set<String> accessedKeys() {
return accessedKeys;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,11 @@
package org.opensearch.security.filter;

import java.net.InetSocketAddress;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import javax.net.ssl.SSLEngine;

import org.opensearch.rest.RestRequest;
Expand Down Expand Up @@ -71,7 +73,18 @@ public String uri() {

@Override
public Map<String, String> params() {
return underlyingRequest.params();
return new HashMap<>(underlyingRequest.params()) {
@Override
public String get(Object key) {
return underlyingRequest.param((String) key);
}
};
}

@Override
public Set<String> getUnconsumedParams() {
// params() Map consumes explict parameter access
return Set.of();
}

/** Gets access to the underlying request object */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
import javax.net.ssl.SSLEngine;

Expand Down Expand Up @@ -49,4 +50,7 @@ default String header(final String headerName) {

/** The parameters associated with this request */
Map<String, String> params();

/** The list of parameters that have been accessed but not recorded as being consumed */
Set<String> getUnconsumedParams();
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@
import java.util.Optional;

/**
* When a request is recieved by the security plugin this governs getting information about the request and complete with with a response
* When a request is received by the security plugin this governs getting information about the request and complete with a response
*/
public interface SecurityRequestChannel extends SecurityRequest {

/** Associate a response with this channel */
void queueForSending(final SecurityResponse response);

/** Acess the queued response */
/** Access the queued response */
Optional<SecurityResponse> getQueuedResponse();
}
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import static org.opensearch.security.http.SecurityHttpServerTransport.CONTEXT_TO_RESTORE;
import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE;
import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED;
import static org.opensearch.security.http.SecurityHttpServerTransport.UNCONSUMED_PARAMS;

public class SecurityRestFilter {

Expand Down Expand Up @@ -144,6 +145,13 @@ public void handleRequest(RestRequest request, RestChannel channel, NodeClient c
}
});

NettyAttribute.popFrom(request, UNCONSUMED_PARAMS).ifPresent(unconsumedParams -> {
for (String unconsumedParam : unconsumedParams) {
// Consume the parameter on the RestRequest
request.param(unconsumedParam);
}
});

final SecurityRequestChannel requestChannel = SecurityRequestFactory.from(request, channel);

// Authenticate request
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

package org.opensearch.security.http;

import java.util.Set;

import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Settings;
Expand All @@ -47,6 +49,7 @@
public class SecurityHttpServerTransport extends SecuritySSLNettyHttpServerTransport {

public static final AttributeKey<SecurityResponse> EARLY_RESPONSE = AttributeKey.newInstance("opensearch-http-early-response");
public static final AttributeKey<Set<String>> UNCONSUMED_PARAMS = AttributeKey.newInstance("opensearch-http-request-consumed-params");
public static final AttributeKey<ThreadContext.StoredContext> CONTEXT_TO_RESTORE = AttributeKey.newInstance(
"opensearch-http-request-thread-context"
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import static org.opensearch.security.http.SecurityHttpServerTransport.EARLY_RESPONSE;
import static org.opensearch.security.http.SecurityHttpServerTransport.IS_AUTHENTICATED;
import static org.opensearch.security.http.SecurityHttpServerTransport.SHOULD_DECOMPRESS;
import static org.opensearch.security.http.SecurityHttpServerTransport.UNCONSUMED_PARAMS;

@Sharable
public class Netty4HttpRequestHeaderVerifier extends SimpleChannelInboundHandler<DefaultHttpRequest> {
Expand Down Expand Up @@ -84,6 +85,8 @@ public void channelRead0(ChannelHandlerContext ctx, DefaultHttpRequest msg) thro
// If request channel is completed and a response is sent, then there was a failure during authentication
restFilter.checkAndAuthenticateRequest(requestChannel);

ctx.channel().attr(UNCONSUMED_PARAMS).set(requestChannel.getUnconsumedParams());

ThreadContext.StoredContext contextToRestore = threadPool.getThreadContext().newStoredContext(false);
ctx.channel().attr(CONTEXT_TO_RESTORE).set(contextToRestore);

Expand Down

0 comments on commit b863395

Please sign in to comment.