From 428ff129662db9e9a37be9948acb1d3a08f7e485 Mon Sep 17 00:00:00 2001 From: Ali Ince Date: Mon, 27 Nov 2017 14:06:24 +0000 Subject: [PATCH] changed how locks are monitored in tests --- .../transport/TransportWriteThrottleTest.java | 76 +++++++++++++++---- 1 file changed, 63 insertions(+), 13 deletions(-) diff --git a/community/bolt/src/test/java/org/neo4j/bolt/transport/TransportWriteThrottleTest.java b/community/bolt/src/test/java/org/neo4j/bolt/transport/TransportWriteThrottleTest.java index 23a3ad5276125..b6639faf993f4 100644 --- a/community/bolt/src/test/java/org/neo4j/bolt/transport/TransportWriteThrottleTest.java +++ b/community/bolt/src/test/java/org/neo4j/bolt/transport/TransportWriteThrottleTest.java @@ -32,20 +32,23 @@ import org.mockito.Answers; import org.mockito.ArgumentCaptor; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.Future; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.is; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.Mockito.atLeastOnce; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; @@ -55,12 +58,13 @@ public class TransportWriteThrottleTest private ChannelHandlerContext context; private Channel channel; private SocketChannelConfig config; - private ThrottleLock lock; + private TestThrottleLock lock; @Before public void setup() throws Exception { - lock = mock( ThrottleLock.class ); + lock = new TestThrottleLock(); + config = mock( SocketChannelConfig.class ); Attribute lockAttribute = mock( Attribute.class ); @@ -116,7 +120,8 @@ public void shouldNotLockWhenWritable() throws Exception } assertTrue( future.isDone() ); - verify( lock, never() ).lock( any(), anyLong() ); + assertThat( lock.lockCallCount(), is( 0 ) ); + assertThat( lock.unlockCallCount(), is( 0 ) ); } @Test @@ -146,8 +151,8 @@ public void shouldLockWhenNotWritable() throws Exception } assertFalse( future.isDone() ); - verify( lock, atLeastOnce() ).lock( any(), anyLong() ); - verify( lock, never() ).unlock( any() ); + assertThat( lock.lockCallCount(), greaterThan( 0 ) ); + assertThat( lock.unlockCallCount(), is( 0 ) ); } @Test @@ -161,8 +166,8 @@ public void shouldResumeWhenWritableOnceAgain() throws Exception throttle.acquire( channel ); // expect - verify( lock, atLeastOnce() ).lock( any(), anyLong() ); - verify( lock, never() ).unlock( any() ); + assertThat( lock.lockCallCount(), greaterThan( 0 ) ); + assertThat( lock.unlockCallCount(), is( 0 ) ); } @Test @@ -173,7 +178,12 @@ public void shouldResumeWhenWritabilityChanged() throws Exception when( channel.isWritable() ).thenReturn( false ); Future future = Executors.newSingleThreadExecutor().submit( () -> throttle.acquire( channel ) ); - Thread.sleep( 500 ); + + // Wait until lock is acquired. + if (!lock.waitLocked( 1, TimeUnit.SECONDS )) + { + fail( "lock should be acquired" ); + } // when when( channel.isWritable() ).thenReturn( true ); @@ -191,8 +201,8 @@ public void shouldResumeWhenWritabilityChanged() throws Exception fail( "should not throw" ); } - verify( lock, atLeastOnce() ).lock( any(), anyLong() ); - verify( lock, times( 1 ) ).unlock( any() ); + assertThat( lock.lockCallCount(), greaterThan( 0 ) ); + assertThat( lock.unlockCallCount(), is( 1 ) ); } private TransportThrottle newThrottle() @@ -209,4 +219,44 @@ private TransportThrottle newThrottleAndInstall( Channel channel ) return throttle; } + private static class TestThrottleLock implements ThrottleLock + { + private AtomicInteger lockCount = new AtomicInteger( 0 ); + private AtomicInteger unlockCount = new AtomicInteger( 0 ); + private Semaphore semaphore = new Semaphore( 1 ); + private volatile CountDownLatch lockWaiter = new CountDownLatch( 1 ); + + @Override + public void lock( Channel channel, long timeout ) throws InterruptedException + { + semaphore.acquire(); + lockCount.incrementAndGet(); + lockWaiter.countDown(); + } + + @Override + public void unlock( Channel channel ) + { + semaphore.release(); + unlockCount.incrementAndGet(); + lockWaiter = new CountDownLatch( 1 ); + } + + public boolean waitLocked( long timeout, TimeUnit unit ) throws InterruptedException + { + return lockWaiter.await( timeout, unit ); + } + + public int lockCallCount() + { + return lockCount.get(); + } + + public int unlockCallCount() + { + return unlockCount.get(); + } + + } + }