Skip to content

Commit

Permalink
Merge pull request #792 from oberstet/fix_791
Browse files Browse the repository at this point in the history
allow WebSocketServerProtocol.onConnect to return a future/deferred; …
  • Loading branch information
oberstet committed Mar 19, 2017
2 parents 50aa7c1 + 363edda commit 5bfe2fa
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 47 deletions.
23 changes: 0 additions & 23 deletions autobahn/asyncio/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,6 @@
from trollius import iscoroutine
from trollius import Future

from autobahn.websocket.types import ConnectionDeny

if hasattr(asyncio, 'ensure_future'):
ensure_future = asyncio.ensure_future
else: # Deprecated since Python 3.4.4
Expand Down Expand Up @@ -199,27 +197,6 @@ class WebSocketServerProtocol(WebSocketAdapterProtocol, protocol.WebSocketServer
Base class for asyncio-based WebSocket server protocols.
"""

def _onConnect(self, request):
# onConnect() will return the selected subprotocol or None
# or a pair (protocol, headers) or raise an HttpException
##
# noinspection PyBroadException
try:
res = self.onConnect(request)
except ConnectionDeny as e:
self.failHandshake(e.reason, e.code)
except Exception as e:
self.failHandshake("Internal server error: {}".format(e), ConnectionDeny.INTERNAL_SERVER_ERROR)
else:
if yields(res):
# if onConnect was an async method, we need to await
# the actual result before calling succeedHandshake
ensure_future(res).add_done_callback(
lambda res: self.succeedHandshake(res.result())
)
else:
self.succeedHandshake(res)


class WebSocketClientProtocol(WebSocketAdapterProtocol, protocol.WebSocketClientProtocol):
"""
Expand Down
20 changes: 1 addition & 19 deletions autobahn/twisted/websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,12 @@
txaio.use_twisted()

import twisted.internet.protocol
from twisted.internet.defer import maybeDeferred
from twisted.internet.interfaces import ITransport
from twisted.internet.error import ConnectionDone, ConnectionAborted, \
ConnectionLost

from autobahn.wamp import websocket
from autobahn.websocket.types import ConnectionRequest, ConnectionResponse, \
ConnectionDeny
from autobahn.websocket.types import ConnectionRequest, ConnectionResponse, ConnectionDeny
from autobahn.websocket import protocol
from autobahn.twisted.util import peer2str, transport_channel_id

Expand Down Expand Up @@ -188,22 +186,6 @@ class WebSocketServerProtocol(WebSocketAdapterProtocol, protocol.WebSocketServer
Base class for Twisted-based WebSocket server protocols.
"""

def _onConnect(self, request):
# onConnect() will return the selected subprotocol or None
# or a pair (protocol, headers) or raise an HttpException
res = maybeDeferred(self.onConnect, request)

res.addCallback(self.succeedHandshake)

def forwardError(failure):
if failure.check(ConnectionDeny):
return self.failHandshake(failure.value.reason, failure.value.code)
else:
self.log.debug("Unexpected exception in onConnect ['{failure.value}']", failure=failure)
return self.failHandshake("Internal server error: {}".format(failure.value), ConnectionDeny.INTERNAL_SERVER_ERROR)

res.addErrback(forwardError)

def get_channel_id(self, channel_id_type=u'tls-unique'):
"""
Implements :func:`autobahn.wamp.interfaces.ITransport.get_channel_id`
Expand Down
25 changes: 20 additions & 5 deletions autobahn/websocket/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
IWebSocketChannelFrameApi, \
IWebSocketChannelStreamingApi

from autobahn.websocket.types import ConnectionRequest, ConnectionResponse
from autobahn.websocket.types import ConnectionRequest, ConnectionResponse, ConnectionDeny

from autobahn.util import Stopwatch, newid, wildcards2patterns, encode_truncate
from autobahn.util import _LazyHexFormatter
Expand Down Expand Up @@ -2747,10 +2747,25 @@ def processHandshake(self):
self.websocket_origin,
self.websocket_protocols,
self.websocket_extensions)
# Now fire onConnect() on derived class, to give that class a chance to accept or deny
# the connection. onConnect() may throw, in which case the connection is denied, or it
# may return a protocol from the protocols provided by client or None.
self._onConnect(request)

# The user's onConnect() handler must do one of the following:
# - return the subprotocol to be spoken
# - return None to continue with no subprotocol
# - return a pair (subprotocol, headers)
# - raise a ConnectionDeny to dismiss the client
f = txaio.as_future(self.onConnect, request)

def forward_error(err):
if isinstance(err.value, ConnectionDeny):
# the user handler explicitly denies the connection
self.failHandshake(err.value.reason, err.value.code)
else:
# the user handler ran into an unexpected error (and hence, user code needs fixing!)
self.log.warn("Unexpected exception in onConnect ['{err.value}']", err=err)
self.log.warn("{tb}", tb=txaio.failure_format_traceback(err))
return self.failHandshake("Internal server error: {}".format(err.value), ConnectionDeny.INTERNAL_SERVER_ERROR)

txaio.add_callbacks(f, self.succeedHandshake, forward_error)

elif self.serveFlashSocketPolicy:
flash_policy_file_request = self.data.find(b"<policy-file-request/>\x00")
Expand Down

0 comments on commit 5bfe2fa

Please sign in to comment.