Skip to content

Commit

Permalink
[6.8] Call ActionListener.onResponse exactly once (#61691)
Browse files Browse the repository at this point in the history
Under specific circumstances we would call onResponse twice, 
which led to unexpected behavior.
Backport of  #61584
  • Loading branch information
jkakavas committed Sep 1, 2020
1 parent 8abd807 commit 2bb45b4
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 3 deletions.
32 changes: 32 additions & 0 deletions server/src/main/java/org/elasticsearch/action/ActionListener.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.common.CheckedConsumer;
import org.elasticsearch.common.CheckedRunnable;
import org.elasticsearch.common.CheckedSupplier;

import java.util.ArrayList;
Expand Down Expand Up @@ -170,6 +171,37 @@ public void onFailure(Exception e) {
};
}

/**
* Wraps a given listener and returns a new listener which executes the provided {@code runBefore}
* callback before the listener is notified via either {@code #onResponse} or {@code #onFailure}.
* If the callback throws an exception then it will be passed to the listener's {@code #onFailure} and its {@code #onResponse} will
* not be executed.
*/
static <Response> ActionListener<Response> runBefore(ActionListener<Response> delegate, CheckedRunnable<?> runBefore) {
return new ActionListener<Response>() {
@Override
public void onResponse(Response response) {
try {
runBefore.run();
} catch (Exception ex) {
delegate.onFailure(ex);
return;
}
delegate.onResponse(response);
}

@Override
public void onFailure(Exception e) {
try {
runBefore.run();
} catch (Exception ex) {
e.addSuppressed(ex);
}
delegate.onFailure(e);
}
};
}

/**
* Wraps a given listener and returns a new listener which makes sure {@link #onResponse(Object)}
* and {@link #onFailure(Exception)} of the provided listener will be called at most once.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,23 @@ public void testRunAfter() {
}
}

public void testRunBefore() {
{
AtomicBoolean afterSuccess = new AtomicBoolean();
ActionListener<Object> listener =
ActionListener.runBefore(ActionListener.wrap(r -> {}, e -> {}), () -> afterSuccess.set(true));
listener.onResponse(null);
assertThat(afterSuccess.get(), equalTo(true));
}
{
AtomicBoolean afterFailure = new AtomicBoolean();
ActionListener<Object> listener =
ActionListener.runBefore(ActionListener.wrap(r -> {}, e -> {}), () -> afterFailure.set(true));
listener.onFailure(null);
assertThat(afterFailure.get(), equalTo(true));
}
}

public void testNotifyOnce() {
AtomicInteger onResponseTimes = new AtomicInteger();
AtomicInteger onFailureTimes = new AtomicInteger();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,9 @@ protected void doLookupUser(String username, ActionListener<User> listener) {
if (realmEnabled == false) {
if (anonymousEnabled && AnonymousUser.isAnonymousUsername(username, config.globalSettings())) {
listener.onResponse(anonymousUser);
} else {
listener.onResponse(null);
}
listener.onResponse(null);
} else if (ClientReservedRealm.isReserved(username, config.globalSettings()) == false) {
listener.onResponse(null);
} else if (AnonymousUser.isAnonymousUsername(username, config.globalSettings())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,13 @@
import java.util.Map.Entry;
import java.util.concurrent.ExecutionException;
import java.util.function.Predicate;
import java.util.concurrent.atomic.AtomicInteger;

import static org.hamcrest.Matchers.contains;
import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.empty;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.mockito.Matchers.any;
Expand Down Expand Up @@ -189,7 +191,7 @@ public void testLookup() throws Exception {
verifyVersionPredicate(principal, predicateCaptor.getValue());

PlainActionFuture<User> future = new PlainActionFuture<>();
reservedRealm.doLookupUser("foobar", future);
reservedRealm.doLookupUser("foobar", assertListenerIsOnlyCalledOnce(future));
final User doesntExist = future.actionGet();
assertThat(doesntExist, nullValue());
verifyNoMoreInteractions(usersStore);
Expand All @@ -204,12 +206,29 @@ public void testLookupDisabled() throws Exception {
final String principal = expectedUser.principal();

PlainActionFuture<User> listener = new PlainActionFuture<>();
reservedRealm.doLookupUser(principal, listener);
reservedRealm.doLookupUser(principal, assertListenerIsOnlyCalledOnce(listener));
final User user = listener.actionGet();
assertNull(user);
verifyZeroInteractions(usersStore);
}


public void testLookupDisabledAnonymous() throws Exception {
Settings settings = Settings.builder()
.put(XPackSettings.RESERVED_REALM_ENABLED_SETTING.getKey(), false)
.put(AnonymousUser.ROLES_SETTING.getKey(), "anonymous")
.build();
final ReservedRealm reservedRealm =
new ReservedRealm(mock(Environment.class), settings, usersStore, new AnonymousUser(settings),
securityIndex, threadPool);
final User expectedUser = new AnonymousUser(settings);
final String principal = expectedUser.principal();

PlainActionFuture<User> listener = new PlainActionFuture<>();
reservedRealm.doLookupUser(principal, assertListenerIsOnlyCalledOnce(listener));
assertThat(listener.actionGet(), equalTo(expectedUser));
}

public void testLookupThrows() throws Exception {
final ReservedRealm reservedRealm =
new ReservedRealm(mock(Environment.class), Settings.EMPTY, usersStore,
Expand Down Expand Up @@ -481,4 +500,13 @@ private void verifyVersionPredicate(String principal, Predicate<Version> version
}
assertThat(versionPredicate.test(Version.CURRENT), is(true));
}

private static <T> ActionListener<T> assertListenerIsOnlyCalledOnce(ActionListener<T> delegate) {
final AtomicInteger callCount = new AtomicInteger(0);
return ActionListener.runBefore(delegate, () -> {
if (callCount.incrementAndGet() != 1) {
fail("Listener was called twice");
}
});
}
}

0 comments on commit 2bb45b4

Please sign in to comment.