diff --git a/hug/use.py b/hug/use.py index f9a86596..e73e4263 100644 --- a/hug/use.py +++ b/hug/use.py @@ -161,98 +161,106 @@ def request(self, method, url, url_params=empty.dict, headers=empty.dict, timeou return Response(data, status_code, response._headers) -if getattr(socket, 'AF_UNIX', False): - class Socket(Service): - __slots__ = ('connection_pool', 'timeout', 'connection', 'send_and_receive') - - Connection = namedtuple('Connection', ('connect_to', 'proto', 'sockopts')) - protocols = { - 'tcp': (socket.AF_INET, socket.SOCK_STREAM), - 'unix_stream': (socket.AF_UNIX, socket.SOCK_STREAM), - 'udp': (socket.AF_INET, socket.SOCK_DGRAM), - 'unix_dgram': (socket.AF_UNIX, socket.SOCK_DGRAM) - } - streams = ('tcp', 'unix_stream') - datagrams = ('udp', 'unix_dgram') - inet = ('tcp', 'udp') - unix = ('unix_stream', 'unix_dgram') - - def __init__(self, connect_to, proto, version=None, - headers=empty.dict, timeout=None, pool=0, raise_on=(500, ), **kwargs): - super().__init__(timeout=timeout, raise_on=raise_on, version=version, **kwargs) - connect_to = tuple(connect_to) if proto in Socket.inet else connect_to - self.timeout = timeout - self.connection = Socket.Connection(connect_to, proto, set()) - self.connection_pool = Queue(maxsize=pool if pool else 1) - - if proto in Socket.streams: - self.send_and_receive = self._stream_send_and_receive - else: - self.send_and_receive = self._dgram_send_and_receive - - def settimeout(self, timeout): - """Set the default timeout""" - self.timeout = timeout - - def setsockopt(self, *sockopts): - """Add socket options to set""" - if type(sockopts[0]) in (list, tuple): - for sock_opt in sockopts[0]: - level, option, value = sock_opt - self.connection.sockopts.add((level, option, value)) - else: - level, option, value = sockopts + +class Socket(Service): + __slots__ = ('connection_pool', 'timeout', 'connection', 'send_and_receive') + + on_unix = getattr(socket, 'AF_UNIX', False) + Connection = namedtuple('Connection', ('connect_to', 'proto', 'sockopts')) + protocols = { + 'tcp': (socket.AF_INET, socket.SOCK_STREAM), + 'udp': (socket.AF_INET, socket.SOCK_DGRAM), + } + streams = set(('tcp',)) + datagrams = set(('udp',)) + inet = set(('tcp', 'udp',)) + unix = set() + + if on_unix: + protocols.update({ + 'unix_dgram': (socket.AF_UNIX, socket.SOCK_DGRAM), + 'unix_stream': (socket.AF_UNIX, socket.SOCK_STREAM) + }) + streams.add('unix_stream') + datagrams.add('unix_dgram') + unix.update(('unix_stream', 'unix_dgram')) + + def __init__(self, connect_to, proto, version=None, + headers=empty.dict, timeout=None, pool=0, raise_on=(500, ), **kwargs): + super().__init__(timeout=timeout, raise_on=raise_on, version=version, **kwargs) + connect_to = tuple(connect_to) if proto in Socket.inet else connect_to + self.timeout = timeout + self.connection = Socket.Connection(connect_to, proto, set()) + self.connection_pool = Queue(maxsize=pool if pool else 1) + + if proto in Socket.streams: + self.send_and_receive = self._stream_send_and_receive + else: + self.send_and_receive = self._dgram_send_and_receive + + def settimeout(self, timeout): + """Set the default timeout""" + self.timeout = timeout + + def setsockopt(self, *sockopts): + """Add socket options to set""" + if type(sockopts[0]) in (list, tuple): + for sock_opt in sockopts[0]: + level, option, value = sock_opt self.connection.sockopts.add((level, option, value)) + else: + level, option, value = sockopts + self.connection.sockopts.add((level, option, value)) - def _register_socket(self): - """Create/Connect socket, apply options""" - _socket = socket.socket(*Socket.protocols[self.connection.proto]) - _socket.settimeout(self.timeout) + def _register_socket(self): + """Create/Connect socket, apply options""" + _socket = socket.socket(*Socket.protocols[self.connection.proto]) + _socket.settimeout(self.timeout) - # Reconfigure original socket options. - if self.connection.sockopts: - for sock_opt in self.connection.sockopts: - level, option, value = sock_opt - _socket.setsockopt(level, option, value) + # Reconfigure original socket options. + if self.connection.sockopts: + for sock_opt in self.connection.sockopts: + level, option, value = sock_opt + _socket.setsockopt(level, option, value) - _socket.connect(self.connection.connect_to) - return _socket + _socket.connect(self.connection.connect_to) + return _socket - def _stream_send_and_receive(self, _socket, message, *args, **kwargs): - """TCP/Stream sender and receiver""" - data = BytesIO() + def _stream_send_and_receive(self, _socket, message, *args, **kwargs): + """TCP/Stream sender and receiver""" + data = BytesIO() - _socket_fd = _socket.makefile(mode='rwb', encoding='utf-8') - _socket_fd.write(message.encode('utf-8')) - _socket_fd.flush() + _socket_fd = _socket.makefile(mode='rwb', encoding='utf-8') + _socket_fd.write(message.encode('utf-8')) + _socket_fd.flush() - for received in _socket_fd: - data.write(received) - data.seek(0) + for received in _socket_fd: + data.write(received) + data.seek(0) - _socket_fd.close() - return data + _socket_fd.close() + return data - def _dgram_send_and_receive(self, _socket, message, buffer_size=4096, *args): - """User Datagram Protocol sender and receiver""" - _socket.sendto(message.encode('utf-8'), self.connection.connect_to) - data, address = _socket.recvfrom(buffer_size) - return BytesIO(data) + def _dgram_send_and_receive(self, _socket, message, buffer_size=4096, *args): + """User Datagram Protocol sender and receiver""" + _socket.sendto(message.encode('utf-8'), self.connection.connect_to) + data, address = _socket.recvfrom(buffer_size) + return BytesIO(data) - def request(self, message, timeout=False, *args, **kwargs): - """Populate connection pool, send message, return BytesIO, and cleanup""" - if not self.connection_pool.full(): - self.connection_pool.put(self._register_socket()) + def request(self, message, timeout=False, *args, **kwargs): + """Populate connection pool, send message, return BytesIO, and cleanup""" + if not self.connection_pool.full(): + self.connection_pool.put(self._register_socket()) - _socket = self.connection_pool.get() + _socket = self.connection_pool.get() - # setting timeout to None enables the socket to block. - if timeout or timeout is None: - _socket.settimeout(timeout) + # setting timeout to None enables the socket to block. + if timeout or timeout is None: + _socket.settimeout(timeout) - data = self.send_and_receive(_socket, message, *args, **kwargs) + data = self.send_and_receive(_socket, message, *args, **kwargs) - if self.connection.proto in Socket.streams: - _socket.shutdown(socket.SHUT_RDWR) + if self.connection.proto in Socket.streams: + _socket.shutdown(socket.SHUT_RDWR) - return Response(data, None, None) + return Response(data, None, None) diff --git a/tests/test_use.py b/tests/test_use.py index b6fe8ec1..15673392 100644 --- a/tests/test_use.py +++ b/tests/test_use.py @@ -138,6 +138,7 @@ def test_request(self): class TestSocket(object): """Test to ensure the Socket Service object enables sending/receiving data from arbitrary server/port sockets""" + on_unix = getattr(socket, 'AF_UNIX', False) tcp_service = use.Socket(connect_to=('www.google.com', 80), proto='tcp', timeout=60) udp_service = use.Socket(connect_to=('8.8.8.8', 53), proto='udp', timeout=60) @@ -148,13 +149,33 @@ def test_init(self): def test_protocols(self): """Test to ensure all supported protocols are present""" protocols = sorted(['tcp', 'udp', 'unix_stream', 'unix_dgram']) - assert sorted(self.tcp_service.protocols) == protocols + if self.on_unix: + assert sorted(self.tcp_service.protocols) == protocols + else: + protocols.remove('unix_stream') + protocols.remove('unix_dgram') + assert sorted(self.tcp_service.protocols) == protocols def test_streams(self): - assert self.tcp_service.streams == ('tcp', 'unix_stream') + if self.on_unix: + assert self.tcp_service.streams == set(('tcp', 'unix_stream',)) + else: + assert self.tcp_service.streams == set(('tcp',)) def test_datagrams(self): - assert self.tcp_service.datagrams == ('udp', 'unix_dgram') + if self.on_unix: + assert self.tcp_service.datagrams == set(('udp', 'unix_dgram',)) + else: + assert self.tcp_service.datagrams == set(('udp',)) + + def test_inet(self): + assert self.tcp_service.inet == set(('tcp', 'udp',)) + + def test_unix(self): + if self.on_unix: + assert self.tcp_service.unix == set(('unix_stream', 'unix_dgram',)) + else: + assert self.tcp_service.unix == set() def test_connection(self): assert self.tcp_service.connection.connect_to == ('www.google.com', 80)