Skip to content

Commit

Permalink
Add tests for problem label (across methods, solvers)
Browse files Browse the repository at this point in the history
  • Loading branch information
randomir committed Jan 12, 2021
1 parent 647d158 commit a4e6dff
Show file tree
Hide file tree
Showing 4 changed files with 200 additions and 12 deletions.
2 changes: 1 addition & 1 deletion tests/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def handler(event, **data):
before = memo['before_sample']
args = dict(type_='ising', linear=lin, quadratic=quad,
offset=offset, params=params,
undirected_biases=False)
undirected_biases=False, label=None)
self.assertEqual(before['obj'], self.solver)
self.assertDictEqual(before['args'], args)

Expand Down
123 changes: 114 additions & 9 deletions tests/test_mock_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from requests.structures import CaseInsensitiveDict
from requests.exceptions import HTTPError
from concurrent.futures import TimeoutError
from parameterized import parameterized

try:
import dimod
Expand Down Expand Up @@ -68,7 +69,7 @@ def solver_data(id_, incomplete=False):
return obj


def complete_reply(id_, solver_name, answer=None, msg=None):
def complete_reply(id_, solver_name, answer=None, msg=None, label=""):
"""Reply with solutions for the test problem."""
response = {
"status": "COMPLETED",
Expand All @@ -85,7 +86,8 @@ def complete_reply(id_, solver_name, answer=None, msg=None):
"timing": {}
},
"type": "ising",
"id": id_
"id": id_,
"label": label
}

# optional answer fields override
Expand All @@ -99,19 +101,20 @@ def complete_reply(id_, solver_name, answer=None, msg=None):
return json.dumps(response)


def complete_no_answer_reply(id_, solver_name):
def complete_no_answer_reply(id_, solver_name, label=""):
"""A reply saying a problem is finished without providing the results."""
return json.dumps({
"status": "COMPLETED",
"solved_on": "2012-12-05T19:15:07+00:00",
"solver": solver_name,
"submitted_on": "2012-12-05T19:06:57+00:00",
"type": "ising",
"id": id_
"id": id_,
"label": label
})


def error_reply(id_, solver_name, error):
def error_reply(id_, solver_name, error, label=""):
"""A reply saying an error has occurred."""
return json.dumps({
"status": "FAILED",
Expand All @@ -120,6 +123,7 @@ def error_reply(id_, solver_name, error):
"submitted_on": "2013-01-18T10:25:59.941674",
"type": "ising",
"id": id_,
"label": label,
"error_message": error
})

Expand All @@ -132,15 +136,16 @@ def immediate_error_reply(code, msg):
})


def cancel_reply(id_, solver_name):
def cancel_reply(id_, solver_name, label=""):
"""A reply saying a problem was canceled."""
return json.dumps({
"status": "CANCELLED",
"solved_on": "2013-01-18T10:26:00.020954",
"solver": solver_name,
"submitted_on": "2013-01-18T10:25:59.941674",
"type": "ising",
"id": id_
"id": id_,
"label": label
})


Expand All @@ -149,7 +154,7 @@ def datetime_in_future(seconds=0):
return now + timedelta(seconds=seconds)


def continue_reply(id_, solver_name, now=None, eta_min=None, eta_max=None):
def continue_reply(id_, solver_name, now=None, eta_min=None, eta_max=None, label=""):
"""A reply saying a problem is still in the queue."""

if not now:
Expand All @@ -161,7 +166,8 @@ def continue_reply(id_, solver_name, now=None, eta_min=None, eta_max=None):
"solver": solver_name,
"submitted_on": now.isoformat(),
"type": "ising",
"id": id_
"id": id_,
"label": label
}
if eta_min:
resp.update({
Expand Down Expand Up @@ -1081,3 +1087,102 @@ def create_mock_session(client):
# .occurrences is deprecated in 0.8.0, scheduled for removal in 0.10.0+
with self.assertWarns(DeprecationWarning):
results.occurrences


class TestProblemLabel(unittest.TestCase):

class PrimaryAssertionSatisfied(Exception):
"""Raised by `on_submit_label_verifier` to signal correct label."""

def on_submit_label_verifier(self, expected_label):
"""Factory for mock Client._submit() that will verify existence, and
optionally validate label value.
"""

# replacement for Client._submit()
def _submit(client, body_data, computation):
body = json.loads(body_data.result())

if 'label' not in body:
if expected_label is None:
raise TestProblemLabel.PrimaryAssertionSatisfied
else:
raise AssertionError("label field missing")

label = body['label']
if label != expected_label:
raise AssertionError(
"unexpected label value: {!r} != {!r}".format(label, expected_label))

raise TestProblemLabel.PrimaryAssertionSatisfied

return _submit

def generate_sample_problems(self, solver):
linear, quadratic = test_problem(solver)

# test sample_{ising,qubo,bqm}
problems = [("sample_ising", (linear, quadratic)),
("sample_qubo", (quadratic,))]
if dimod:
bqm = dimod.BQM.from_ising(linear, quadratic)
problems.append(("sample_bqm", (bqm,)))

return problems

@parameterized.expand([
("undefined", None),
("empty", ""),
("string", "text label")
])
@mock.patch.object(Client, 'create_session', lambda client: mock.Mock())
def test_label_is_sent(self, name, label):
"""Problem label is set on problem submit."""

with Client('endpoint', 'token') as client:
solver = Solver(client, solver_data('solver'))
problems = self.generate_sample_problems(solver)

for method_name, problem_args in problems:
with self.subTest(method_name=method_name):
sample = getattr(solver, method_name)

with mock.patch.object(Client, '_submit', self.on_submit_label_verifier(label)):

with self.assertRaises(self.PrimaryAssertionSatisfied):
sample(*problem_args, label=label).result()

@parameterized.expand([
("undefined", None),
("empty", ""),
("string", "text label")
])
def test_label_is_received(self, name, label):
"""Problem label is set from response in result/sampleset."""

def make_session_generator(label):
def create_mock_session(client):
session = mock.Mock()
session.post = lambda a, _: choose_reply(a, {
'problems/': '[%s]' % complete_no_answer_reply(
'123', 'abc123', label=None)})
session.get = lambda a: choose_reply(a, {
'problems/123/': complete_reply(
'123', 'abc123', label=label)})
return session
return create_mock_session

with mock.patch.object(Client, 'create_session', make_session_generator(label)):
with Client('endpoint', 'token') as client:
solver = Solver(client, solver_data('abc123'))
problems = self.generate_sample_problems(solver)

for method_name, problem_args in problems:
with self.subTest(method_name=method_name):
sample = getattr(solver, method_name)

future = sample(*problem_args, label=label)
info = future.sampleset.info # ensure future is resolved

self.assertEqual(future.label, label)
self.assertEqual(info['problem_label'], label)
74 changes: 72 additions & 2 deletions tests/test_mock_unstructured_submission.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import unittest
from unittest import mock
from parameterized import parameterized

import dimod
import numpy
Expand All @@ -38,7 +39,7 @@ def unstructured_solver_data(problem_type='bqm'):
"description": "A test unstructured solver"
}

def complete_reply(sampleset, problem_type='bqm'):
def complete_reply(sampleset, problem_type='bqm', label=""):
"""Reply with the sampleset as a solution."""

return json.dumps([{
Expand All @@ -51,7 +52,8 @@ def complete_reply(sampleset, problem_type='bqm'):
"data": sampleset.to_serializable()
},
"type": problem_type,
"id": "problem-id"
"id": "problem-id",
"label": label
}])

def choose_reply(path, replies):
Expand Down Expand Up @@ -229,3 +231,71 @@ def mock_upload(self, bqm):
for fut in futs:
with self.assertRaises(type(mock_upload_exc)):
fut.result()


class TestProblemLabel(unittest.TestCase):

class PrimaryAssertionSatisfied(Exception):
"""Raised by `on_submit_label_verifier` to signal correct label."""

def on_submit_label_verifier(self, expected_label):
"""Factory for mock Client._submit() that will verify existence, and
optionally validate label value.
"""

# replacement for Client._submit()
def _submit(client, body_data, computation):
body = json.loads(body_data.result())

if 'label' not in body:
if expected_label is None:
raise TestProblemLabel.PrimaryAssertionSatisfied
else:
raise AssertionError("label field missing")

label = body['label']
if label != expected_label:
raise AssertionError(
"unexpected label value: {!r} != {!r}".format(label, expected_label))

raise TestProblemLabel.PrimaryAssertionSatisfied

return _submit

@parameterized.expand([
("undefined", None),
("empty", ""),
("string", "text label")
])
@mock.patch.object(Client, 'create_session', lambda client: mock.Mock())
def test_label_is_sent(self, name, label):
"""Problem label is set on problem submit."""

bqm = dimod.BQM.from_ising({}, {'ab': 1})

# use a global mocked session, so we can modify it on-fly
session = mock.Mock()

# upload is now part of submit, so we need to mock it
mock_problem_id = 'mock-problem-id'
def mock_upload(self, bqm):
return Present(result=mock_problem_id)

# construct a functional solver by mocking client and api response data
with mock.patch.multiple(Client, create_session=lambda self: session,
upload_problem_encoded=mock_upload):
with Client('endpoint', 'token') as client:
solver = BQMSolver(client, unstructured_solver_data())

problems = [("sample_ising", (bqm.linear, bqm.quadratic)),
("sample_qubo", (bqm.quadratic,)),
("sample_bqm", (bqm,))]

for method_name, problem_args in problems:
with self.subTest(method_name=method_name):
sample = getattr(solver, method_name)

with mock.patch.object(Client, '_submit', self.on_submit_label_verifier(label)):

with self.assertRaises(self.PrimaryAssertionSatisfied):
sample(*problem_args, label=label).result()
13 changes: 13 additions & 0 deletions tests/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,11 @@ def _submit_and_check(self, solver, linear, quad, **kwargs):
self.assertAlmostEqual(
energy, evaluate_ising(linear, quad, state, offset=offset))

# label is optional
label = kwargs.get('label')
if label is not None:
self.assertEqual(results.label, label)

return results


Expand Down Expand Up @@ -205,6 +210,14 @@ def test_submit_partial_problem(self):

self._submit_and_check(solver, linear, quad)

def test_problem_label(self):
"""Problem label is set."""

with Client(**config) as client:
solver = client.get_solver()
linear, quad = generate_random_ising_problem(solver)
self._submit_and_check(solver, linear, quad, label="test")

@unittest.skipUnless(dimod, "dimod required for 'Solver.sample_bqm'")
def test_submit_bqm_ising_problem(self):
"""Submit an Ising BQM with all supported coefficients set."""
Expand Down

0 comments on commit a4e6dff

Please sign in to comment.