Skip to content

Commit

Permalink
Fixed flaky tests caused by a race in websocket client's stop method
Browse files Browse the repository at this point in the history
  • Loading branch information
Zhen committed Dec 18, 2015
1 parent 0e69b03 commit 640ffe0
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 38 deletions.
Expand Up @@ -91,7 +91,7 @@ public static Collection<Object[]> documentedFullProtocolExamples()
@After
public void shutdown() throws Exception
{
client.close();
client.disconnect();
}

@Test
Expand Down
Expand Up @@ -66,8 +66,7 @@ public static Collection<Object[]> transports()
return asList(
new Object[]{
(Factory<Connection>) SecureWebSocketConnection::new,
new IOException( "Failed to connect to the server within 30 seconds" )
// TODO shutdown ssl connections properly
new IOException( "Failed to connect to the server within 10 seconds" )
},
new Object[]{
(Factory<Connection>) SecureSocketConnection::new,
Expand Down
Expand Up @@ -23,7 +23,7 @@

import org.neo4j.helpers.HostnamePort;

public interface Connection extends AutoCloseable
public interface Connection
{
Connection connect( HostnamePort address ) throws Exception;

Expand Down
Expand Up @@ -104,10 +104,4 @@ public void disconnect() throws IOException
socket.close();
}
}

@Override
public void close() throws Exception
{
disconnect();
}
}
Expand Up @@ -43,17 +43,19 @@ public class WebSocketConnection implements Connection, WebSocketListener
private final Supplier<WebSocketClient> clientSupplier;
private final Function<HostnamePort,URI> uriGenerator;

private final byte[] POISON_PILL = "poison".getBytes();

private WebSocketClient client;
private RemoteEndpoint server;

// Incoming data goes on this queue
private final LinkedBlockingQueue<byte[]> received = new LinkedBlockingQueue<>();

// Current input data being handled, popped off of 'received' queue
private byte[] currentRecieveBuffer = null;
private byte[] currentReceiveBuffer = null;

// Index into the current receive buffer
private int currentRecieveIndex = 0;
private int currentReceiveIndex = 0;

public WebSocketConnection()
{
Expand All @@ -66,6 +68,12 @@ public WebSocketConnection( Supplier<WebSocketClient> clientSupplier, Function<H
this.uriGenerator = uriGenerator;
}

WebSocketConnection( WebSocketClient client )
{
this( null, null );
this.client = client;
}

@Override
public Connection connect( HostnamePort address ) throws Exception
{
Expand All @@ -77,12 +85,13 @@ public Connection connect( HostnamePort address ) throws Exception
Session session = null;
try
{
session = client.connect( this, target ).get( 30, SECONDS );
session = client.connect( this, target ).get( 10, SECONDS );
}
catch ( Exception e )
{
throw new IOException( "Failed to connect to the server within 30 seconds", e );
throw new IOException( "Failed to connect to the server within 10 seconds", e );
}

server = session.getRemote();
return this;
}
Expand All @@ -105,9 +114,9 @@ public byte[] recv( int length ) throws Exception
while ( remaining > 0 )
{
waitForRecievedData( length, remaining, target );
for ( int i = 0; i < Math.min( remaining, currentRecieveBuffer.length - currentRecieveIndex ); i++ )
for ( int i = 0; i < Math.min( remaining, currentReceiveBuffer.length - currentReceiveIndex ); i++ )
{
target[length - remaining] = currentRecieveBuffer[currentRecieveIndex++];
target[length - remaining] = currentReceiveBuffer[currentReceiveIndex++];
remaining--;
}
}
Expand All @@ -124,19 +133,21 @@ private void waitForRecievedData( int length, int remaining, byte[] target )
throws InterruptedException, IOException
{
long start = System.currentTimeMillis();
while ( currentRecieveBuffer == null || currentRecieveIndex >= currentRecieveBuffer.length )
while ( currentReceiveBuffer == null || currentReceiveIndex >= currentReceiveBuffer.length )
{
currentRecieveIndex = 0;
currentRecieveBuffer = received.poll( 10, MILLISECONDS );
currentReceiveIndex = 0;
currentReceiveBuffer = received.poll( 10, MILLISECONDS );

if( client.isStopped() || client.isStopping() )
if( (currentReceiveBuffer == null && ( client.isStopped() || client.isStopping() ) ) ||
currentReceiveBuffer == POISON_PILL )
{
// no data received
throw new IOException( "Connection closed while waiting for data from the server." );
}
if ( System.currentTimeMillis() - start > 30_000 )
{
throw new IOException( "Waited 30 seconds for " + remaining + " bytes, " +
"" + (length - remaining) + " was recieved: " +
"" + (length - remaining) + " was received: " +
HexPrinter.hex( ByteBuffer.wrap( target ), 0, length - remaining ) );
}
}
Expand All @@ -145,17 +156,9 @@ private void waitForRecievedData( int length, int remaining, byte[] target )
@Override
public void disconnect() throws Exception
{
close();
client.stop();
}

@Override
public void close() throws Exception
{
if ( client != null )
{
client.stop();
}
}

@Override
public void onWebSocketBinary( byte[] bytes, int i, int i2 )
Expand All @@ -166,14 +169,7 @@ public void onWebSocketBinary( byte[] bytes, int i, int i2 )
@Override
public void onWebSocketClose( int i, String s )
{
try
{
close();
}
catch ( Exception e )
{
throw new RuntimeException( e );
}
received.add( POISON_PILL );
}

@Override
Expand Down
@@ -0,0 +1,74 @@
/*
* Copyright (c) 2002-2015 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Neo4j is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package org.neo4j.bolt.v1.transport.socket.client;

import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import java.io.IOException;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class WebSocketConnectionTest
{

@Rule
public ExpectedException expectedException = ExpectedException.none();

@Test
public void shouldNotThrowAnyExceptionWhenDataReceivedBeforeClose() throws Throwable
{
// Given
WebSocketClient client = mock( WebSocketClient.class );
WebSocketConnection conn = new WebSocketConnection( client );
when( client.isStopped() ).thenReturn( true );

byte[] data = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};

// When
conn.onWebSocketBinary( data, 0, 10 );
conn.recv( 10 );

// Then
// no exception
}


@Test
public void shouldThrowIOExceptionWhenNotEnoughDataReceivedBeforeClose() throws Throwable
{
// Given
WebSocketClient client = mock( WebSocketClient.class );
WebSocketConnection conn = new WebSocketConnection( client );
when( client.isStopped() ).thenReturn( true, true );

byte[] data = {0, 1, 2, 3};

// When && Then
conn.onWebSocketBinary( data, 0, 4 );

expectedException.expect( IOException.class );
expectedException.expectMessage( "Connection closed while waiting for data from the server." );
conn.recv( 10 );
}
}

0 comments on commit 640ffe0

Please sign in to comment.