Skip to content

Commit

Permalink
Make transport and MockTicker tests more stable (#13713)
Browse files Browse the repository at this point in the history
Motivation:
A number of flaky tests were identified.

Modification:
Address test isolation failures and resource exhaustion errors.
Also enable parallel test execution, becuase the tests can probably tolerate that now.
Improve the DefaultMockTicker test to be deterministic, by adding the ability to wait for other threads to enter the sleep method.

Result:
Tests in the transport module should now run stable enough to run in parallel.
All the tests can now also be run in repeat.
  • Loading branch information
chrisvest committed Dec 1, 2023
1 parent a0a02fd commit 438b436
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 83 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
*/
package io.netty5.util.concurrent;

import java.util.Collections;
import java.util.IdentityHashMap;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;

import static io.netty5.util.internal.ObjectUtil.checkPositiveOrZero;
Expand All @@ -29,9 +31,12 @@
*/
final class DefaultMockTicker implements MockTicker {

private final Lock lock = new ReentrantLock();
private final Condition cond = lock.newCondition();
// The lock is fair, so waiters get to process condition signals in the order they (the waiters) queued up.
private final ReentrantLock lock = new ReentrantLock(true);
private final Condition tickCondition = lock.newCondition();
private final Condition sleeperCondition = lock.newCondition();
private final AtomicLong nanoTime = new AtomicLong();
private final Set<Thread> sleepers = Collections.newSetFromMap(new IdentityHashMap<>());

@Override
public long nanoTime() {
Expand All @@ -47,13 +52,30 @@ public void sleep(long delay, TimeUnit unit) throws InterruptedException {
return;
}

final long startTimeNanos = nanoTime();
final long delayNanos = unit.toNanos(delay);
lock.lockInterruptibly();
try {
final long startTimeNanos = nanoTime();
sleepers.add(Thread.currentThread());
sleeperCondition.signalAll();
do {
cond.await();
tickCondition.await();
} while (nanoTime() - startTimeNanos < delayNanos);
} finally {
sleepers.remove(Thread.currentThread());
lock.unlock();
}
}

/**
* Wait for the given thread to enter the {@link #sleep(long, TimeUnit)} method, and block.
*/
public void awaitSleepingThread(Thread thread) throws InterruptedException {
lock.lockInterruptibly();
try {
while (!sleepers.contains(thread)) {
sleeperCondition.await();
}
} finally {
lock.unlock();
}
Expand All @@ -72,7 +94,7 @@ public void advance(long amount, TimeUnit unit) {
lock.lock();
try {
nanoTime.addAndGet(amountNanos);
cond.signalAll();
tickCondition.signalAll();
} finally {
lock.unlock();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@
package io.netty5.util.concurrent;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.FutureTask;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

Expand Down Expand Up @@ -65,44 +66,74 @@ void advanceWithNegativeAmount() {
});
}

@Timeout(60)
@Test
void advanceWithWaiters() throws Exception {
final MockTicker ticker = Ticker.newMockTicker();
final List<Thread> threads = new ArrayList<>();
final DefaultMockTicker ticker = (DefaultMockTicker) Ticker.newMockTicker();
final int numWaiters = 4;
final List<CompletableFuture<Void>> futures = new ArrayList<>();
final List<FutureTask<Void>> futures = new ArrayList<>();
for (int i = 0; i < numWaiters; i++) {
futures.add(CompletableFuture.runAsync(() -> {
FutureTask<Void> task = new FutureTask<>(() -> {
try {
ticker.sleep(1, TimeUnit.MILLISECONDS);
} catch (InterruptedException e) {
throw new CompletionException(e);
}
}));
}

// Time did not advance at all, and thus future will not complete.
for (int i = 0; i < numWaiters; i++) {
final int finalCnt = i;
assertThrows(TimeoutException.class, () -> {
futures.get(finalCnt).get(1, TimeUnit.SECONDS);
});
}

// Advance just one nanosecond before completion.
ticker.advance(999_999, TimeUnit.NANOSECONDS);

// Still needs one more nanosecond.
for (int i = 0; i < numWaiters; i++) {
final int finalCnt = i;
assertThrows(TimeoutException.class, () -> {
futures.get(finalCnt).get(1, TimeUnit.SECONDS);
return null;
});
Thread thread = new Thread(task);
threads.add(thread);
futures.add(task);
thread.start();
}

// Reach at the 1 millisecond mark and ensure the future is complete.
ticker.advance(1, TimeUnit.NANOSECONDS);
for (int i = 0; i < numWaiters; i++) {
futures.get(i).get();
try {
// Wait for all threads to be sleeping.
for (Thread thread : threads) {
ticker.awaitSleepingThread(thread);
}

// Time did not advance at all, and thus future will not complete.
for (int i = 0; i < numWaiters; i++) {
final int finalCnt = i;
assertThrows(TimeoutException.class, () -> {
futures.get(finalCnt).get(1, TimeUnit.MILLISECONDS);
});
}

// Advance just one nanosecond before completion.
ticker.advance(999_999, TimeUnit.NANOSECONDS);

// All threads should still be sleeping.
for (Thread thread : threads) {
ticker.awaitSleepingThread(thread);
}

// Still needs one more nanosecond for our futures.
for (int i = 0; i < numWaiters; i++) {
final int finalCnt = i;
assertThrows(TimeoutException.class, () -> {
futures.get(finalCnt).get(1, TimeUnit.MILLISECONDS);
});
}

// Reach at the 1 millisecond mark and ensure the future is complete.
ticker.advance(1, TimeUnit.NANOSECONDS);
for (int i = 0; i < numWaiters; i++) {
futures.get(i).get();
}
} catch (InterruptedException ie) {
for (Thread thread : threads) {
String name = thread.getName();
Thread.State state = thread.getState();
StackTraceElement[] stackTrace = thread.getStackTrace();
thread.interrupt();
InterruptedException threadStackTrace = new InterruptedException(name + ": " + state);
threadStackTrace.setStackTrace(stackTrace);
ie.addSuppressed(threadStackTrace);
}
throw ie;
}
}

Expand Down
14 changes: 11 additions & 3 deletions transport/src/test/java/io/netty5/bootstrap/BootstrapTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.netty5.util.concurrent.Future;
import io.netty5.util.concurrent.Promise;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

Expand Down Expand Up @@ -65,9 +66,16 @@

public class BootstrapTest {

private static final EventLoopGroup groupA = new MultithreadEventLoopGroup(1, LocalHandler.newFactory());
private static final EventLoopGroup groupB = new MultithreadEventLoopGroup(1, LocalHandler.newFactory());
private static final ChannelHandler dummyHandler = new DummyHandler();
private static EventLoopGroup groupA;
private static EventLoopGroup groupB;
private static ChannelHandler dummyHandler;

@BeforeAll
public static void setUp() {
groupA = new MultithreadEventLoopGroup(1, LocalHandler.newFactory());
groupB = new MultithreadEventLoopGroup(1, LocalHandler.newFactory());
dummyHandler = new DummyHandler();
}

@AfterAll
public static void destroy() throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import org.junit.jupiter.api.Test;

import java.util.UUID;

import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertSame;
Expand All @@ -27,7 +29,7 @@ public class ChannelOptionTest {

@Test
public void testExists() {
String name = "test";
String name = "test" + UUID.randomUUID();
assertFalse(ChannelOption.exists(name));
ChannelOption<String> option = ChannelOption.valueOf(name);

Expand All @@ -37,7 +39,7 @@ public void testExists() {

@Test
public void testValueOf() {
String name = "test1";
String name = "test1" + UUID.randomUUID();
assertFalse(ChannelOption.exists(name));
ChannelOption<String> option = ChannelOption.valueOf(name);
ChannelOption<String> option2 = ChannelOption.valueOf(name);
Expand All @@ -47,7 +49,7 @@ public void testValueOf() {

@Test
public void testCreateOrFail() {
String name = "test2";
String name = "test2" + UUID.randomUUID();
assertFalse(ChannelOption.exists(name));
ChannelOption<String> option = ChannelOption.newInstance(name);
assertTrue(ChannelOption.exists(name));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.hamcrest.Matchers;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;

Expand Down Expand Up @@ -64,11 +65,16 @@

public class DefaultChannelPipelineTest {

private static final EventLoopGroup group = new MultithreadEventLoopGroup(1, LocalHandler.newFactory());
private static EventLoopGroup group;

private Channel self;
private Channel peer;

@BeforeAll
public static void beforeClass() {
group = new MultithreadEventLoopGroup(1, LocalHandler.newFactory());
}

@AfterAll
public static void afterClass() throws Exception {
group.shutdownGracefully().asStage().sync();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.Timeout;
import org.junit.jupiter.api.parallel.ResourceAccessMode;
import org.junit.jupiter.api.parallel.ResourceLock;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
Expand Down Expand Up @@ -53,6 +55,7 @@
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

@ResourceLock(value = "scheduler_timing", mode = ResourceAccessMode.READ)
public class SingleThreadEventLoopTest {

private static final Runnable NOOP = () -> { };
Expand Down Expand Up @@ -146,12 +149,14 @@ private static void testScheduleTask(EventLoop loopA) throws InterruptedExceptio
is(greaterThanOrEqualTo(TimeUnit.MILLISECONDS.toNanos(500))));
}

@ResourceLock(value = "scheduler_timing", mode = ResourceAccessMode.READ_WRITE)
@Test
@Timeout(value = 5000, unit = TimeUnit.MILLISECONDS)
public void scheduleTaskAtFixedRateA() throws Exception {
testScheduleTaskAtFixedRate(loopA);
}

@ResourceLock(value = "scheduler_timing", mode = ResourceAccessMode.READ_WRITE)
@Test
@Timeout(value = 5000, unit = TimeUnit.MILLISECONDS)
public void scheduleTaskAtFixedRateB() throws Exception {
Expand All @@ -164,11 +169,6 @@ private static void testScheduleTaskAtFixedRate(EventLoop loopA) throws Interrup
final CountDownLatch allTimeStampsLatch = new CountDownLatch(expectedTimeStamps);
Future<?> f = loopA.scheduleAtFixedRate(() -> {
timestamps.add(System.nanoTime());
try {
Thread.sleep(50);
} catch (InterruptedException e) {
// Ignore
}
allTimeStampsLatch.countDown();
}, 100, 100, TimeUnit.MILLISECONDS);
allTimeStampsLatch.await();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,12 @@
import static org.junit.jupiter.api.Assertions.assertEquals;

public class LocalTransportThreadModelTest2 {

private static final String LOCAL_CHANNEL = LocalTransportThreadModelTest2.class.getName();

static final int messageCountPerRun = 4;

@Test
@Timeout(value = 15000, unit = TimeUnit.MILLISECONDS)
public void testSocketReuse() throws Exception {
LocalAddress address = new LocalAddress(LocalTransportThreadModelTest2.class);
ServerBootstrap serverBootstrap = new ServerBootstrap();
LocalHandler serverHandler = new LocalHandler("SERVER");
serverBootstrap
Expand All @@ -53,9 +51,9 @@ public void testSocketReuse() throws Exception {
clientBootstrap
.group(new MultithreadEventLoopGroup(io.netty5.channel.local.LocalHandler.newFactory()))
.channel(LocalChannel.class)
.remoteAddress(new LocalAddress(LOCAL_CHANNEL)).handler(clientHandler);
.remoteAddress(address).handler(clientHandler);

serverBootstrap.bind(new LocalAddress(LOCAL_CHANNEL)).asStage().sync();
serverBootstrap.bind(address).asStage().sync();

int count = 100;
for (int i = 1; i < count + 1; i ++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,9 @@
import java.nio.channels.SocketChannel;
import java.nio.channels.spi.SelectorProvider;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.RejectedExecutionException;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.IntSupplier;

import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
Expand Down Expand Up @@ -160,35 +156,6 @@ public void testSelectableChannel() throws Exception {
}
}

@Test
public void testTaskRemovalOnShutdownThrowsNoUnsupportedOperationException() throws Exception {
final AtomicReference<Throwable> error = new AtomicReference<>();
final Runnable task = () -> {
// NOOP
};
// Just run often enough to trigger it normally.
for (int i = 0; i < 1000; i++) {
EventLoopGroup group = new MultithreadEventLoopGroup(1, NioHandler.newFactory());
final EventLoop loop = group.next();

Thread t = new Thread(() -> {
try {
for (;;) {
loop.execute(task);
}
} catch (Throwable cause) {
error.set(cause);
}
});
t.start();
Future<?> termination = group.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS);
t.join();
termination.asStage().sync();
assertThat(error.get(), instanceOf(RejectedExecutionException.class));
error.set(null);
}
}

@Test
public void testRebuildSelectorOnIOException() throws Exception {
final CountDownLatch latch = new CountDownLatch(1);
Expand Down

0 comments on commit 438b436

Please sign in to comment.