diff --git a/ksqldb-engine/src/main/java/io/confluent/ksql/query/PullQueryWriteStream.java b/ksqldb-engine/src/main/java/io/confluent/ksql/query/PullQueryWriteStream.java index b20b476ea68f..783eca949c3f 100644 --- a/ksqldb-engine/src/main/java/io/confluent/ksql/query/PullQueryWriteStream.java +++ b/ksqldb-engine/src/main/java/io/confluent/ksql/query/PullQueryWriteStream.java @@ -28,9 +28,11 @@ import io.confluent.ksql.util.KsqlHostInfo; import io.confluent.ksql.util.RowMetadata; import io.vertx.core.AsyncResult; +import io.vertx.core.Context; import io.vertx.core.Future; import io.vertx.core.Handler; import io.vertx.core.Promise; +import io.vertx.core.Vertx; import io.vertx.core.impl.ConcurrentHashSet; import io.vertx.core.impl.future.SucceededFuture; import io.vertx.core.streams.WriteStream; @@ -97,12 +99,6 @@ public PullQueryWriteStream( ) { this.queryLimit = queryLimit; this.translator = translator; - - // register a drainHandler that will wake up anyone waiting on hasCapacity - drainHandler.add(ignored -> { - monitor.enter(); - monitor.leave(); - }); } private static final class HandledRow { @@ -211,6 +207,10 @@ private PullQueryRow pollRow() { if (monitor.enterIf(atHalfCapacity)) { try { drainHandler.forEach(h -> h.handle(null)); + // Users of this WriteStream, in particular vert.x's PipeImpl re-register the drain handler + // every time the WriteStream is full, so we can clear the drain handler collection after + // we have called it. + drainHandler.clear(); } finally { monitor.leave(); } @@ -273,7 +273,12 @@ public PullQueryWriteStream exceptionHandler(final Handler handler) { @Override public PullQueryWriteStream drainHandler(final Handler handler) { - drainHandler.add(handler); + // Make sure to run the drain handler in the context that registers the drain handler + // (that is, the pipe that is pushing data into this write stream), and not from the thread + // that is consuming data from the write queue. This will avoid data races inside the + // ReadStream. + final Context context = Vertx.currentContext(); + drainHandler.add(v -> context.runOnContext(handler)); return this; } diff --git a/ksqldb-engine/src/test/java/io/confluent/ksql/query/PullQueryWriteStreamTest.java b/ksqldb-engine/src/test/java/io/confluent/ksql/query/PullQueryWriteStreamTest.java index 6af59abc56f2..d3b96a969cdd 100644 --- a/ksqldb-engine/src/test/java/io/confluent/ksql/query/PullQueryWriteStreamTest.java +++ b/ksqldb-engine/src/test/java/io/confluent/ksql/query/PullQueryWriteStreamTest.java @@ -17,6 +17,9 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; import com.google.common.collect.ImmutableList; import io.confluent.ksql.GenericRow; @@ -24,6 +27,9 @@ import io.confluent.ksql.rest.entity.StreamedRow; import io.confluent.ksql.schema.ksql.LogicalSchema; import io.confluent.ksql.util.KeyValueMetadata; +import io.vertx.core.Context; +import io.vertx.core.Handler; +import io.vertx.core.Vertx; import java.util.List; import java.util.Optional; import java.util.OptionalInt; @@ -40,6 +46,7 @@ import org.junit.rules.Timeout; import org.junit.runner.RunWith; import org.mockito.Mock; +import org.mockito.MockedStatic; import org.mockito.junit.MockitoJUnitRunner; @RunWith(MockitoJUnitRunner.class) @@ -117,17 +124,21 @@ public void shouldHandleLimit() { @Test public void shouldCallDrainHandlerWhenHasCapacity() { // Given: - final AtomicBoolean called = new AtomicBoolean(false); + final Context context = mock(Context.class); + @SuppressWarnings("unchecked") final Handler handler = mock(Handler.class); writeStream.setWriteQueueMaxSize(1); writeStream.write(getData(1)); - writeStream.drainHandler(ignored -> called.set(true)); + try (MockedStatic mocked = mockStatic(Vertx.class)) { + mocked.when(Vertx::currentContext).thenReturn(context); + writeStream.drainHandler(handler); + } // When: writeStream.poll(); // Then: - assertThat("expected drain handler to be called", called.get()); + verify(context).runOnContext(handler); } @Test