diff --git a/gapipy/client.py b/gapipy/client.py index a00291a..ac92ce6 100644 --- a/gapipy/client.py +++ b/gapipy/client.py @@ -24,6 +24,7 @@ 'maxsize': os.environ.get('GAPI_CLIENT_CONNECTION_POOL_MAXSIZE', 10), }, 'uuid': os.environ.get('GAPI_UUID', False), + 'max_retries': os.environ.get('GAPI_CLIENT_MAX_RETRIES', 0), } @@ -50,6 +51,7 @@ def __init__(self, **config): self.api_language = get_config(config, 'api_language') self.cache_backend = get_config(config, 'cache_backend') self.uuid = get_config(config, 'uuid') + self.max_retries = get_config(config, 'max_retries') # begin with default connection pool options and overwrite any that the # client has specified @@ -61,7 +63,7 @@ def __init__(self, **config): self.logger.setLevel(log_level) self._set_cache_instance(get_config(config, 'cache_options')) - self._set_requestor(self.connection_pool_options) + self._set_requestor(self.connection_pool_options, self.max_retries) # Prevent install issues where setup.py digs down the path and # eventually fails on a missing requests requirement by importing Query @@ -77,7 +79,7 @@ def _set_cache_instance(self, cache_options): cache = getattr(module, class_name)(**cache_options) self._cache = cache - def _set_requestor(self, pool_options): + def _set_requestor(self, pool_options, max_retries): """ Set the requestor based on connection pooling options. @@ -88,17 +90,20 @@ def _set_requestor(self, pool_options): # to break some CI environments import requests + session = requests.Session() + if not pool_options['enable']: - self._requestor = requests - return - session = requests.Session() - adapter = requests.adapters.HTTPAdapter( - pool_block=pool_options['block'], - pool_connections=pool_options['number'], - pool_maxsize=pool_options['maxsize'], - ) - logger.info( + adapter = requests.adapters.HTTPAdapter(max_retries=max_retries) + + else: + adapter = requests.adapters.HTTPAdapter( + pool_block=pool_options['block'], + pool_connections=pool_options['number'], + pool_maxsize=pool_options['maxsize'], + max_retries=max_retries, + ) + logger.info( 'Created connection pool (block={}, number={}, maxsize={})'.format( pool_options['block'], pool_options['number'], @@ -107,13 +112,13 @@ def _set_requestor(self, pool_options): prefix = _get_protocol_prefix(self.api_root) if prefix: session.mount(prefix, adapter) - logger.info('Mounted connection pool for "{}"'.format(prefix)) + logger.info('Mounted session for "{}"'.format(prefix)) else: session.mount('http://', adapter) session.mount('https://', adapter) logger.info( 'Could not find protocol prefix in API root, mounted ' - 'connection pool on both http and https.') + 'session on both http and https.') self._requestor = session diff --git a/tests/test_client.py b/tests/test_client.py index 87d7893..d8ae140 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -94,3 +94,34 @@ def test_correct_client_is_associated_with_resources(self, mock_get_data): self.assertEqual(en_itin._client, en_client) self.assertEqual(de_itin._client, de_client) + + def test_default_retries(self): + """Should not set any retries on the client's requestor.""" + http_retries = self.gapi.requestor.adapters['http://'].max_retries.total + https_retries = self.gapi.requestor.adapters['https://'].max_retries.total + + self.assertEqual(http_retries, 0) + self.assertEqual(https_retries, 0) + + def test_retries_no_connection_pooling(self): + """Should initialize the client's requestor with the passed number of retries.""" + expected_retries = 42 + client_with_retries = Client(max_retries=expected_retries) + + # Connection pooling defaults to https only + https_retries = client_with_retries.requestor.adapters['https://'].max_retries.total + + self.assertEqual(https_retries, expected_retries) + + def test_retries_with_connection_pooling(self): + """Should initialize the client's requestor with the passed number of retries.""" + expected_retries = 84 + connection_pool_options = {"enable": True} + + client_with_retries = Client(max_retries=expected_retries, connection_pool_options=connection_pool_options) + + # Connection pooling defaults to https only + https_retries = client_with_retries.requestor.adapters['https://'].max_retries.total + + self.assertEqual(https_retries, expected_retries) +