Skip to content

Commit

Permalink
Fixed SessionTrackingTest.
Browse files Browse the repository at this point in the history
Introduced WebSocketSession.Listener that can be used to be notified
of session opening and close, so that tests can be written more reliably.
  • Loading branch information
sbordet committed Jun 13, 2016
1 parent df2af60 commit ebee9f1
Show file tree
Hide file tree
Showing 3 changed files with 132 additions and 88 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,8 @@

package org.eclipse.jetty.websocket.jsr356.server;

import static org.hamcrest.Matchers.is;
import static org.junit.Assert.assertThat;

import java.net.URI;
import java.util.Collection;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

Expand All @@ -36,133 +33,104 @@
import org.eclipse.jetty.server.Server;
import org.eclipse.jetty.server.ServerConnector;
import org.eclipse.jetty.servlet.ServletContextHandler;
import org.eclipse.jetty.util.StringUtil;
import org.eclipse.jetty.util.thread.QueuedThreadPool;
import org.eclipse.jetty.websocket.common.WebSocketSession;
import org.eclipse.jetty.websocket.jsr356.ClientContainer;
import org.eclipse.jetty.websocket.jsr356.server.deploy.WebSocketServerContainerInitializer;
import org.eclipse.jetty.websocket.server.WebSocketServerFactory;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Ignore;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class SessionTrackingTest
{
public static class ClientSocket extends Endpoint
{
public Session session;
public CountDownLatch openLatch = new CountDownLatch(1);
public CountDownLatch closeLatch = new CountDownLatch(1);

@Override
public void onOpen(Session session, EndpointConfig config)
{
this.session = session;
openLatch.countDown();
}

@Override
public void onClose(Session session, CloseReason closeReason)
{
closeLatch.countDown();
}

public void waitForOpen(long timeout, TimeUnit unit) throws InterruptedException
{
assertThat("ClientSocket opened",openLatch.await(timeout,unit),is(true));
}
private Server server;
private ServerContainer serverContainer;
private WebSocketServerFactory wsServerFactory;
private URI serverURI;

public void waitForClose(long timeout, TimeUnit unit) throws InterruptedException
{
assertThat("ClientSocket opened",closeLatch.await(timeout,unit),is(true));
}
}

@ServerEndpoint("/test")
public static class EchoSocket
@Before
public void startServer() throws Exception
{
@OnMessage
public String echo(String msg)
{
return msg;
}
}

private static Server server;
private static WebSocketServerFactory wsServerFactory;
private static URI serverURI;

@BeforeClass
public static void startServer() throws Exception
{
Server server = new Server();
QueuedThreadPool serverThreads = new QueuedThreadPool();
serverThreads.setName("server");
server = new Server(serverThreads);
ServerConnector serverConnector = new ServerConnector(server);
serverConnector.setPort(0);
server.addConnector(serverConnector);
ServletContextHandler servletContextHandler = new ServletContextHandler(ServletContextHandler.NO_SESSIONS);
servletContextHandler.setContextPath("/");
server.setHandler(servletContextHandler);

ServerContainer serverContainer = WebSocketServerContainerInitializer.configureContext(servletContextHandler);
serverContainer = WebSocketServerContainerInitializer.configureContext(servletContextHandler);
serverContainer.addEndpoint(EchoSocket.class);

wsServerFactory = serverContainer.getBean(WebSocketServerFactory.class);

server.start();

String host = serverConnector.getHost();
if (StringUtil.isBlank(host))
{
host = "localhost";
}
serverURI = new URI("ws://" + host + ":" + serverConnector.getLocalPort());
serverURI = new URI("ws://localhost:" + serverConnector.getLocalPort());
}

@AfterClass
public static void stopServer() throws Exception
@After
public void stopServer() throws Exception
{
if (server == null)
{
return;
}

server.stop();
if (server != null)
server.stop();
}

@Test
@Ignore
public void testAddRemoveSessions() throws Exception
{
// Create Client
ClientContainer clientContainer = new ClientContainer();
QueuedThreadPool clientThreads = new QueuedThreadPool();
clientThreads.setName("client");
clientContainer.getClient().setExecutor(clientThreads);
try
{
CountDownLatch openedLatch = new CountDownLatch(2);
CountDownLatch closedLatch = new CountDownLatch(2);
wsServerFactory.addSessionListener(new WebSocketSession.Listener()
{
@Override
public void onOpened(WebSocketSession session)
{
openedLatch.countDown();
}

@Override
public void onClosed(WebSocketSession session)
{
closedLatch.countDown();
}
});

clientContainer.start();

// Establish connections
ClientSocket cli1 = new ClientSocket();
clientContainer.connectToServer(cli1,serverURI.resolve("/test"));
cli1.waitForOpen(1,TimeUnit.SECONDS);

// Assert open connections
assertServerOpenConnectionCount(1);
clientContainer.connectToServer(cli1, serverURI.resolve("/test"));
cli1.waitForOpen(1, TimeUnit.SECONDS);

// Establish new connection
ClientSocket cli2 = new ClientSocket();
clientContainer.connectToServer(cli2,serverURI.resolve("/test"));
cli2.waitForOpen(1,TimeUnit.SECONDS);
clientContainer.connectToServer(cli2, serverURI.resolve("/test"));
cli2.waitForOpen(1, TimeUnit.SECONDS);

// Assert open connections
openedLatch.await(5, TimeUnit.SECONDS);
assertServerOpenConnectionCount(2);

// Establish close both connections
cli1.session.close();
cli2.session.close();

cli1.waitForClose(1,TimeUnit.SECONDS);
cli2.waitForClose(1,TimeUnit.SECONDS);
cli1.waitForClose(1, TimeUnit.SECONDS);
cli2.waitForClose(1, TimeUnit.SECONDS);

// Assert open connections
closedLatch.await(5, TimeUnit.SECONDS);
assertServerOpenConnectionCount(0);
}
finally
Expand All @@ -173,13 +141,53 @@ public void testAddRemoveSessions() throws Exception

private void assertServerOpenConnectionCount(int expectedCount)
{
Collection<WebSocketSession> sessions = wsServerFactory.getBeans(WebSocketSession.class);
Set<Session> sessions = serverContainer.getOpenSessions();
int openCount = 0;
for (WebSocketSession session : sessions)
for (Session session : sessions)
{
assertThat("Session.isopen: " + session,session.isOpen(),is(true));
Assert.assertThat("Session.isopen: " + session, session.isOpen(), Matchers.is(true));
openCount++;
}
assertThat("Open Session Count",openCount,is(expectedCount));
Assert.assertThat("Open Session Count", openCount, Matchers.is(expectedCount));
}

private static class ClientSocket extends Endpoint
{
private Session session;
private CountDownLatch openLatch = new CountDownLatch(1);
private CountDownLatch closeLatch = new CountDownLatch(1);

@Override
public void onOpen(Session session, EndpointConfig config)
{
this.session = session;
openLatch.countDown();
}

@Override
public void onClose(Session session, CloseReason closeReason)
{
closeLatch.countDown();
}

public void waitForOpen(long timeout, TimeUnit unit) throws InterruptedException
{
Assert.assertThat("ClientSocket opened", openLatch.await(timeout, unit), Matchers.is(true));
}

public void waitForClose(long timeout, TimeUnit unit) throws InterruptedException
{
Assert.assertThat("ClientSocket opened", closeLatch.await(timeout, unit), Matchers.is(true));
}
}

@ServerEndpoint("/test")
public static class EchoSocket
{
@OnMessage
public String echo(String msg)
{
return msg;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -609,4 +609,10 @@ public String toString()
return builder.toString();
}

public static interface Listener
{
void onOpened(WebSocketSession session);

void onClosed(WebSocketSession session);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,9 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.Executor;
import java.util.function.Consumer;

import javax.servlet.ServletContext;
import javax.servlet.ServletException;
Expand Down Expand Up @@ -88,6 +90,7 @@ public class WebSocketServerFactory extends ContainerLifeCycle implements WebSoc
* Have the factory maintain 1 and only 1 scheduler. All connections share this scheduler.
*/
private final Scheduler scheduler = new ScheduledExecutorScheduler();
private final List<WebSocketSession.Listener> listeners = new CopyOnWriteArrayList<>();
private final String supportedVersions;
private final WebSocketPolicy defaultPolicy;
private final EventDriverFactory eventDriverFactory;
Expand Down Expand Up @@ -153,6 +156,16 @@ public WebSocketServerFactory(WebSocketPolicy policy, ByteBufferPool bufferPool)
supportedVersions = rv.toString();
}

public void addSessionListener(WebSocketSession.Listener listener)
{
listeners.add(listener);
}

public void removeSessionListener(WebSocketSession.Listener listener)
{
listeners.remove(listener);
}

@Override
public boolean acceptWebSocket(HttpServletRequest request, HttpServletResponse response) throws IOException
{
Expand Down Expand Up @@ -453,16 +466,33 @@ public boolean isUpgradeRequest(HttpServletRequest request, HttpServletResponse
return true;
}

@Override
public void onSessionOpened(WebSocketSession session)
{
addManaged(session);
notifySessionListeners(listener -> listener.onOpened(session));
}

@Override
public void onSessionClosed(WebSocketSession session)
{
removeBean(session);
notifySessionListeners(listener -> listener.onClosed(session));
}

@Override
public void onSessionOpened(WebSocketSession session)
private void notifySessionListeners(Consumer<WebSocketSession.Listener> consumer)
{
addManaged(session);
for (WebSocketSession.Listener listener : listeners)
{
try
{
consumer.accept(listener);
}
catch (Throwable x)
{
LOG.info("Exception while invoking listener " + listener, x);
}
}
}

@Override
Expand Down

0 comments on commit ebee9f1

Please sign in to comment.