Skip to content

Commit

Permalink
Fix Potential Memory Leak in SecurityServerTransportInterceptor (#75669
Browse files Browse the repository at this point in the history
…) (#75673)

Calling `onFailure` on an `AbstractRunnable` would not trigger the `onAfter` hook.
If a request that actually needed the ref counting would run into an auth failure
we'd leak it. This currently isn't an issue I think since we only use the ref counting
with recovery and cluster state requests but would cause a memory leak if auth started
to actually fail here.
  • Loading branch information
original-brownbear committed Jul 26, 2021
1 parent 4b8be45 commit 2685560
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.DestructiveOperations;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.common.settings.Setting;
import org.elasticsearch.common.settings.Setting.Property;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -47,7 +46,6 @@
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.Function;

import static org.elasticsearch.xpack.core.security.SecurityField.setting;
Expand Down Expand Up @@ -294,36 +292,62 @@ public void messageReceived(T request, TransportChannel channel, Task task) thro
}
}
assert filter != null;
final Thread executingThread = Thread.currentThread();

final AbstractRunnable receiveMessage = getReceiveRunnable(request, channel, task);
CheckedConsumer<Void, Exception> consumer = (x) -> {
final Executor executor;
if (executingThread == Thread.currentThread()) {
// only fork off if we get called on another thread this means we moved to
// an async execution and in this case we need to go back to the thread pool
// that was actually executing it. it's also possible that the
// thread-pool we are supposed to execute on is `SAME` in that case
// the handler is OK with executing on a network thread and we can just continue even if
// we are on another thread due to async operations
executor = threadPool.executor(ThreadPool.Names.SAME);
} else {
executor = threadPool.executor(executorName);
}

try {
executor.execute(receiveMessage);
} catch (Exception e) {
receiveMessage.onFailure(e);
}

};
ActionListener<Void> filterListener = ActionListener.wrap(consumer, receiveMessage::onFailure);
final ActionListener<Void> filterListener;
if (ThreadPool.Names.SAME.equals(executorName)) {
filterListener = new AbstractFilterListener(receiveMessage) {
@Override
public void onResponse(Void unused) {
receiveMessage.run();
}
};
} else {
final Thread executingThread = Thread.currentThread();
filterListener = new AbstractFilterListener(receiveMessage) {
@Override
public void onResponse(Void unused) {
if (executingThread == Thread.currentThread()) {
// only fork off if we get called on another thread this means we moved to
// an async execution and in this case we need to go back to the thread pool
// that was actually executing it. it's also possible that the
// thread-pool we are supposed to execute on is `SAME` in that case
// the handler is OK with executing on a network thread and we can just continue even if
// we are on another thread due to async operations
receiveMessage.run();
} else {
try {
threadPool.executor(executorName).execute(receiveMessage);
} catch (Exception e) {
onFailure(e);
}
}
}
};
}
filter.inbound(action, request, channel, filterListener);
} else {
getReceiveRunnable(request, channel, task).run();
}
}
}
}

private abstract static class AbstractFilterListener implements ActionListener<Void> {

protected final AbstractRunnable receiveMessage;

protected AbstractFilterListener(AbstractRunnable receiveMessage) {
this.receiveMessage = receiveMessage;
}

@Override
public void onFailure(Exception e) {
try {
receiveMessage.onFailure(e);
} finally {
receiveMessage.onAfter();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
package org.elasticsearch.xpack.security.transport;

import org.elasticsearch.Version;
import org.elasticsearch.action.admin.indices.delete.DeleteIndexAction;
import org.elasticsearch.action.admin.indices.delete.DeleteIndexRequest;
import org.elasticsearch.action.main.MainAction;
import org.elasticsearch.action.support.DestructiveOperations;
import org.elasticsearch.cluster.ClusterState;
Expand All @@ -17,12 +19,14 @@
import org.elasticsearch.common.util.concurrent.ThreadContext;
import org.elasticsearch.gateway.GatewayService;
import org.elasticsearch.license.XPackLicenseState;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.test.ClusterServiceUtils;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.TestThreadPool;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.Transport;
import org.elasticsearch.transport.Transport.Connection;
import org.elasticsearch.transport.TransportChannel;
import org.elasticsearch.transport.TransportException;
import org.elasticsearch.transport.TransportInterceptor.AsyncSender;
import org.elasticsearch.transport.TransportRequest;
Expand All @@ -47,10 +51,14 @@
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Consumer;

import static com.carrotsearch.randomizedtesting.RandomizedTest.randomBoolean;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Matchers.any;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
Expand Down Expand Up @@ -424,6 +432,56 @@ public void handleException(TransportException exp) {
assertEquals("value", threadContext.getHeader("key"));
}

public void testProfileSecuredRequestHandlerDecrementsRefCountOnFailure() throws Exception {
final String profileName = "some-profile";
final DestructiveOperations destructiveOperations =
new DestructiveOperations(
Settings.builder().put(DestructiveOperations.REQUIRES_NAME_SETTING.getKey(), true).build(),
clusterService.getClusterSettings()
);
final SecurityServerTransportInterceptor.ProfileSecuredRequestHandler<DeleteIndexRequest> requestHandler =
new SecurityServerTransportInterceptor.ProfileSecuredRequestHandler<>(
logger,
DeleteIndexAction.NAME,
randomBoolean(),
randomBoolean() ? ThreadPool.Names.SAME : ThreadPool.Names.GENERIC,
(request, channel, task) -> fail("should fail at destructive operations check to trigger listener failure"),
Collections.singletonMap(
profileName,
new ServerTransportFilter.NodeProfile(
null,
null,
threadContext,
randomBoolean(),
destructiveOperations,
randomBoolean(),
securityContext,
xPackLicenseState
)
),
xPackLicenseState,
threadPool
);
final TransportChannel channel = mock(TransportChannel.class);
when(channel.getProfileName()).thenReturn(profileName);
final AtomicBoolean exceptionSent = new AtomicBoolean(false);
doAnswer(invocationOnMock -> {
assertTrue(exceptionSent.compareAndSet(false, true));
return null;
}).when(channel).sendResponse(any(Exception.class));
final AtomicBoolean decRefCalled = new AtomicBoolean(false);
final DeleteIndexRequest deleteIndexRequest = new DeleteIndexRequest() {
@Override
public boolean decRef() {
assertTrue(decRefCalled.compareAndSet(false, true));
return super.decRef();
}
};
requestHandler.messageReceived(deleteIndexRequest, channel, mock(Task.class));
assertTrue(decRefCalled.get());
assertTrue(exceptionSent.get());
}

private String[] randomRoles() {
return generateRandomStringArray(3, 10, false, true);
}
Expand Down

0 comments on commit 2685560

Please sign in to comment.