Skip to content

Commit

Permalink
Merge pull request #492 from randomir/fix-dwaveapiclient-close
Browse files Browse the repository at this point in the history
Fix dwaveapiclient close
  • Loading branch information
randomir committed Oct 19, 2021
2 parents 5c55571 + 8d33965 commit a1b8913
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 54 deletions.
14 changes: 14 additions & 0 deletions dwave/cloud/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,14 @@ class DWaveAPIClient:
"""Low-level client for D-Wave APIs. A thin wrapper around
`requests.Session` that handles API specifics such as authentication,
response and error parsing, retrying, etc.
Note:
To make sure the session is closed, call :meth:`.close`, or use the
context manager form (as show in the example below).
Example:
with DWaveAPIClient(endpoint='...', timeout=(5, 600)) as client:
client.session.get('...')
"""

DEFAULTS = {
Expand Down Expand Up @@ -165,6 +173,12 @@ def __init__(self, **config):
def close(self):
self.session.close()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

@staticmethod
def _retry_config(backoff_max=None, **kwargs):
"""Create http idempotent urllib3.Retry config."""
Expand Down
5 changes: 4 additions & 1 deletion dwave/cloud/api/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,14 @@ def __init__(self, **config):
self.session = self.client.session
self._patch_session()

def close(self):
self.client.close()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.client.close()
self.close()

@classmethod
def from_client_config(cls, client: Union[DWaveAPIClient, 'dwave.cloud.client.base.Client']):
Expand Down
11 changes: 11 additions & 0 deletions releasenotes/notes/fix-api-client-cleanup-0aa8a3d441db2bdb.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
---
features:
- |
Add context manager protocol support to
``dwave.cloud.api.client.DWaveAPIClient`` to ensure resources are easily
cleaned up (session closed). Note that ``close()`` method is available for
cases when context manager pattern is inconvenient.
Similarly, we add ``close()`` method to resources in
``dwave.cloud.api.resource.*``, in addition to the existing context manager
protocol support.
10 changes: 10 additions & 0 deletions tests/api/resources/test_problems.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,8 @@ def match_invalid_batch_submit(request):
def tearDown(self):
self.mocker.stop()

self.api.close()


class TestMockProblemsStructured(StructuredProblemTestsMixin,
ProblemResourcesMockerMixin,
Expand Down Expand Up @@ -582,6 +584,10 @@ def setUpClass(cls):
# double-check
assert future.remote_status == constants.ProblemStatus.COMPLETED.value

@classmethod
def tearDownClass(cls):
cls.api.close()


@unittest.skipUnless(dimod, "dimod not installed")
@unittest.skipUnless(config, "SAPI access not configured")
Expand Down Expand Up @@ -616,3 +622,7 @@ def setUpClass(cls):

# double-check
assert future.remote_status == constants.ProblemStatus.COMPLETED.value

@classmethod
def tearDownClass(cls):
cls.api.close()
4 changes: 4 additions & 0 deletions tests/api/resources/test_regions.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ class TestCloudRegions(unittest.TestCase):
def setUpClass(cls):
cls.api = Regions()

@classmethod
def tearDownClass(cls):
cls.api.close()

def test_list_solvers(self):
regions = self.api.list_regions()

Expand Down
4 changes: 4 additions & 0 deletions tests/api/resources/test_solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ def setUpClass(cls):
with Client(**config) as client:
cls.api = Solvers.from_client_config(client)

@classmethod
def tearDownClass(cls):
cls.api.close()

def test_list_solvers(self):
"""List of all available solvers retrieved."""

Expand Down
105 changes: 52 additions & 53 deletions tests/api/test_api_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def test_defaults(self):
# verify Retry object config
retry = client.session.get_adapter('https://').max_retries
conf = DWaveAPIClient.DEFAULTS['retry']
client.close()
self.assertEqual(retry.total, conf['total'])

def test_init(self):
Expand All @@ -53,31 +54,32 @@ def test_init(self):
verify=False,
proxies={'https': 'http://proxy.com'})

client = DWaveAPIClient(**config)
with DWaveAPIClient(**config) as client:
session = client.session
self.assertIsInstance(session, requests.Session)

session = client.session
self.assertIsInstance(session, requests.Session)
self.assertEqual(session.base_url, config['endpoint'])
self.assertEqual(session.cert, None)
self.assertEqual(session.headers['X-Auth-Token'], config['token'])
self.assertEqual(session.headers['Custom'], config['headers']['Custom'])
self.assertIn(__packagename__, session.headers['User-Agent'])
self.assertIn(__version__, session.headers['User-Agent'])
self.assertEqual(session.verify, config['verify'])
self.assertEqual(session.proxies, config['proxies'])

self.assertEqual(session.base_url, config['endpoint'])
self.assertEqual(session.cert, None)
self.assertEqual(session.headers['X-Auth-Token'], config['token'])
self.assertEqual(session.headers['Custom'], config['headers']['Custom'])
self.assertIn(__packagename__, session.headers['User-Agent'])
self.assertIn(__version__, session.headers['User-Agent'])
self.assertEqual(session.verify, config['verify'])
self.assertEqual(session.proxies, config['proxies'])

# verify Retry object config
retry = session.get_adapter('https://').max_retries
self.assertEqual(retry.total, config['retry']['total'])
# verify Retry object config
retry = session.get_adapter('https://').max_retries
self.assertEqual(retry.total, config['retry']['total'])

def test_sapi_client(self):
client = SolverAPIClient()
self.assertEqual(client.session.base_url, constants.DEFAULT_SOLVER_API_ENDPOINT)
with SolverAPIClient() as client:
self.assertEqual(client.session.base_url,
constants.DEFAULT_SOLVER_API_ENDPOINT)

def test_metadata_client(self):
client = MetadataAPIClient()
self.assertEqual(client.session.base_url, constants.DEFAULT_METADATA_API_ENDPOINT)
with MetadataAPIClient() as client:
self.assertEqual(client.session.base_url,
constants.DEFAULT_METADATA_API_ENDPOINT)


class TestRequests(unittest.TestCase):
Expand All @@ -97,9 +99,8 @@ def test_request(self, m):
m.get(requests_mock.ANY, status_code=404, request_headers=auth_headers)
m.get(config['endpoint'], json=data, request_headers=config['headers'])

client = DWaveAPIClient(**config)

self.assertEqual(client.session.get('').json(), data)
with DWaveAPIClient(**config) as client:
self.assertEqual(client.session.get('').json(), data)

@requests_mock.Mocker()
def test_paths(self, m):
Expand All @@ -115,10 +116,9 @@ def test_paths(self, m):
m.get(f"{baseurl}/{path_a}", json=data_a)
m.get(f"{baseurl}/{path_b}", json=data_b)

client = DWaveAPIClient(**config)

self.assertEqual(client.session.get(path_a).json(), data_a)
self.assertEqual(client.session.get(path_b).json(), data_b)
with DWaveAPIClient(**config) as client:
self.assertEqual(client.session.get(path_a).json(), data_a)
self.assertEqual(client.session.get(path_b).json(), data_b)

@requests_mock.Mocker()
def test_session_history(self, m):
Expand All @@ -130,17 +130,16 @@ def test_session_history(self, m):
m.get(requests_mock.ANY, status_code=404)
m.get(f"{baseurl}/path", json=dict(data=True))

client = DWaveAPIClient(**config)

client.session.get('path')
self.assertEqual(client.session.history[-1].request.path_url, '/path')
with DWaveAPIClient(**config) as client:
client.session.get('path')
self.assertEqual(client.session.history[-1].request.path_url, '/path')

with self.assertRaises(exceptions.ResourceNotFoundError):
client.session.get('unknown')
self.assertEqual(client.session.history[-1].exception.error_code, 404)
with self.assertRaises(exceptions.ResourceNotFoundError):
client.session.get('unknown')
self.assertEqual(client.session.history[-1].exception.error_code, 404)

client.session.get('/path')
self.assertEqual(client.session.history[-1].request.path_url, '/path')
client.session.get('/path')
self.assertEqual(client.session.history[-1].request.path_url, '/path')


class TestResponseParsing(unittest.TestCase):
Expand All @@ -151,10 +150,10 @@ def test_non_json(self, m):

m.get(requests_mock.ANY, text='text', status_code=200)

client = DWaveAPIClient(endpoint='https://mock')
with DWaveAPIClient(endpoint='https://mock') as client:

with self.assertRaises(exceptions.ResourceBadResponseError) as exc:
client.session.get('test')
with self.assertRaises(exceptions.ResourceBadResponseError) as exc:
client.session.get('test')

@requests_mock.Mocker()
def test_structured_error_response(self, m):
Expand All @@ -166,13 +165,13 @@ def test_structured_error_response(self, m):

m.get(requests_mock.ANY, json=error, status_code=error_code)

client = DWaveAPIClient(endpoint='https://mock')
with DWaveAPIClient(endpoint='https://mock') as client:

with self.assertRaisesRegex(exceptions.ResourceNotFoundError, error_msg) as exc:
client.session.get('test')
with self.assertRaisesRegex(exceptions.ResourceNotFoundError, error_msg) as exc:
client.session.get('test')

self.assertEqual(exc.error_msg, error_msg)
self.assertEqual(exc.error_code, error_code)
self.assertEqual(exc.error_msg, error_msg)
self.assertEqual(exc.error_code, error_code)

@requests_mock.Mocker()
def test_plain_text_error(self, m):
Expand All @@ -183,13 +182,13 @@ def test_plain_text_error(self, m):

m.get(requests_mock.ANY, text=error_msg, status_code=error_code)

client = DWaveAPIClient(endpoint='https://mock')
with DWaveAPIClient(endpoint='https://mock') as client:

with self.assertRaisesRegex(exceptions.ResourceNotFoundError, error_msg) as exc:
client.session.get('test')
with self.assertRaisesRegex(exceptions.ResourceNotFoundError, error_msg) as exc:
client.session.get('test')

self.assertEqual(exc.error_msg, error_msg)
self.assertEqual(exc.error_code, error_code)
self.assertEqual(exc.error_msg, error_msg)
self.assertEqual(exc.error_code, error_code)

@requests_mock.Mocker()
def test_unknown_errors(self, m):
Expand All @@ -200,10 +199,10 @@ def test_unknown_errors(self, m):

m.get(requests_mock.ANY, text=error_msg, status_code=error_code)

client = DWaveAPIClient(endpoint='https://mock')
with DWaveAPIClient(endpoint='https://mock') as client:

with self.assertRaisesRegex(exceptions.RequestError, error_msg) as exc:
client.session.get('test')
with self.assertRaisesRegex(exceptions.RequestError, error_msg) as exc:
client.session.get('test')

self.assertEqual(exc.error_msg, error_msg)
self.assertEqual(exc.error_code, error_code)
self.assertEqual(exc.error_msg, error_msg)
self.assertEqual(exc.error_code, error_code)

0 comments on commit a1b8913

Please sign in to comment.