diff --git a/cabby/abstract.py b/cabby/abstract.py index 38cb16e..4577a61 100644 --- a/cabby/abstract.py +++ b/cabby/abstract.py @@ -25,7 +25,7 @@ class AbstractClient(object): taxii_version = None def __init__(self, host=None, discovery_path=None, port=None, - use_https=False, headers=None): + use_https=False, headers=None, timeout=None): self.host = host self.port = port @@ -47,6 +47,7 @@ def __init__(self, host=None, discovery_path=None, port=None, self.jwt_token = None self.headers = headers or {} + self.timeout = timeout self.log = logging.getLogger( "{}.{}".format(self.__module__, self.__class__.__name__)) @@ -193,13 +194,15 @@ def _execute_request(self, request, uri=None, service_type=None): 'key_file': self.key_file, 'key_password': self.key_password, 'ca_cert': self.ca_cert - }) + }, + timeout=self.timeout) else: message = dispatcher.send_taxii_request( session, self._prepare_url(uri), request, - taxii_binding=self.taxii_binding) + taxii_binding=self.taxii_binding, + timeout=self.timeout) return message diff --git a/cabby/dispatcher.py b/cabby/dispatcher.py index 2f8963f..b4dff46 100644 --- a/cabby/dispatcher.py +++ b/cabby/dispatcher.py @@ -34,7 +34,7 @@ def raise_http_error(status_code, response_stream=None): def send_taxii_request(session, url, request, taxii_binding=None, - tls_details=None): + tls_details=None, timeout=None): ''' Send XML message to a TAXII service and parse a response. ''' @@ -55,7 +55,7 @@ def send_taxii_request(session, url, request, taxii_binding=None, # https://github.com/kennethreitz/requests/issues/2519 is fixed try: response = get_response_using_key_pass( - url, request_body, session, **tls_details) + url, request_body, session, timeout=timeout, **tls_details) except urllib.error.HTTPError as e: log.error( "Error while connecting to {}".format(url), @@ -64,7 +64,8 @@ def send_taxii_request(session, url, request, taxii_binding=None, stream, headers = response, response.headers else: - response = session.post(url, data=request_body, stream=True) + response = session.post(url, data=request_body, stream=True, + timeout=timeout) if not response.ok: raise_http_error(response.status_code, response.raw) @@ -369,7 +370,7 @@ def obtain_jwt_token(session, jwt_url, username, password): def get_response_using_key_pass(url, data, session, cert_file, key_file, - key_password, ca_cert=None): + key_password, ca_cert=None, timeout=None): if sys.version_info < (2, 7, 9): raise ValueError( @@ -405,4 +406,7 @@ def get_response_using_key_pass(url, data, session, cert_file, key_file, request = urllib.request.Request(url, data, headers) - return opener.open(request) + if timeout: + return opener.open(request, timeout=timeout) + else: + return opener.open(request) diff --git a/tests/test_common.py b/tests/test_common.py index 9c9b4f4..85c53d0 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -3,6 +3,8 @@ import json import gzip import sys +import requests +from time import sleep from six import StringIO @@ -55,7 +57,6 @@ def get_sent_message(version): @pytest.mark.parametrize("version", [11, 10]) def test_set_headers(version): - httpretty.reset() httpretty.enable() @@ -86,7 +87,6 @@ def test_set_headers(version): @pytest.mark.parametrize("version", [11, 10]) def test_invalid_response(version): - httpretty.reset() httpretty.enable() @@ -112,7 +112,6 @@ def test_invalid_response(version): @pytest.mark.parametrize("version", [11, 10]) def test_invalid_response_status(version): - httpretty.reset() httpretty.enable() @@ -132,7 +131,6 @@ def test_invalid_response_status(version): @pytest.mark.parametrize("version", [11, 10]) def test_jwt_auth_response(version): - httpretty.reset() httpretty.enable() @@ -202,7 +200,6 @@ def compress(text): @pytest.mark.parametrize("version", [11, 10]) def test_gzip_response(version): - httpretty.reset() httpretty.enable() @@ -222,3 +219,37 @@ def test_gzip_response(version): httpretty.disable() httpretty.reset() + + +@pytest.mark.parametrize("version", [11, 10]) +def test_timeout(version): + httpretty.reset() + httpretty.enable() + + timeout_in_sec = 1 + + client = make_client(version) + # + # configure to raise the error before the timeout + # + client.timeout = timeout_in_sec / 2.0 + + def timeout_request_callback(request, uri, headers): + sleep(timeout_in_sec) + + return 200, headers, {'result': 'success'} + + uri = get_fix(version).DISCOVERY_URI_HTTP + + httpretty.register_uri( + httpretty.POST, + uri, + body=timeout_request_callback, + content_type='application/json' + ) + + with pytest.raises(requests.exceptions.Timeout): + client.discover_services(uri=uri) + + httpretty.disable() + httpretty.reset()