Skip to content

Commit

Permalink
Merge d71fda1 into 0021d7e
Browse files Browse the repository at this point in the history
  • Loading branch information
Moutix committed Apr 2, 2017
2 parents 0021d7e + d71fda1 commit 7b900f7
Show file tree
Hide file tree
Showing 8 changed files with 258 additions and 30 deletions.
11 changes: 6 additions & 5 deletions smserver/smutils/smconn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Base module for handling all type of connection
"""

import abc
import datetime
import uuid

Expand All @@ -14,12 +15,12 @@


class StepmaniaConn(object):
""" A stepmania connection is represented by a token in the database """

log = logger.get_logger()
ENCODING = "binary"
ALLOWED_PACKET = []

""" A stepmania connection is represented by a token in the database """

def __init__(self, serv, ip, port):
self.mutex = Lock()

Expand Down Expand Up @@ -85,10 +86,10 @@ def send(self, packet):
)

self.log.debug("packet send to %s: %s", self.ip, packet)
self._send_data(packet.to_(self.ENCODING))
self.send_data(packet.to_(self.ENCODING))

def _send_data(self, data):
pass
def send_data(self, data):
""" Send biary data to the client """

def close(self):
""" Close the connection """
Expand Down
45 changes: 33 additions & 12 deletions smserver/smutils/smconnections/asynctcpserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@ def run(self):
data = data_left
data_left = b""
else:
data = (yield from self.reader.read(8192))

try:
data = yield from self.reader.read(8192)
except asyncio.CancelledError:
break
if data == b'':
break

Expand Down Expand Up @@ -57,7 +59,7 @@ def run(self):

self.close()

def _send_data(self, data):
def send_data(self, data):
self.writer.write(data)
self.loop.create_task(self.writer.drain())

Expand All @@ -67,10 +69,10 @@ def close(self):


class AsyncSocketServer(smconn.SMThread):
def __init__(self, server, ip, port):
def __init__(self, server, ip, port, loop=None):
smconn.SMThread.__init__(self, server, ip, port)

self.loop = asyncio.new_event_loop()
self.loop = loop or asyncio.new_event_loop()
self._serv = None
self.clients = {}

Expand All @@ -91,16 +93,24 @@ def client_done(task):
client.task.add_done_callback(client_done)

def run(self):
self._serv = self.loop.run_until_complete(
asyncio.streams.start_server(self._accept_client,
self.ip, self.port,
loop=self.loop))
self.start_server()
self.loop.run_forever()
self.loop.close()
smconn.SMThread.run(self)

def stop(self):
smconn.SMThread.stop(self)
def start_server(self):
""" Start the server in the given loop """

self._serv = self.loop.run_until_complete(asyncio.start_server(
self._accept_client,
host=self.ip,
port=self.port,
loop=self.loop,
))
return self._serv

def stop_server(self):
""" Stop the server in the given loop """

if self._serv is None:
return
Expand All @@ -109,5 +119,16 @@ def stop(self):
for sock in self._serv.sockets:
sock.shutdown(socket.SHUT_RDWR)

self.loop.stop()
self._serv.close()
self.loop.run_until_complete(
asyncio.wait_for(self._serv.wait_closed(), timeout=1, loop=self.loop)
)
for task in self.clients:
task.cancel()

self.loop.run_until_complete(asyncio.gather(*self.clients))

def stop(self):
smconn.SMThread.stop(self)
self.stop_server()
self.loop.stop()
2 changes: 1 addition & 1 deletion smserver/smutils/smconnections/smtcpsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def received_data(self):
full_data = b""
size = None

def _send_data(self, data):
def send_data(self, data):
with self.mutex:
try:
self._conn.sendall(data)
Expand Down
2 changes: 1 addition & 1 deletion smserver/smutils/smconnections/udpsocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def __init__(self, serv, ip, port, data):
def received_data(self):
yield self._data

def _send_data(self, data):
def send_data(self, data):
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM, socket.IPPROTO_UDP)
sock.sendto(data, (self.ip, self.port))

Expand Down
38 changes: 27 additions & 11 deletions smserver/smutils/smconnections/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,28 @@ def run(self):
data = yield from self.websocket.recv()
except websockets.ConnectionClosed:
break

self._on_data(data)

self.close()

def _send_data(self, data):
def send_data(self, data):
self.loop.create_task(self.websocket.send(data))

def close(self):
self._serv.on_disconnect(self)
self.websocket.close()

class WebSocketServer(smconn.SMThread):
def __init__(self, server, ip, port):
def __init__(self, server, ip, port, loop=None):
smconn.SMThread.__init__(self, server, ip, port)

self.loop = asyncio.new_event_loop()
self.loop = loop or asyncio.new_event_loop()

self._serv = None
self.server = server
self.daemon = True
self.ip = ip
self.port = port
self.clients = {}

@asyncio.coroutine
def _accept_client(self, websocket, path=""):
Expand All @@ -60,14 +58,25 @@ def _accept_client(self, websocket, path=""):
raise

def run(self):
self._serv = self.loop.run_until_complete(
websockets.serve(self._accept_client, self.ip, self.port, loop=self.loop))

self.start_server()
self.loop.run_forever()
smconn.SMThread.run(self)

def stop(self):
smconn.SMThread.stop(self)
def start_server(self):
""" Start the websocket server """

self._serv = self.loop.run_until_complete(
websockets.serve(
self._accept_client,
host=self.ip,
port=self.port,
loop=self.loop,
)
)
return self._serv

def stop_server(self):
""" Stop the server in the given loop """

if self._serv is None:
return
Expand All @@ -77,5 +86,12 @@ def stop(self):
for sock in sockets:
sock.shutdown(socket.SHUT_RDWR)

self.loop.stop()
self._serv.close()
self.loop.run_until_complete(
asyncio.wait_for(self._serv.wait_closed(), timeout=1)
)

def stop(self):
smconn.SMThread.stop(self)
self.stop_server()
self.loop.stop()
Empty file.
113 changes: 113 additions & 0 deletions test/test_smutils/test_smconnections/test_asynctcpserver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
""" Test the async tcp server connection """

import unittest
import socket

import asyncio
import mock

from smserver.smutils.smconnections import asynctcpserver

class AsyncSocketServerTest(unittest.TestCase):
""" Test the thread which handle async tcp connection """

def setUp(self):
super().setUp()
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)

sock = socket.socket()
sock.bind(("127.0.0.1", 0))
self.ip, self.port = sock.getsockname()

self.mock_server = mock.MagicMock()

self.server = asynctcpserver.AsyncSocketServer(
self.mock_server, self.ip, self.port, self.loop
)

self.reader, self.writer = None, None

def tearDown(self):
self.loop.close()
self.mock_server.reset_mock()

@asyncio.coroutine
def client_connection(self):
""" Coroutine to open the client connection """

reader, writer = yield from asyncio.open_connection(
self.ip, self.port, loop=self.loop
)
return reader, writer

def start_client(self):
""" Start the client """

self.reader, self.writer = self.loop.run_until_complete(
self.client_connection()
)

def start_server(self):
""" Start the server """

self.server.start_server()

def stop_server(self):
""" Stop the server """

self.server.stop_server()


def test_server_close_while_client_connected(self):
""" Try stopping the server during client connection """

self.start_server()
self.start_client()
self.stop_server()
self.mock_server.add_connection.assert_called_once()

@mock.patch("smserver.smutils.smconnections.asynctcpserver.AsyncSocketClient._on_data")
def test_valid_package(self, on_data):
""" Test sending data to the server """

self.start_server()
self.start_client()
self.mock_server.add_connection.assert_called_once()
self.writer.write(b"\x00\x00\x00\x01\x54")
self.loop.run_until_complete(self.writer.drain())

on_data.assert_called_with(b"\x00\x00\x00\x01\x54")

self.stop_server()

@mock.patch("smserver.smutils.smconnections.asynctcpserver.AsyncSocketClient._on_data")
def test_double_valid_package(self, on_data):
""" Test sending data to the server """

self.start_server()
self.start_client()
self.mock_server.add_connection.assert_called_once()
self.writer.write(b"\x00\x00\x00\x01\x54\x00\x00\x00\x01\x55")
self.loop.run_until_complete(self.writer.drain())

self.assertEqual(on_data.call_count, 2)
self.assertEqual(on_data.call_args_list[0][0][0], b"\x00\x00\x00\x01\x54")
self.assertEqual(on_data.call_args_list[1][0][0], b"\x00\x00\x00\x01\x55")

self.stop_server()

@mock.patch("smserver.smutils.smconnections.asynctcpserver.AsyncSocketClient._on_data")
def test_invalid_package(self, on_data):
""" Test sending data to the server """

self.start_server()
self.start_client()
self.mock_server.add_connection.assert_called_once()
self.writer.write(b"\x00\x00\x43")
self.writer.write(b"\x00\x00\x00\x45")
self.loop.run_until_complete(self.writer.drain())

on_data.assert_not_called()

self.stop_server()
Loading

0 comments on commit 7b900f7

Please sign in to comment.