Skip to content

Commit

Permalink
Merge pull request #272 from kirklg/bugfix/windows_compatible_sockets
Browse files Browse the repository at this point in the history
Bugfix/windows compatible sockets
  • Loading branch information
timothycrosley committed Mar 20, 2016
2 parents 9eff2c6 + f983532 commit 429ad08
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 83 deletions.
168 changes: 88 additions & 80 deletions hug/use.py
Expand Up @@ -161,98 +161,106 @@ def request(self, method, url, url_params=empty.dict, headers=empty.dict, timeou

return Response(data, status_code, response._headers)

if getattr(socket, 'AF_UNIX', False):
class Socket(Service):
__slots__ = ('connection_pool', 'timeout', 'connection', 'send_and_receive')

Connection = namedtuple('Connection', ('connect_to', 'proto', 'sockopts'))
protocols = {
'tcp': (socket.AF_INET, socket.SOCK_STREAM),
'unix_stream': (socket.AF_UNIX, socket.SOCK_STREAM),
'udp': (socket.AF_INET, socket.SOCK_DGRAM),
'unix_dgram': (socket.AF_UNIX, socket.SOCK_DGRAM)
}
streams = ('tcp', 'unix_stream')
datagrams = ('udp', 'unix_dgram')
inet = ('tcp', 'udp')
unix = ('unix_stream', 'unix_dgram')

def __init__(self, connect_to, proto, version=None,
headers=empty.dict, timeout=None, pool=0, raise_on=(500, ), **kwargs):
super().__init__(timeout=timeout, raise_on=raise_on, version=version, **kwargs)
connect_to = tuple(connect_to) if proto in Socket.inet else connect_to
self.timeout = timeout
self.connection = Socket.Connection(connect_to, proto, set())
self.connection_pool = Queue(maxsize=pool if pool else 1)

if proto in Socket.streams:
self.send_and_receive = self._stream_send_and_receive
else:
self.send_and_receive = self._dgram_send_and_receive

def settimeout(self, timeout):
"""Set the default timeout"""
self.timeout = timeout

def setsockopt(self, *sockopts):
"""Add socket options to set"""
if type(sockopts[0]) in (list, tuple):
for sock_opt in sockopts[0]:
level, option, value = sock_opt
self.connection.sockopts.add((level, option, value))
else:
level, option, value = sockopts

class Socket(Service):
__slots__ = ('connection_pool', 'timeout', 'connection', 'send_and_receive')

on_unix = getattr(socket, 'AF_UNIX', False)
Connection = namedtuple('Connection', ('connect_to', 'proto', 'sockopts'))
protocols = {
'tcp': (socket.AF_INET, socket.SOCK_STREAM),
'udp': (socket.AF_INET, socket.SOCK_DGRAM),
}
streams = set(('tcp',))
datagrams = set(('udp',))
inet = set(('tcp', 'udp',))
unix = set()

if on_unix:
protocols.update({
'unix_dgram': (socket.AF_UNIX, socket.SOCK_DGRAM),
'unix_stream': (socket.AF_UNIX, socket.SOCK_STREAM)
})
streams.add('unix_stream')
datagrams.add('unix_dgram')
unix.update(('unix_stream', 'unix_dgram'))

def __init__(self, connect_to, proto, version=None,
headers=empty.dict, timeout=None, pool=0, raise_on=(500, ), **kwargs):
super().__init__(timeout=timeout, raise_on=raise_on, version=version, **kwargs)
connect_to = tuple(connect_to) if proto in Socket.inet else connect_to
self.timeout = timeout
self.connection = Socket.Connection(connect_to, proto, set())
self.connection_pool = Queue(maxsize=pool if pool else 1)

if proto in Socket.streams:
self.send_and_receive = self._stream_send_and_receive
else:
self.send_and_receive = self._dgram_send_and_receive

def settimeout(self, timeout):
"""Set the default timeout"""
self.timeout = timeout

def setsockopt(self, *sockopts):
"""Add socket options to set"""
if type(sockopts[0]) in (list, tuple):
for sock_opt in sockopts[0]:
level, option, value = sock_opt
self.connection.sockopts.add((level, option, value))
else:
level, option, value = sockopts
self.connection.sockopts.add((level, option, value))

def _register_socket(self):
"""Create/Connect socket, apply options"""
_socket = socket.socket(*Socket.protocols[self.connection.proto])
_socket.settimeout(self.timeout)
def _register_socket(self):
"""Create/Connect socket, apply options"""
_socket = socket.socket(*Socket.protocols[self.connection.proto])
_socket.settimeout(self.timeout)

# Reconfigure original socket options.
if self.connection.sockopts:
for sock_opt in self.connection.sockopts:
level, option, value = sock_opt
_socket.setsockopt(level, option, value)
# Reconfigure original socket options.
if self.connection.sockopts:
for sock_opt in self.connection.sockopts:
level, option, value = sock_opt
_socket.setsockopt(level, option, value)

_socket.connect(self.connection.connect_to)
return _socket
_socket.connect(self.connection.connect_to)
return _socket

def _stream_send_and_receive(self, _socket, message, *args, **kwargs):
"""TCP/Stream sender and receiver"""
data = BytesIO()
def _stream_send_and_receive(self, _socket, message, *args, **kwargs):
"""TCP/Stream sender and receiver"""
data = BytesIO()

_socket_fd = _socket.makefile(mode='rwb', encoding='utf-8')
_socket_fd.write(message.encode('utf-8'))
_socket_fd.flush()
_socket_fd = _socket.makefile(mode='rwb', encoding='utf-8')
_socket_fd.write(message.encode('utf-8'))
_socket_fd.flush()

for received in _socket_fd:
data.write(received)
data.seek(0)
for received in _socket_fd:
data.write(received)
data.seek(0)

_socket_fd.close()
return data
_socket_fd.close()
return data

def _dgram_send_and_receive(self, _socket, message, buffer_size=4096, *args):
"""User Datagram Protocol sender and receiver"""
_socket.sendto(message.encode('utf-8'), self.connection.connect_to)
data, address = _socket.recvfrom(buffer_size)
return BytesIO(data)
def _dgram_send_and_receive(self, _socket, message, buffer_size=4096, *args):
"""User Datagram Protocol sender and receiver"""
_socket.sendto(message.encode('utf-8'), self.connection.connect_to)
data, address = _socket.recvfrom(buffer_size)
return BytesIO(data)

def request(self, message, timeout=False, *args, **kwargs):
"""Populate connection pool, send message, return BytesIO, and cleanup"""
if not self.connection_pool.full():
self.connection_pool.put(self._register_socket())
def request(self, message, timeout=False, *args, **kwargs):
"""Populate connection pool, send message, return BytesIO, and cleanup"""
if not self.connection_pool.full():
self.connection_pool.put(self._register_socket())

_socket = self.connection_pool.get()
_socket = self.connection_pool.get()

# setting timeout to None enables the socket to block.
if timeout or timeout is None:
_socket.settimeout(timeout)
# setting timeout to None enables the socket to block.
if timeout or timeout is None:
_socket.settimeout(timeout)

data = self.send_and_receive(_socket, message, *args, **kwargs)
data = self.send_and_receive(_socket, message, *args, **kwargs)

if self.connection.proto in Socket.streams:
_socket.shutdown(socket.SHUT_RDWR)
if self.connection.proto in Socket.streams:
_socket.shutdown(socket.SHUT_RDWR)

return Response(data, None, None)
return Response(data, None, None)
27 changes: 24 additions & 3 deletions tests/test_use.py
Expand Up @@ -138,6 +138,7 @@ def test_request(self):

class TestSocket(object):
"""Test to ensure the Socket Service object enables sending/receiving data from arbitrary server/port sockets"""
on_unix = getattr(socket, 'AF_UNIX', False)
tcp_service = use.Socket(connect_to=('www.google.com', 80), proto='tcp', timeout=60)
udp_service = use.Socket(connect_to=('8.8.8.8', 53), proto='udp', timeout=60)

Expand All @@ -148,13 +149,33 @@ def test_init(self):
def test_protocols(self):
"""Test to ensure all supported protocols are present"""
protocols = sorted(['tcp', 'udp', 'unix_stream', 'unix_dgram'])
assert sorted(self.tcp_service.protocols) == protocols
if self.on_unix:
assert sorted(self.tcp_service.protocols) == protocols
else:
protocols.remove('unix_stream')
protocols.remove('unix_dgram')
assert sorted(self.tcp_service.protocols) == protocols

def test_streams(self):
assert self.tcp_service.streams == ('tcp', 'unix_stream')
if self.on_unix:
assert self.tcp_service.streams == set(('tcp', 'unix_stream',))
else:
assert self.tcp_service.streams == set(('tcp',))

def test_datagrams(self):
assert self.tcp_service.datagrams == ('udp', 'unix_dgram')
if self.on_unix:
assert self.tcp_service.datagrams == set(('udp', 'unix_dgram',))
else:
assert self.tcp_service.datagrams == set(('udp',))

def test_inet(self):
assert self.tcp_service.inet == set(('tcp', 'udp',))

def test_unix(self):
if self.on_unix:
assert self.tcp_service.unix == set(('unix_stream', 'unix_dgram',))
else:
assert self.tcp_service.unix == set()

def test_connection(self):
assert self.tcp_service.connection.connect_to == ('www.google.com', 80)
Expand Down

0 comments on commit 429ad08

Please sign in to comment.