Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support concurrent refresh of refresh tokens #38382

Merged
merged 21 commits into from
Mar 1, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
70e620f
allow tokens to be reissued in a given time window
jkakavas Feb 4, 2019
318af42
Concludes work for supporting concurrent refreshes of access tokens i…
jkakavas Feb 4, 2019
541fc34
Merge remote-tracking branch 'origin/master' into support-concurrent-…
jkakavas Feb 4, 2019
55b7428
remove debug logging
jkakavas Feb 4, 2019
27324ab
fix merge woes
jkakavas Feb 5, 2019
f51a8f8
messing up while resolving merge conflicts is my super power
jkakavas Feb 5, 2019
44e0036
Merge remote-tracking branch 'origin/master' into support-concurrent-…
jkakavas Feb 5, 2019
7b70ca5
Handle/not handle deprecation warnings as needed
jkakavas Feb 5, 2019
2b54388
Handle deprecation header-AbstractUpgradeTestCase
jkakavas Feb 5, 2019
6a66302
Revert "Handle deprecation header-AbstractUpgradeTestCase"
jkakavas Feb 5, 2019
7337574
address feedback
jkakavas Feb 11, 2019
e655b0d
Fix versions for master. Will be changed back to V7_1_0 on backport
jkakavas Feb 11, 2019
0fa0f3c
add test with concurrent refreshes
jkakavas Feb 13, 2019
4d1e1dc
Implement suggested modifications
jkakavas Feb 26, 2019
d0971d1
Address feedback
jkakavas Feb 26, 2019
03475ac
Merge remote-tracking branch 'origin/master' into support-concurrent-…
jkakavas Feb 26, 2019
8f804c1
Fix TokenServiceTests
jkakavas Feb 27, 2019
309c8d1
address ffedback
jkakavas Feb 28, 2019
e13fd13
Merge remote-tracking branch 'origin/master' into support-concurrent-…
jkakavas Feb 28, 2019
35eb8ef
address feedback
jkakavas Feb 28, 2019
cbc626d
address feedback
jkakavas Mar 1, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package org.elasticsearch.xpack.core.security.authc.support;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.Version;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -50,6 +51,9 @@ public TokensInvalidationResult(StreamInput in) throws IOException {
this.invalidatedTokens = in.readStringList();
this.previouslyInvalidatedTokens = in.readStringList();
this.errors = in.readList(StreamInput::readException);
if (in.getVersion().before(Version.V_8_0_0)) {
in.readVInt();
}
}

public static TokensInvalidationResult emptyResult() {
Expand Down Expand Up @@ -93,5 +97,8 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeStringCollection(invalidatedTokens);
out.writeStringCollection(previouslyInvalidatedTokens);
out.writeCollection(errors, StreamOutput::writeException);
if (out.getVersion().before(Version.V_8_0_0)) {
out.writeVInt(5);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ public final class TokenService {
private static final String TOKEN_DOC_TYPE = "token";
private static final String TOKEN_DOC_ID_PREFIX = TOKEN_DOC_TYPE + "_";
static final int MINIMUM_BYTES = VERSION_BYTES + SALT_BYTES + IV_BYTES + 1;
private static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue();
static final int MINIMUM_BASE64_BYTES = Double.valueOf(Math.ceil((4 * MINIMUM_BYTES) / 3)).intValue();
private static final Logger logger = LogManager.getLogger(TokenService.class);

private final SecureRandom secureRandom = new SecureRandom();
Expand Down Expand Up @@ -892,14 +892,7 @@ public void onFailure(Exception e) {
logger.info("failed to update the original token document [{}], the update result was [{}]. Retrying",
tokenDocId, updateResponse.getResult());
client.threadPool().schedule(
() -> innerRefresh(
tokenDocId,
source,
seqNo,
primaryTerm,
clientAuth,
listener,
backoff,
() -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff,
refreshRequested),
backoff.next(), GENERIC);
} else {
Expand Down Expand Up @@ -953,14 +946,7 @@ public void onFailure(Exception e) {
if (backoff.hasNext()) {
logger.debug("failed to update the original token document [{}], retrying", tokenDocId);
client.threadPool().schedule(
() -> innerRefresh(
tokenDocId,
source,
seqNo,
primaryTerm,
clientAuth,
listener,
backoff,
() -> innerRefresh(tokenDocId, source, seqNo, primaryTerm, clientAuth, listener, backoff,
refreshRequested),
backoff.next(), GENERIC);
} else {
Expand Down Expand Up @@ -1396,35 +1382,13 @@ public String getAccessTokenAsString(UserToken userToken) throws IOException, Ge
}
}

// Used only for testing
protected String getDeprecatedAccessTokenString(UserToken userToken) throws IOException, GeneralSecurityException {
try (ByteArrayOutputStream os = new ByteArrayOutputStream(MINIMUM_BASE64_BYTES);
OutputStream base64 = Base64.getEncoder().wrap(os);
StreamOutput out = new OutputStreamStreamOutput(base64)) {
out.setVersion(Version.V_7_0_0);
KeyAndCache keyAndCache = keyCache.activeKeyCache;
Version.writeVersion(Version.V_7_0_0, out);
out.writeByteArray(keyAndCache.getSalt().bytes);
out.writeByteArray(keyAndCache.getKeyHash().bytes);
final byte[] initializationVector = getNewInitializationVector();
out.writeByteArray(initializationVector);
try (CipherOutputStream encryptedOutput =
new CipherOutputStream(out, getEncryptionCipher(initializationVector, keyAndCache, Version.V_7_0_0));
StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) {
encryptedStreamOutput.setVersion(Version.V_7_0_0);
encryptedStreamOutput.writeString(userToken.getId());
encryptedStreamOutput.close();
return new String(os.toByteArray(), StandardCharsets.UTF_8);
}
}
}

private void ensureEncryptionCiphersSupported() throws NoSuchPaddingException, NoSuchAlgorithmException {
Cipher.getInstance(ENCRYPTION_CIPHER);
SecretKeyFactory.getInstance(KDF_ALGORITHM);
}

private Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException {
// Package private for testing
Cipher getEncryptionCipher(byte[] iv, KeyAndCache keyAndCache, Version version) throws GeneralSecurityException {
Cipher cipher = Cipher.getInstance(ENCRYPTION_CIPHER);
BytesKey salt = keyAndCache.getSalt();
try {
Expand All @@ -1446,7 +1410,8 @@ private Cipher getDecryptionCipher(byte[] iv, SecretKey key, Version version,
return cipher;
}

private byte[] getNewInitializationVector() {
// Package private for testing
byte[] getNewInitializationVector() {
final byte[] initializationVector = new byte[IV_BYTES];
secureRandom.nextBytes(initializationVector);
return initializationVector;
Expand Down Expand Up @@ -1833,6 +1798,13 @@ void clearActiveKeyCache() {
this.keyCache.activeKeyCache.keyCache.invalidateAll();
}

/**
* For testing
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you say package private for testing?

*/
KeyAndCache getActiveKeyCache() {
return this.keyCache.activeKeyCache;
}

static final class KeyAndCache implements Closeable {
private final KeyAndTimestamp keyAndTimestamp;
private final Cache<BytesKey, SecretKey> keyCache;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,8 @@ public void testRefreshingMultipleTimesWithinWindowSucceeds() throws Exception {
try {
readyLatch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
completedLatch.countDown();
jaymode marked this conversation as resolved.
Show resolved Hide resolved
return;
}
threadSecurityClient.refreshToken(refreshRequest, ActionListener.wrap(result -> {
accessTokens.add(result.getTokenString());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.security.authc;

import org.elasticsearch.ElasticsearchSecurityException;
import org.elasticsearch.Version;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.NoShardAvailableActionException;
import org.elasticsearch.action.get.GetAction;
Expand All @@ -23,6 +24,8 @@
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.io.stream.OutputStreamStreamOutput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.util.concurrent.ThreadContext;
Expand Down Expand Up @@ -51,7 +54,11 @@
import org.junit.Before;
import org.junit.BeforeClass;

import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.charset.StandardCharsets;
import java.security.GeneralSecurityException;
import java.time.Clock;
import java.time.Instant;
import java.time.temporal.ChronoUnit;
Expand All @@ -61,6 +68,7 @@
import java.util.Map;
import java.util.function.Consumer;

import javax.crypto.CipherOutputStream;
import javax.crypto.SecretKey;

import static java.time.Clock.systemUTC;
Expand Down Expand Up @@ -198,7 +206,7 @@ public void testRotateKey() throws Exception {
authentication = token.getAuthentication();

ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getDeprecatedAccessTokenString(token));
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token));

try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
Expand All @@ -219,10 +227,10 @@ public void testRotateKey() throws Exception {
tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true);
final UserToken newToken = newTokenFuture.get().v1();
assertNotNull(newToken);
assertNotEquals(tokenService.getDeprecatedAccessTokenString(newToken), tokenService.getDeprecatedAccessTokenString(token));
assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token));

requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getDeprecatedAccessTokenString(newToken));
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken));
mockGetTokenFromId(newToken, false);

try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
Expand Down Expand Up @@ -258,7 +266,7 @@ public void testKeyExchange() throws Exception {
authentication = token.getAuthentication();

ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getDeprecatedAccessTokenString(token));
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token));
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
otherTokenService.getAndValidateToken(requestContext, future);
Expand Down Expand Up @@ -289,7 +297,7 @@ public void testPruneKeys() throws Exception {
authentication = token.getAuthentication();

ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getDeprecatedAccessTokenString(token));
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token));

try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
Expand All @@ -316,7 +324,7 @@ public void testPruneKeys() throws Exception {
tokenService.createUserToken(authentication, authentication, newTokenFuture, Collections.emptyMap(), true);
final UserToken newToken = newTokenFuture.get().v1();
assertNotNull(newToken);
assertNotEquals(tokenService.getDeprecatedAccessTokenString(newToken), tokenService.getDeprecatedAccessTokenString(token));
assertNotEquals(getDeprecatedAccessTokenString(tokenService, newToken), getDeprecatedAccessTokenString(tokenService, token));

metaData = tokenService.pruneKeys(1);
tokenService.refreshMetaData(metaData);
Expand All @@ -329,7 +337,7 @@ public void testPruneKeys() throws Exception {
}

requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getDeprecatedAccessTokenString(newToken));
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, newToken));
mockGetTokenFromId(newToken, false);
try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
Expand All @@ -351,7 +359,7 @@ public void testPassphraseWorks() throws Exception {
authentication = token.getAuthentication();

ThreadContext requestContext = new ThreadContext(Settings.EMPTY);
requestContext.putHeader("Authorization", "Bearer " + tokenService.getDeprecatedAccessTokenString(token));
requestContext.putHeader("Authorization", "Bearer " + getDeprecatedAccessTokenString(tokenService, token));

try (ThreadContext.StoredContext ignore = requestContext.newStoredContext(true)) {
PlainActionFuture<UserToken> future = new PlainActionFuture<>();
Expand All @@ -377,10 +385,10 @@ public void testGetTokenWhenKeyCacheHasExpired() throws Exception {
PlainActionFuture<Tuple<UserToken, String>> tokenFuture = new PlainActionFuture<>();
tokenService.createUserToken(authentication, authentication, tokenFuture, Collections.emptyMap(), true);
UserToken token = tokenFuture.get().v1();
assertThat(tokenService.getDeprecatedAccessTokenString(token), notNullValue());
assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue());

tokenService.clearActiveKeyCache();
assertThat(tokenService.getDeprecatedAccessTokenString(token), notNullValue());
assertThat(getDeprecatedAccessTokenString(tokenService, token), notNullValue());
}

public void testInvalidatedToken() throws Exception {
Expand Down Expand Up @@ -639,4 +647,28 @@ public static void assertAuthentication(Authentication result, Authentication ex
assertEquals(expected.getMetadata(), result.getMetadata());
assertEquals(AuthenticationType.TOKEN, result.getAuthenticationType());
}

protected String getDeprecatedAccessTokenString(TokenService tokenService, UserToken userToken) throws IOException,
GeneralSecurityException {
try (ByteArrayOutputStream os = new ByteArrayOutputStream(TokenService.MINIMUM_BASE64_BYTES);
OutputStream base64 = Base64.getEncoder().wrap(os);
StreamOutput out = new OutputStreamStreamOutput(base64)) {
out.setVersion(Version.V_7_0_0);
TokenService.KeyAndCache keyAndCache = tokenService.getActiveKeyCache();
Version.writeVersion(Version.V_7_0_0, out);
out.writeByteArray(keyAndCache.getSalt().bytes);
out.writeByteArray(keyAndCache.getKeyHash().bytes);
final byte[] initializationVector = tokenService.getNewInitializationVector();
out.writeByteArray(initializationVector);
try (CipherOutputStream encryptedOutput =
new CipherOutputStream(out, tokenService.getEncryptionCipher(initializationVector, keyAndCache, Version.V_7_0_0));
StreamOutput encryptedStreamOutput = new OutputStreamStreamOutput(encryptedOutput)) {
encryptedStreamOutput.setVersion(Version.V_7_0_0);
encryptedStreamOutput.writeString(userToken.getId());
encryptedStreamOutput.close();
return new String(os.toByteArray(), StandardCharsets.UTF_8);
}
}
}

}