Skip to content

Commit

Permalink
unit testing framework
Browse files Browse the repository at this point in the history
  • Loading branch information
miguelgrinberg committed Feb 17, 2014
1 parent b21e23a commit 9633aba
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 36 deletions.
70 changes: 35 additions & 35 deletions flask_socketio.py → flask_socketio/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import logging
from gevent import monkey
from socketio import socketio_manage
from socketio.server import SocketIOServer
from socketio.namespace import BaseNamespace
from flask import request, session
from flask.ctx import RequestContext
from werkzeug.debug import DebuggedApplication
from werkzeug.serving import run_with_reloader
from test_client import SocketIOTestClient

monkey.patch_all()

Expand All @@ -26,7 +25,6 @@ def __call__(self, environ, start_response):
else:
return self.wsgi_app(environ, start_response)


class SocketIO(object):
def __init__(self, app=None):
if app:
Expand All @@ -36,11 +34,11 @@ def __init__(self, app=None):
def init_app(self, app):
app.wsgi_app = SocketIOMiddleware(app, self)

def get_namespaces(self):
class GenericNamespace(BaseNamespace):
def get_namespaces(self, base_namespace=BaseNamespace):
class GenericNamespace(base_namespace):
socketio = self
base_emit = BaseNamespace.emit
base_send = BaseNamespace.send
base_emit = base_namespace.emit
base_send = base_namespace.send

def process_event(self, packet):
message = packet['name']
Expand Down Expand Up @@ -68,64 +66,63 @@ def recv_json(self, data):
self.socketio.dispatch_message(app, self, 'json', [data])

def emit(self, event, *args, **kwargs):
namespace = kwargs.pop('namespace', None)
ns_name = kwargs.pop('namespace', None)
broadcast = kwargs.pop('broadcast', False)
if broadcast:
if namespace is None:
namespace = self.ns_name
if ns_name is None:
ns_name = self.ns_name
callback = kwargs.pop('callback', None)
ret = None
for sessid, socket in self.socket.server.sockets.items():
if socket == self.socket:
ret = self.base_emit(event, *args, callback=callback, **kwargs)
else:
socket[namespace].base_emit(event, *args, **kwargs)
socket[ns_name].base_emit(event, *args, **kwargs)
return ret
if namespace is None:
if ns_name is None:
return self.base_emit(event, *args, **kwargs)
return request.namespace.socket[namespace].base_emit(event, *args, **kwargs)
return request.namespace.socket[ns_name].base_emit(event, *args, **kwargs)

def send(message, json=False, namespace=None, callback=None, broadcast=False):
def send(self, message, json=False, ns_name=None, callback=None, broadcast=False):
if broadcast:
if namespace is None:
namespace = self.ns_name
if ns_name is None:
ns_name = self.ns_name
ret = None
for sessid, socket in self.socket.server.sockets.items():
if socket == request.namespace.socket:
ret = self.base_send(message, json, callback=callback)
else:
socket[namespace].base_send(message, json)
socket[ns_name].base_send(message, json)
return ret
if namespace is None:
if ns_name is None:
return request.namespace.base_send(message, json, callback)
return request.namespace.socket[namespace].base_send(message, json, callback)
return request.namespace.socket[ns_name].base_send(message, json, callback)

namespaces = {}
for namespace in self.messages.keys():
if namespace == '/':
namespace = ''
namespaces[namespace] = GenericNamespace
for ns_name in self.messages.keys():
if ns_name == '/':
ns_name = ''
namespaces[ns_name] = GenericNamespace
return namespaces

def dispatch_message(self, app, namespace, message, args=[]):
if namespace.ns_name not in self.messages:
return
if message not in self.messages[namespace.ns_name]:
return
with app.app_context():
with RequestContext(app, namespace.environ):
request.namespace = namespace
for k, v in namespace.session.items():
session[k] = v
self.messages[namespace.ns_name][message](*args)
for k, v in session.items():
namespace.session[k] = v
with app.request_context(namespace.environ):
request.namespace = namespace
for k, v in namespace.session.items():
session[k] = v
self.messages[namespace.ns_name][message](*args)
for k, v in session.items():
namespace.session[k] = v

def on_message(self, message, handler, **options):
namespace = options.pop('namespace', '/')
if namespace not in self.messages:
self.messages[namespace] = {}
self.messages[namespace][message] = handler
ns_name = options.pop('namespace', '')
if ns_name not in self.messages:
self.messages[ns_name] = {}
self.messages[ns_name][message] = handler

def on(self, message, **options):
def decorator(f):
Expand All @@ -151,6 +148,9 @@ def run_server():
else:
SocketIOServer((host, port), app.wsgi_app, resource='socket.io').serve_forever()

def test_client(self, app, namespace=None):
return SocketIOTestClient(app, self, namespace)


def emit(event, *args, **kwargs):
return request.namespace.emit(event, *args, **kwargs)
Expand Down
113 changes: 113 additions & 0 deletions flask_socketio/test_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
This module contains a collection of auxiliary mock objects used by
unit tests.
"""


class TestServer(object):
counter = 0

def __init__(self):
self.sockets = {}

def add_socket(self, socket):
self.sockets[self.counter] = socket
self.counter += 1

def remove_socket(self, socket):
for id, s in self.sockets.items():
if s == socket:
del self.sockets[id]
return


class TestSocket(object):

This comment has been minimized.

def __init__(self, server):
self.server = server
self.namespace = {}

def __getitem__(self, ns_name):
return self.namespace[ns_name]


class TestBaseNamespace(object):
def __init__(self, ns_name, socket, request=None):
from werkzeug.test import EnvironBuilder
self.environ = EnvironBuilder().get_environ()
self.ns_name = ns_name
self.socket = socket
self.request = request
self.session = {}
self.received = []

def recv_connect(self):
pass

def recv_disconnect(self):
pass

def emit(self, event, *args, **kwargs):
self.received.append({'name': event, 'args': args})
callback = kwargs.pop('callback', None)
if callback:
callback()

def send(self, message, json=False, callback=None):
if not json:
self.received.append({'name': 'message', 'args': message})
else:
self.received.append({'name': 'json', 'args': message})
if callback:
callback()


class SocketIOTestClient(object):

This comment has been minimized.

Copy link
@TronPaul

TronPaul Feb 17, 2014

Might be good to add the context manager methods so this could be used with with clauses.

server = TestServer()

def __init__(self, app, socketio, ns_name=None):
self.socketio = socketio
self.socket = TestSocket(self.server)
self.server.add_socket(self.socket)
self.connect(app, ns_name)

def __del__(self):
self.server.remove_socket(self.socket)

def connect(self, app, ns_name=None):
if self.socket.namespace.get(ns_name):
self.disconnect(ns_name)
key_ns_name = ns_name
if ns_name is None or ns_name == '/':
ns_name = ''
self.socket.namespace[ns_name] = \
self.socketio.get_namespaces(
TestBaseNamespace)[ns_name](ns_name, self.socket, app)
self.socket[ns_name].recv_connect()

def disconnect(self, ns_name=None):
if ns_name is None or ns_name == '/':
ns_name = ''
if self.socket[ns_name]:
self.socket[ns_name].recv_disconnect()
del self.socket.namespace[ns_name]

def emit(self, event, *args, **kwargs):
ns_name = kwargs.pop('ns_name', None)
if ns_name is None or ns_name == '/':

This comment has been minimized.

Copy link
@TronPaul

TronPaul Feb 17, 2014

Should probably default to the namespace used in the test_client() call.

ns_name = ''
self.socket[ns_name].process_event({'name': event, 'args': args})

def send(self, message, json=False, namespace=None):
if namespace is None or namespace == '/':
namespace = ''
if not json:
self.socket[namespace].recv_message(message)
else:
self.socket[namespace].recv_json(message)

def get_received(self, namespace=None):
if namespace is None or namespace == '/':
namespace = ''
received = self.socket[namespace].received
self.socket[namespace].received = []
return received
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
author_email='miguelgrinberg50@gmail.com',
description='Socket.IO integration for Flask applications',
long_description=__doc__,
py_modules=['flask_socketio'],
packages=['flask_socketio'],
zip_safe=False,
include_package_data=True,
platforms='any',
install_requires=[
'Flask>=0.9',
'gevent-socketio>=0.3.6'
],
test_suite='test_socketio',
classifiers=[
'Environment :: Web Environment',
'Intended Audience :: Developers',
Expand Down
72 changes: 72 additions & 0 deletions test_socketio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import unittest
import coverage

cov = coverage.coverage()
cov.start()

from flask import Flask
from flask.ext.socketio import SocketIO, send, emit

app = Flask(__name__)
socketio = SocketIO(app)
disconnected = None

@socketio.on('connect')
def on_connect():
send('connected')

@socketio.on('disconnect')
def on_connect():
global disconnected
disconnected = '/'

@socketio.on('connect', namespace='/test')
def on_connect_test():
send('connected-test')

@socketio.on('disconnect', namespace='/test')
def on_disconnect_test():
global disconnected
disconnected = 'test'

@socketio.on('message')
def on_message(message):
send(message)

@socketio.on('message', namespace='/test')
def on_message_test(message):
send(message, json=True)


class TestSocketIO(unittest.TestCase):
@classmethod
def setUpClass(cls):
pass

@classmethod
def tearDownClass(cls):
cov.stop()
cov.report(include='flask_socketio/__init__.py')

def setUp(self):
pass

def tearDown(self):
pass

def test_connect(self):
client = socketio.test_client(app)
received = client.get_received()
self.assertTrue(len(received) == 1)
self.assertTrue(received[0]['args'] == 'connected')
client.disconnect()

def test_connect_namespace(self):
client = socketio.test_client(app, namespace='/test')
received = client.get_received('/test')
self.assertTrue(len(received) == 1)
self.assertTrue(received[0]['args'] == 'connected-test')
client.disconnect('/test')

if __name__ == '__main__':
unittest.main()

1 comment on commit 9633aba

@TronPaul
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My comments are notes for myself for later pull requests.

Please sign in to comment.