diff --git a/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandCursor.java b/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandCursor.java index 91286bd520..6af8b9ec67 100644 --- a/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandCursor.java +++ b/driver-core/src/main/com/mongodb/internal/operation/AsyncCommandCursor.java @@ -108,11 +108,12 @@ public void next(final OperationContext operationContext, final SingleResultCall commandCursorResult = withEmptyResults(commandCursorResult); funcCallback.onResult(batchResults, null); } else { - getMore(localServerCursor, operationContext, funcCallback); + getMoreLoop(localServerCursor, operationContext, funcCallback); } }, operationContext, callback); } + @Override public boolean isClosed() { return !resourceManager.operable(); @@ -162,14 +163,32 @@ public int getMaxWireVersion() { return maxWireVersion; } + private void getMoreLoop(final ServerCursor localServerCursor, + final OperationContext operationContext, + final SingleResultCallback> funcCallback) { + getMore(localServerCursor, operationContext, (nextBatch, t) -> { + if (t != null) { + funcCallback.onResult(null, t); + } else if (resourceManager.getServerCursor() == null || (nextBatch != null && !nextBatch.isEmpty())) { + commandCursorResult = withEmptyResults(commandCursorResult); + funcCallback.onResult(nextBatch, null); + } else if (!resourceManager.operable()) { + funcCallback.onResult(emptyList(), null); + } else { + getMoreLoop(assertNotNull(resourceManager.getServerCursor()), operationContext, funcCallback); + } + }); + } + private void getMore(final ServerCursor cursor, final OperationContext operationContext, final SingleResultCallback> callback) { resourceManager.executeWithConnection(operationContext, (connection, wrappedCallback) -> - getMoreLoop(assertNotNull(connection), cursor, operationContext, wrappedCallback), callback); + executeGetMoreCommand(assertNotNull(connection), cursor, operationContext, wrappedCallback), callback); } - private void getMoreLoop(final AsyncConnection connection, final ServerCursor serverCursor, - final OperationContext operationContext, - final SingleResultCallback> callback) { + private void executeGetMoreCommand(final AsyncConnection connection, + final ServerCursor serverCursor, + final OperationContext operationContext, + final SingleResultCallback> callback) { connection.commandAsync(namespace.getDatabaseName(), getMoreCommandDocument(serverCursor.getId(), connection.getDescription(), namespace, batchSize, comment), NoOpFieldNameValidator.INSTANCE, ReadPreference.primary(), @@ -188,19 +207,7 @@ private void getMoreLoop(final AsyncConnection connection, final ServerCursor se connection.getDescription().getServerAddress(), NEXT_BATCH, assertNotNull(commandResult)); ServerCursor nextServerCursor = commandCursorResult.getServerCursor(); resourceManager.setServerCursor(nextServerCursor); - List nextBatch = commandCursorResult.getResults(); - if (nextServerCursor == null || !nextBatch.isEmpty()) { - commandCursorResult = withEmptyResults(commandCursorResult); - callback.onResult(nextBatch, null); - return; - } - - if (!resourceManager.operable()) { - callback.onResult(emptyList(), null); - return; - } - - getMoreLoop(connection, nextServerCursor, operationContext, callback); + callback.onResult(commandCursorResult.getResults(), null); }); } diff --git a/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandCursorTest.java b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandCursorTest.java index 464e817d60..6d2ef649bc 100644 --- a/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandCursorTest.java +++ b/driver-core/src/test/functional/com/mongodb/internal/operation/AsyncCommandCursorTest.java @@ -41,14 +41,18 @@ import org.bson.Document; import org.bson.codecs.Decoder; import org.bson.codecs.DocumentCodec; -import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import java.time.Duration; +import java.util.concurrent.atomic.AtomicInteger; import static com.mongodb.internal.operation.OperationUnitSpecification.getMaxWireVersionForServerVersion; import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -80,11 +84,9 @@ class AsyncCommandCursorTest { private OperationContext operationContext; private TimeoutContext timeoutContext; private ServerDescription serverDescription; - private AsyncCursor coreCursor; @BeforeEach void setUp() { - coreCursor = mock(AsyncCursor.class); timeoutContext = spy(new TimeoutContext(TimeoutSettings.create( MongoClientSettings.builder().timeout(TIMEOUT.toMillis(), MILLISECONDS).build()))); operationContext = spy(new OperationContext( @@ -105,7 +107,7 @@ void setUp() { serverDescription = mock(ServerDescription.class); when(operationContext.getTimeoutContext()).thenReturn(timeoutContext); doAnswer(invocation -> { - SingleResultCallback callback = invocation.getArgument(0); + SingleResultCallback callback = invocation.getArgument(1); callback.onResult(mockConnection, null); return null; }).when(connectionSource).getConnection(any(), any()); @@ -126,9 +128,9 @@ void shouldSkipKillsCursorsCommandWhenNetworkErrorOccurs() { //when commandBatchCursor.next(operationContext, (result, t) -> { - Assertions.assertNull(result); - Assertions.assertNotNull(t); - Assertions.assertEquals(MongoSocketException.class, t.getClass()); + assertNull(result); + assertNotNull(t); + assertEquals(MongoSocketException.class, t.getClass()); }); //then @@ -151,9 +153,9 @@ void shouldNotSkipKillsCursorsCommandWhenTimeoutExceptionDoesNotHaveNetworkError //when commandBatchCursor.next(operationContext, (result, t) -> { - Assertions.assertNull(result); - Assertions.assertNotNull(t); - Assertions.assertEquals(MongoOperationTimeoutException.class, t.getClass()); + assertNull(result); + assertNotNull(t); + assertEquals(MongoOperationTimeoutException.class, t.getClass()); }); commandBatchCursor.close(operationContext); @@ -182,9 +184,9 @@ void shouldSkipKillsCursorsCommandWhenTimeoutExceptionHaveNetworkErrorCause() { //when commandBatchCursor.next(operationContext, (result, t) -> { - Assertions.assertNull(result); - Assertions.assertNotNull(t); - Assertions.assertEquals(MongoOperationTimeoutException.class, t.getClass()); + assertNull(result); + assertNotNull(t); + assertEquals(MongoOperationTimeoutException.class, t.getClass()); }); commandBatchCursor.close(operationContext); @@ -199,6 +201,33 @@ void shouldSkipKillsCursorsCommandWhenTimeoutExceptionHaveNetworkErrorCause() { } + @Test + void shouldReleaseConnectionBetweenEmptyGetMoreResponses() { + AtomicInteger callCount = new AtomicInteger(); + doAnswer(invocation -> { + SingleResultCallback cb = invocation.getArgument(6); + cb.onResult(new BsonDocument("cursor", + new BsonDocument("ns", new BsonString(NAMESPACE.getFullName())) + .append("id", new BsonInt64(callCount.incrementAndGet() < 3 ? 1 : 0)) + .append("nextBatch", new BsonArrayWrapper<>(new BsonArray()))), null); + return null; + }).when(mockConnection).commandAsync(eq(NAMESPACE.getDatabaseName()), + argThat(doc -> doc.containsKey("getMore")), any(), any(), any(), any(), any()); + + when(serverDescription.getType()).thenReturn(ServerType.STANDALONE); + createBatchCursor().next(operationContext, (result, t) -> { + assertNotNull(result); + assertTrue(result.isEmpty()); + assertNull(t); + }); + + // 2 empty-batch getMores + 1 exhausted getMore = 3 getMores, but the 3rd + // exhausts the cursor (id=0), which makes the cursor break the loop and return an empty result. + verify(mockConnection, times(3)).release(); + verify(connectionSource, times(3)).getConnection(any(), any()); + assertEquals(3, callCount.get()); + } + private AsyncCursor createBatchCursor() { return new AsyncCommandCursor<>( COMMAND_CURSOR_DOCUMENT, diff --git a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncCommandBatchCursorSpecification.groovy b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncCommandBatchCursorSpecification.groovy index f0b73f24fe..b9884e3481 100644 --- a/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncCommandBatchCursorSpecification.groovy +++ b/driver-core/src/test/unit/com/mongodb/internal/operation/AsyncCommandBatchCursorSpecification.groovy @@ -167,8 +167,9 @@ class AsyncCommandBatchCursorSpecification extends Specification { def 'should handle getMore when there are empty results but there is a cursor'() { given: def initialConnection = referenceCountedAsyncConnection() - def connection = referenceCountedAsyncConnection() - def connectionSource = getAsyncConnectionSource(connection) + def connectionA = referenceCountedAsyncConnection() + def connectionB = referenceCountedAsyncConnection() + def connectionSource = getAsyncConnectionSource(connectionA, connectionB) when: def firstBatch = createCommandResult([], CURSOR_ID) @@ -177,14 +178,15 @@ class AsyncCommandBatchCursorSpecification extends Specification { def batch = nextBatch(cursor) then: - 1 * connection.commandAsync(*_) >> { - connection.getCount() == 1 + 1 * connectionA.commandAsync(*_) >> { + connectionA.getCount() == 1 connectionSource.getCount() == 1 it.last().onResult(response, null) } - 1 * connection.commandAsync(*_) >> { - connection.getCount() == 1 + then: + 1 * connectionB.commandAsync(*_) >> { + connectionB.getCount() == 1 connectionSource.getCount() == 1 it.last().onResult(response2, null) } @@ -196,7 +198,10 @@ class AsyncCommandBatchCursorSpecification extends Specification { cursor.close() then: - 0 * connection._ + 0 * connectionA._ + 0 * connectionB._ + connectionA.getCount() == 0 + connectionB.getCount() == 0 initialConnection.getCount() == 0 connectionSource.getCount() == 0