diff --git a/docs/changelog/134359.yaml b/docs/changelog/134359.yaml new file mode 100644 index 0000000000000..89cdcb63704d2 --- /dev/null +++ b/docs/changelog/134359.yaml @@ -0,0 +1,6 @@ +pr: 134359 +summary: Make `MutableSearchResponse` ref-counted to prevent use-after-close in async + search +area: Search +type: bug +issues: [] diff --git a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchConcurrentStatusIT.java b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchConcurrentStatusIT.java new file mode 100644 index 0000000000000..fb227d3b9f75e --- /dev/null +++ b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchConcurrentStatusIT.java @@ -0,0 +1,271 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.search; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.index.IndexRequestBuilder; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.aggregations.AggregationBuilders; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.ESIntegTestCase.SuiteScopeTestCase; +import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.LongAdder; +import java.util.stream.IntStream; + +@SuiteScopeTestCase +public class AsyncSearchConcurrentStatusIT extends AsyncSearchIntegTestCase { + private static String indexName; + private static int numShards; + + private static int numKeywords; + + @Override + public void setupSuiteScopeCluster() { + indexName = "test-async"; + numShards = randomIntBetween(1, 20); + int numDocs = randomIntBetween(100, 1000); + createIndex(indexName, Settings.builder().put("index.number_of_shards", numShards).build()); + numKeywords = randomIntBetween(50, 100); + Set keywordSet = new HashSet<>(); + for (int i = 0; i < numKeywords; i++) { + keywordSet.add(randomAlphaOfLengthBetween(10, 20)); + } + numKeywords = keywordSet.size(); + String[] keywords = keywordSet.toArray(String[]::new); + List reqs = new ArrayList<>(); + for (int i = 0; i < numDocs; i++) { + float metric = randomFloat(); + String keyword = keywords[randomIntBetween(0, numKeywords - 1)]; + reqs.add(prepareIndex(indexName).setSource("terms", keyword, "metric", metric)); + } + indexRandom(true, true, reqs); + } + + /** + * This test spins up a set of poller threads that repeatedly call + * {@code _async_search/{id}}. Each poller starts immediately, and once enough + * requests have been issued they signal a latch to indicate the group is "warmed up". + * The test waits on this latch to deterministically ensure pollers are active. + * In parallel, a consumer thread drives the async search to completion using the + * blocking iterator. This coordinated overlap exercises the window where the task + * is closing and some status calls may return {@code 410 GONE}. + */ + public void testConcurrentStatusFetchWhileTaskCloses() throws Exception { + final TimeValue timeout = TimeValue.timeValueSeconds(3); + final String aggName = "terms"; + final SearchSourceBuilder source = new SearchSourceBuilder().aggregation( + AggregationBuilders.terms(aggName).field("terms.keyword").size(Math.max(1, numKeywords)) + ); + + final int progressStep = (numShards > 2) ? randomIntBetween(2, numShards) : 2; + try (SearchResponseIterator it = assertBlockingIterator(indexName, numShards, source, 0, progressStep)) { + String id = getAsyncId(it); + + PollStats stats = new PollStats(); + + // Pick a random number of status-poller threads, at least 1, up to (4×numShards) + int pollerThreads = randomIntBetween(1, 4 * numShards); + + // Wait for pollers to be active + CountDownLatch warmed = new CountDownLatch(1); + + // Executor and coordination for pollers + ExecutorService pollerExec = Executors.newFixedThreadPool(pollerThreads); + AtomicBoolean running = new AtomicBoolean(true); + Queue failures = new ConcurrentLinkedQueue<>(); + + CompletableFuture pollers = createPollers(id, pollerThreads, stats, warmed, pollerExec, running, failures); + + // Wait until pollers are issuing requests (warming period) + assertTrue("pollers did not warm up in time", warmed.await(timeout.millis(), TimeUnit.MILLISECONDS)); + + // Start consumer on a separate thread and capture errors + var consumerExec = Executors.newSingleThreadExecutor(); + AtomicReference consumerError = new AtomicReference<>(); + Future consumer = consumerExec.submit(() -> { + try { + consumeAllResponses(it, aggName); + } catch (Throwable t) { + consumerError.set(t); + } + }); + + // Join consumer & surface errors + try { + consumer.get(timeout.millis(), TimeUnit.MILLISECONDS); + + if (consumerError.get() != null) { + fail("consumeAllResponses failed: " + consumerError.get()); + } + } catch (TimeoutException e) { + consumer.cancel(true); + fail(e, "Consumer thread did not finish within timeout"); + } catch (Exception ignored) { + // ignored + } finally { + // Stop pollers + running.set(false); + try { + pollers.get(timeout.millis(), TimeUnit.MILLISECONDS); + } catch (TimeoutException te) { + // The finally block will shut down the pollers forcibly + } catch (ExecutionException ee) { + failures.add(ExceptionsHelper.unwrapCause(ee.getCause())); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } finally { + pollerExec.shutdownNow(); + try { + pollerExec.awaitTermination(timeout.millis(), TimeUnit.MILLISECONDS); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + fail("Interrupted while stopping pollers: " + ie.getMessage()); + } + } + + // Shut down the consumer executor + consumerExec.shutdown(); + try { + consumerExec.awaitTermination(timeout.millis(), TimeUnit.MILLISECONDS); + } catch (InterruptedException ie) { + Thread.currentThread().interrupt(); + } + } + + assertNoWorkerFailures(failures); + assertStats(stats); + } + } + + private void assertNoWorkerFailures(Queue failures) { + assertTrue( + "Unexpected worker failures:\n" + failures.stream().map(ExceptionsHelper::stackTrace).reduce("", (a, b) -> a + "\n---\n" + b), + failures.isEmpty() + ); + } + + private void assertStats(PollStats stats) { + assertEquals(stats.totalCalls.sum(), stats.runningResponses.sum() + stats.completedResponses.sum()); + assertEquals("There should be no exceptions other than GONE", 0, stats.exceptions.sum()); + } + + private String getAsyncId(SearchResponseIterator it) { + AsyncSearchResponse response = it.next(); + try { + assertNotNull(response.getId()); + return response.getId(); + } finally { + response.decRef(); + } + } + + private void consumeAllResponses(SearchResponseIterator it, String aggName) throws Exception { + while (it.hasNext()) { + AsyncSearchResponse response = it.next(); + try { + if (response.getSearchResponse() != null && response.getSearchResponse().getAggregations() != null) { + assertNotNull(response.getSearchResponse().getAggregations().get(aggName)); + } + } finally { + response.decRef(); + } + } + } + + private CompletableFuture createPollers( + String id, + int threads, + PollStats stats, + CountDownLatch warmed, + ExecutorService pollerExec, + AtomicBoolean running, + Queue failures + ) { + @SuppressWarnings("unchecked") + final CompletableFuture[] tasks = IntStream.range(0, threads).mapToObj(i -> CompletableFuture.runAsync(() -> { + while (running.get()) { + AsyncSearchResponse resp = null; + try { + resp = getAsyncSearch(id); + stats.totalCalls.increment(); + + // Once enough requests have been sent, consider pollers "warmed". + if (stats.totalCalls.sum() >= threads) { + warmed.countDown(); + } + + if (resp.isRunning()) { + stats.runningResponses.increment(); + } else { + // Success-only assertions: if reported completed, we must have a proper search response + assertNull("Async search reported completed with failure", resp.getFailure()); + assertNotNull("Completed async search must carry a SearchResponse", resp.getSearchResponse()); + assertNotNull("Completed async search must have aggregations", resp.getSearchResponse().getAggregations()); + assertNotNull( + "Completed async search must contain the expected aggregation", + resp.getSearchResponse().getAggregations().get("terms") + ); + stats.completedResponses.increment(); + } + } catch (Exception e) { + Throwable cause = ExceptionsHelper.unwrapCause(e); + if (cause instanceof ElasticsearchStatusException) { + RestStatus status = ExceptionsHelper.status(cause); + if (status == RestStatus.GONE) { + stats.gone410.increment(); + } else { + stats.exceptions.increment(); + failures.add(cause); + } + } else { + stats.exceptions.increment(); + failures.add(cause); + } + } finally { + if (resp != null) { + resp.decRef(); + } + } + } + }, pollerExec).whenComplete((v, ex) -> { + if (ex != null) { + failures.add(ExceptionsHelper.unwrapCause(ex)); + } + })).toArray(CompletableFuture[]::new); + + return CompletableFuture.allOf(tasks); + } + + static final class PollStats { + final LongAdder totalCalls = new LongAdder(); + final LongAdder runningResponses = new LongAdder(); + final LongAdder completedResponses = new LongAdder(); + final LongAdder exceptions = new LongAdder(); + final LongAdder gone410 = new LongAdder(); + } +} diff --git a/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchRefcountAndFallbackTests.java b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchRefcountAndFallbackTests.java new file mode 100644 index 0000000000000..f17be5da89b70 --- /dev/null +++ b/x-pack/plugin/async-search/src/internalClusterTest/java/org/elasticsearch/xpack/search/AsyncSearchRefcountAndFallbackTests.java @@ -0,0 +1,254 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ +package org.elasticsearch.xpack.search; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.DocWriteResponse; +import org.elasticsearch.action.delete.DeleteResponse; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.action.search.ShardSearchFailure; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.lucene.Lucene; +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.search.SearchHits; +import org.elasticsearch.search.aggregations.AggregationReduceContext; +import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESSingleNodeTestCase; +import org.elasticsearch.test.client.NoOpClient; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.async.AsyncExecutionId; +import org.elasticsearch.xpack.core.async.AsyncTaskIndexService; +import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse; +import org.junit.After; +import org.junit.Before; + +import java.util.Map; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class AsyncSearchRefcountAndFallbackTests extends ESSingleNodeTestCase { + + private AsyncTaskIndexService store; + public String index = ".async-search"; + + private TestThreadPool threadPool; + private NoOpClient client; + + final int totalShards = 1; + final int skippedShards = 0; + + @Before + public void setup() { + this.threadPool = new TestThreadPool(getTestName()); + this.client = new NoOpClient(threadPool); + + ClusterService clusterService = getInstanceFromNode(ClusterService.class); + TransportService transportService = getInstanceFromNode(TransportService.class); + BigArrays bigArrays = getInstanceFromNode(BigArrays.class); + this.store = new AsyncTaskIndexService<>( + index, + clusterService, + transportService.getThreadPool().getThreadContext(), + client(), + "test_origin", + AsyncSearchResponse::new, + writableRegistry(), + bigArrays + ); + } + + @After + public void cleanup() { + terminate(threadPool); + } + + /** + * If another thread still holds a ref to the in-memory MutableSearchResponse, + * building an AsyncSearchResponse succeeds (final SearchResponse is still live). + */ + public void testBuildSucceedsIfAnotherThreadHoldsRef() { + // Build a SearchResponse (sr refCount -> 1) + SearchResponse searchResponse = createSearchResponse(totalShards, totalShards, skippedShards); + + // Take a ref - (msr refCount -> 1, sr refCount -> 2) + MutableSearchResponse msr = new MutableSearchResponse(threadPool.getThreadContext()); + msr.updateShardsAndClusters(totalShards, skippedShards, null); + msr.updateFinalResponse(searchResponse, false); + + searchResponse.decRef(); // sr refCount -> 1 + + // Simulate another thread : take a resource (msr refCount -> 2) + msr.incRef(); + // close resource (msr refCount -> 1) -> closeInternal not called yet + msr.decRef(); + + // Build a response + AsyncSearchResponse resp = msr.toAsyncSearchResponse(createAsyncSearchTask(), System.currentTimeMillis() + 60_000, false); + try { + assertNotNull("Expect SearchResponse when a live ref prevents close", resp.getSearchResponse()); + assertNull("No failure expected while ref is held", resp.getFailure()); + assertFalse("Response should not be marked running", resp.isRunning()); + } finally { + resp.decRef(); + } + + // Release msr (msr refCount -> 0, sr refCount -> 0) -> now calling closeInternal + msr.decRef(); + } + + /** + * When the in-memory MutableSearchResponse has been released (tryIncRef == false), + * the task falls back to the async-search store and returns the persisted response (200 OK). + */ + public void testFallbackToStoreWhenInMemoryResponseReleased() throws Exception { + // Create an AsyncSearchTask + AsyncSearchTask task = createAsyncSearchTask(); + assertNotNull(task); + + // Build a SearchResponse (ssr refCount -> 1) to be stored in the index + SearchResponse storedSearchResponse = createSearchResponse(totalShards, totalShards, skippedShards); + + // Take a ref - (msr refCount -> 1, ssr refCount -> 2) + MutableSearchResponse msr = task.getSearchResponse(); + msr.updateShardsAndClusters(totalShards, skippedShards, /*clusters*/ null); + msr.updateFinalResponse(storedSearchResponse, /*ccsMinimizeRoundtrips*/ false); + + // Build the AsyncSearchResponse to persist in the store + AsyncSearchResponse asyncSearchResponse = null; + try { + long now = System.currentTimeMillis(); + asyncSearchResponse = new AsyncSearchResponse( + task.getExecutionId().getEncoded(), + storedSearchResponse, + /*failure*/ null, + /*isPartial*/ false, + /*isRunning*/ false, + task.getStartTime(), + now + TimeValue.timeValueMinutes(1).millis() + ); + } finally { + if (asyncSearchResponse != null) { + // (ssr refCount -> 2, asr refCount -> 0) + asyncSearchResponse.decRef(); + } + } + + // Persist initial/final response to the store while we still hold a ref + PlainActionFuture write = new PlainActionFuture<>(); + store.createResponse(task.getExecutionId().getDocId(), Map.of(), asyncSearchResponse, write); + write.actionGet(); + + // Release the in-memory objects so the task path must fall back to the store + // - drop our extra ref to the SearchResponse that updateFinalResponse() took (sr -> 1) + storedSearchResponse.decRef(); + // - drop the in-memory MutableSearchResponse so mutableSearchResponse.tryIncRef() == false + // msr -> 0 (closeInternal runs, releasing its ssr) → tryIncRef() will now fail + msr.decRef(); + + PlainActionFuture future = new PlainActionFuture<>(); + task.getResponse(future); + + AsyncSearchResponse resp = future.actionGet(); + assertNotNull("Expected response loaded from store", resp); + assertNull("No failure expected when loaded from store", resp.getFailure()); + assertNotNull("SearchResponse must be present", resp.getSearchResponse()); + assertFalse("Should not be running", resp.isRunning()); + assertFalse("Should not be partial", resp.isPartial()); + assertEquals(RestStatus.OK, resp.status()); + } + + /** + * When both the in-memory MutableSearchResponse has been released AND the stored + * document has been deleted or not found, the task returns GONE (410). + */ + public void testGoneWhenInMemoryReleasedAndStoreMissing() throws Exception { + AsyncSearchTask task = createAsyncSearchTask(); + + SearchResponse searchResponse = createSearchResponse(totalShards, totalShards, skippedShards); + MutableSearchResponse msr = task.getSearchResponse(); + msr.updateShardsAndClusters(totalShards, skippedShards, null); + msr.updateFinalResponse(searchResponse, false); + + long now = System.currentTimeMillis(); + AsyncSearchResponse asr = new AsyncSearchResponse( + task.getExecutionId().getEncoded(), + searchResponse, + null, + false, + false, + task.getStartTime(), + now + TimeValue.timeValueMinutes(1).millis() + ); + asr.decRef(); + + PlainActionFuture write = new PlainActionFuture<>(); + store.createResponse(task.getExecutionId().getDocId(), Map.of(), asr, write); + write.actionGet(); + + searchResponse.decRef(); + msr.decRef(); + + // Delete the doc from store + PlainActionFuture del = new PlainActionFuture<>(); + store.deleteResponse(task.getExecutionId(), del); + del.actionGet(); + + // Now the task must surface GONE + PlainActionFuture future = new PlainActionFuture<>(); + task.getResponse(future); + + Exception ex = expectThrows(Exception.class, future::actionGet); + Throwable cause = ExceptionsHelper.unwrapCause(ex); + assertThat(cause, instanceOf(ElasticsearchStatusException.class)); + assertThat(ExceptionsHelper.status(cause), is(RestStatus.GONE)); + } + + private AsyncSearchTask createAsyncSearchTask() { + return new AsyncSearchTask( + 1L, + "search", + "indices:data/read/search", + TaskId.EMPTY_TASK_ID, + () -> "debug", + TimeValue.timeValueMinutes(1), + Map.of(), + Map.of(), + new AsyncExecutionId("debug", new TaskId("node", 1L)), + client, + threadPool, + isCancelled -> () -> new AggregationReduceContext.ForFinal(null, null, null, null, null, PipelineAggregator.PipelineTree.EMPTY), + store + ); + } + + private SearchResponse createSearchResponse(int totalShards, int successfulShards, int skippedShards) { + SearchResponse.Clusters clusters = new SearchResponse.Clusters(1, 1, 0); + return new SearchResponse( + SearchHits.empty(Lucene.TOTAL_HITS_GREATER_OR_EQUAL_TO_ZERO, Float.NaN), + null, + null, + false, + false, + null, + 0, + null, + totalShards, + successfulShards, + skippedShards, + 1L, + ShardSearchFailure.EMPTY_ARRAY, + clusters + ); + } +} diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java index d14b60f7b77f8..7eaf906ab13cf 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/AsyncSearchTask.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.search; +import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; @@ -23,9 +24,10 @@ import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.action.search.TransportSearchAction; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.core.Releasable; -import org.elasticsearch.core.Releasables; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.search.SearchShardTarget; import org.elasticsearch.search.aggregations.AggregationReduceContext; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -36,6 +38,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.async.AsyncExecutionId; import org.elasticsearch.xpack.core.async.AsyncTask; +import org.elasticsearch.xpack.core.async.AsyncTaskIndexService; import org.elasticsearch.xpack.core.search.action.AsyncSearchResponse; import org.elasticsearch.xpack.core.search.action.AsyncStatusResponse; @@ -54,6 +57,8 @@ * Task that tracks the progress of a currently running {@link SearchRequest}. */ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable { + private final Logger logger = Loggers.getLogger(getClass(), "async"); + private final AsyncExecutionId searchId; private final Client client; private final ThreadPool threadPool; @@ -74,6 +79,8 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable private final MutableSearchResponse searchResponse; + private final AsyncTaskIndexService store; + /** * Creates an instance of {@link AsyncSearchTask}. * @@ -100,7 +107,8 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable AsyncExecutionId searchId, Client client, ThreadPool threadPool, - Function, Supplier> aggReduceContextSupplierFactory + Function, Supplier> aggReduceContextSupplierFactory, + AsyncTaskIndexService store ) { super(id, type, action, () -> "async_search{" + descriptionSupplier.get() + "}", parentTaskId, taskHeaders); this.expirationTimeMillis = getStartTime() + keepAlive.getMillis(); @@ -112,6 +120,7 @@ final class AsyncSearchTask extends SearchTask implements AsyncTask, Releasable this.progressListener = new Listener(); setProgressListener(progressListener); searchResponse = new MutableSearchResponse(threadPool.getThreadContext()); + this.store = store; } /** @@ -203,7 +212,7 @@ public boolean addCompletionListener(ActionListener listene } } if (executeImmediately) { - ActionListener.respondAndRelease(listener, getResponseWithHeaders()); + getResponseWithHeaders(listener); } return true; // unused } @@ -222,12 +231,38 @@ public void addCompletionListener(Consumer listener) { } } if (executeImmediately) { - var response = getResponseWithHeaders(); - try { - listener.accept(response); - } finally { - response.decRef(); - } + getResponseWithHeaders(ActionListener.wrap(resp -> { listener.accept(resp); }, e -> { + ElasticsearchException ex = (e instanceof ElasticsearchException) + ? (ElasticsearchException) e + : ExceptionsHelper.convertToElastic(e); + ex.setStackTrace(new StackTraceElement[0]); + + // We are in the failure path where no MutableSearchResponse can be produced or retained. + // This happens only after the in-memory reference has been released and the stored response + // cannot be retrieved or retained (tryIncRef() failed or the doc is unavailable). At this point + // it is unsafe to call searchResponse.toAsyncSearchResponse(...): the underlying + // MutableSearchResponse is ref-counted and may already be closed (refcount == 0). + // + // Instead, we synthesize a minimal AsyncSearchResponse that carries the exception back to + // the caller. We set isRunning=false because both the in-memory state and stored document + // are unavailable and no further work can be observed. For isPartial we use false as a + // conservative choice: partial results may have existed earlier, but they are no longer + // accessible and we must not imply that any partial data is included. + AsyncSearchResponse failureResponse = new AsyncSearchResponse( + searchId.getEncoded(), + null, + ex, + false, + false, + getStartTime(), + expirationTimeMillis + ); + try { + listener.accept(failureResponse); + } finally { + failureResponse.decRef(); + } + })); } } @@ -246,7 +281,7 @@ private void internalAddCompletionListener(ActionListener l if (hasRun.compareAndSet(false, true)) { // timeout occurred before completion removeCompletionListener(id); - ActionListener.respondAndRelease(listener, getResponseWithHeaders()); + getResponseWithHeaders(listener); } }, waitForCompletion, threadPool.generic()); } catch (Exception exc) { @@ -263,7 +298,7 @@ private void internalAddCompletionListener(ActionListener l } } if (executeImmediately) { - ActionListener.respondAndRelease(listener, getResponseWithHeaders()); + getResponseWithHeaders(listener); } } @@ -314,47 +349,110 @@ private void executeCompletionListeners() { } // we don't need to restore the response headers, they should be included in the current // context since we are called by the search action listener. - AsyncSearchResponse finalResponse = getResponse(); - try { + getResponse(ActionListener.wrap(finalResponse -> { for (Consumer consumer : completionsListenersCopy.values()) { consumer.accept(finalResponse); } - } finally { - finalResponse.decRef(); - } + }, e -> { + ElasticsearchException ex = (e instanceof ElasticsearchException) + ? (ElasticsearchException) e + : ExceptionsHelper.convertToElastic(e); + ex.setStackTrace(new StackTraceElement[0]); + + // We are in the failure path where no MutableSearchResponse can be produced or retained. + // This happens only after the in-memory reference has been released and the stored response + // cannot be retrieved or retained (tryIncRef() failed or the doc is unavailable). At this point + // it is unsafe to call searchResponse.toAsyncSearchResponse(...): the underlying + // MutableSearchResponse is ref-counted and may already be closed (refcount == 0). + // + // Instead, we synthesize a minimal AsyncSearchResponse that carries the exception back to + // the caller. We set isRunning=false because both the in-memory state and stored document + // are unavailable and no further work can be observed. For isPartial we use false as a + // conservative choice: partial results may have existed earlier, but they are no longer + // accessible and we must not imply that any partial data is included. + AsyncSearchResponse failureResponse = new AsyncSearchResponse( + searchId.getEncoded(), + null, + ex, + false, + false, + getStartTime(), + expirationTimeMillis + ); + try { + for (Consumer consumer : completionsListenersCopy.values()) { + consumer.accept(failureResponse); + } + } finally { + failureResponse.decRef(); + } + })); } /** - * Returns the current {@link AsyncSearchResponse}. + * Invokes the listener with the current {@link AsyncSearchResponse} + * without restoring response headers into the calling thread context. + * + * Visible for testing */ - private AsyncSearchResponse getResponse() { - return getResponse(false); + void getResponse(ActionListener listener) { + getResponse(false, listener); } /** - * Returns the current {@link AsyncSearchResponse} and restores the response headers - * in the local thread context. + * Invokes the listener with the current {@link AsyncSearchResponse}, + * restoring response headers into the calling thread context. */ - private AsyncSearchResponse getResponseWithHeaders() { - return getResponse(true); + private void getResponseWithHeaders(ActionListener listener) { + getResponse(true, listener); } - private AsyncSearchResponse getResponse(boolean restoreResponseHeaders) { - MutableSearchResponse mutableSearchResponse = searchResponse; + private void getResponse(boolean restoreResponseHeaders, ActionListener listener) { + final MutableSearchResponse mutableSearchResponse = searchResponse; assert mutableSearchResponse != null; checkCancellation(); - AsyncSearchResponse asyncSearchResponse; + + // Fallback: fetch from store asynchronously + if (mutableSearchResponse.tryIncRef() == false) { + store.getResponse(searchId, restoreResponseHeaders, ActionListener.wrap(resp -> { + if (resp.tryIncRef() == false) { + listener.onFailure(new ElasticsearchStatusException("Async search: result no longer available", RestStatus.GONE)); + return; + } + ActionListener.respondAndRelease(listener, resp); + }, e -> { + final Exception unwrapped = (Exception) ExceptionsHelper.unwrapCause(e); + listener.onFailure( + new ElasticsearchStatusException("Async search: result no longer available", RestStatus.GONE, unwrapped) + ); + })); + return; + } + + // At this point we successfully incremented the ref on the MutableSearchResponse. + // This guarantees that the underlying SearchResponse it wraps will not be closed + // while we are using it, so calling toAsyncSearchResponse(..) is safe. try { - asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, restoreResponseHeaders); - } catch (Exception e) { - ElasticsearchException exception = new ElasticsearchStatusException( - "Async search: error while reducing partial results", - ExceptionsHelper.status(e), - e - ); - asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, exception); + AsyncSearchResponse asyncSearchResponse; + try { + asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, restoreResponseHeaders); + } catch (Exception e) { + final ElasticsearchException ex = new ElasticsearchStatusException( + "Async search: error while reducing partial results", + ExceptionsHelper.status(e), + e + ); + asyncSearchResponse = mutableSearchResponse.toAsyncSearchResponse(this, expirationTimeMillis, ex); + } + ActionListener.respondAndRelease(listener, asyncSearchResponse); + } finally { + mutableSearchResponse.decRef(); } - return asyncSearchResponse; + } + + // Visible for testing. + MutableSearchResponse getSearchResponse() { + return searchResponse; } // checks if the search task should be cancelled @@ -381,7 +479,17 @@ public static AsyncStatusResponse getStatusResponse(AsyncSearchTask asyncTask) { @Override public void close() { - Releasables.close(searchResponse); + if (logger.isDebugEnabled()) { + logger.debug( + "AsyncSearchTask.close(): byThread={}, asyncId={}, taskId={}, hasCompleted={}, stack={}", + Thread.currentThread().getName(), + searchId != null ? searchId.getEncoded() : "", + getId(), + hasCompleted, + new Exception().getStackTrace() + ); + } + searchResponse.decRef(); } class Listener extends SearchProgressActionListener { diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java index 11ff403237888..e09a4722b31f4 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/MutableSearchResponse.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.search; +import org.apache.logging.log4j.Logger; import org.apache.lucene.search.TotalHits; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; @@ -14,10 +15,11 @@ import org.elasticsearch.action.search.SearchResponseMerger; import org.elasticsearch.action.search.ShardSearchFailure; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.logging.Loggers; import org.elasticsearch.common.lucene.Lucene; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.core.Releasable; +import org.elasticsearch.core.AbstractRefCounted; import org.elasticsearch.core.TimeValue; import org.elasticsearch.search.SearchHits; import org.elasticsearch.search.aggregations.InternalAggregations; @@ -38,7 +40,10 @@ * creating an async response concurrently. This limits the number of final reduction that can * run concurrently to 1 and ensures that we pause the search progress when an {@link AsyncSearchResponse} is built. */ -class MutableSearchResponse implements Releasable { +class MutableSearchResponse extends AbstractRefCounted { + + private final Logger logger = Loggers.getLogger(getClass(), "async"); + private int totalShards; private int skippedShards; private Clusters clusters; @@ -85,6 +90,7 @@ class MutableSearchResponse implements Releasable { * @param threadContext The thread context to retrieve the final response headers. */ MutableSearchResponse(ThreadContext threadContext) { + super(); this.isPartial = true; this.threadContext = threadContext; this.totalHits = Lucene.TOTAL_HITS_GREATER_OR_EQUAL_TO_ZERO; @@ -487,14 +493,28 @@ private String getShardsInResponseMismatchInfo(SearchResponse response, boolean } @Override - public synchronized void close() { + protected synchronized void closeInternal() { + + if (logger.isDebugEnabled()) { + logger.debug( + "MutableSearchResponse.close(): byThread={}, finalResponsePresent={}, clusterResponsesCount={}, stack={}", + Thread.currentThread().getName(), + finalResponse != null, + clusterResponses != null ? clusterResponses.size() : 0, + new Exception().getStackTrace() + ); + } + if (finalResponse != null) { finalResponse.decRef(); + finalResponse = null; } if (clusterResponses != null) { for (SearchResponse clusterResponse : clusterResponses) { clusterResponse.decRef(); } + clusterResponses.clear(); + clusterResponses = null; } } } diff --git a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java index ad648af2c5571..16aaac68754a6 100644 --- a/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java +++ b/x-pack/plugin/async-search/src/main/java/org/elasticsearch/xpack/search/TransportSubmitAsyncSearchAction.java @@ -207,7 +207,8 @@ public AsyncSearchTask createTask(long id, String type, String action, TaskId pa store.getClientWithOrigin(), nodeClient.threadPool(), isCancelled -> () -> searchService.aggReduceContextBuilder(isCancelled, originalSearchRequest.source().aggregations()) - .forFinalReduction() + .forFinalReduction(), + store ); } }; diff --git a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java index ab1b6189c0133..70be3a506ddba 100644 --- a/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java +++ b/x-pack/plugin/async-search/src/test/java/org/elasticsearch/xpack/search/AsyncSearchTaskTests.java @@ -91,7 +91,8 @@ private AsyncSearchTask createAsyncSearchTask() { new AsyncExecutionId("0", new TaskId("node1", 1)), new NoOpClient(threadPool), threadPool, - (t) -> () -> null + (t) -> () -> null, + null ); } @@ -112,7 +113,8 @@ public void testTaskDescription() { new AsyncExecutionId("0", new TaskId("node1", 1)), new NoOpClient(threadPool), threadPool, - (t) -> () -> null + (t) -> () -> null, + null ) ) { assertEquals(""" @@ -135,7 +137,8 @@ public void testWaitForInit() throws InterruptedException { new AsyncExecutionId("0", new TaskId("node1", 1)), new NoOpClient(threadPool), threadPool, - (t) -> () -> null + (t) -> () -> null, + null ) ) { int numShards = randomIntBetween(0, 10);