Skip to content
This repository has been archived by the owner on Aug 18, 2023. It is now read-only.

Enable differentiation through hypernetworks #96

Merged
merged 6 commits into from
Oct 6, 2021
Merged

Enable differentiation through hypernetworks #96

merged 6 commits into from
Oct 6, 2021

Conversation

sahilpatelsp
Copy link
Contributor

#93

@zaqqwerty
Copy link
Contributor

As discussed, to be broken into two PRs, one to address #86 and a second to address #96 . Tests for all new features need to be added as well.

Copy link
Contributor

@zaqqwerty zaqqwerty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Other than my comments on the VQT hypernetwork test, a few additional tests would be good: trainable_variables tests for MLP, QNN, and QHBM

@@ -146,6 +146,49 @@ def test_loss_value_x_rot(self):
actual_thetas_grads, expected_thetas_grads, rtol=RTOL)
self.assertAllClose(actual_phis_grads, expected_phis_grads, rtol=RTOL)

def test_hyernetwork(self):
qubit = cirq.GridQubit(0, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If possible, could this be changed to the more general get_random_qhbm, like in the QMHL test?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

@@ -146,6 +146,49 @@ def test_loss_value_x_rot(self):
actual_thetas_grads, expected_thetas_grads, rtol=RTOL)
self.assertAllClose(actual_phis_grads, expected_phis_grads, rtol=RTOL)

def test_hyernetwork(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo in test name

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!

@sahilpatelsp sahilpatelsp marked this pull request as ready for review October 4, 2021 16:51
Copy link
Contributor

@zaqqwerty zaqqwerty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Last thing, could we add a test of the values of the hypernetwork gradient? Can be for a simple scalar hypernetwork.

qhbmlib/vqt.py Outdated
return grad_vars
return grad_vars, [tf.zeros_like(g) for g in grad_vars]
return grad_qhbm
return grad_qhbm, [tf.zeros_like(grad) for grad in grad_qhbm]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better not to have a collision with the function name, can this go back to just g?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree

qhbmlib/qmhl.py Outdated
return grad_vars
return grad_vars, [tf.zeros_like(g) for g in grad_vars]
return grad_qhbm
return grad_qhbm, [tf.zeros_like(grad) for grad in grad_qhbm]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: better not to have a collision with the function name, can this go back to just g?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I agree

Copy link
Contributor

@zaqqwerty zaqqwerty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zaqqwerty zaqqwerty merged commit b6a174f into google:main Oct 6, 2021
@jaeyoo
Copy link
Contributor

jaeyoo commented Oct 7, 2021

It shows the error. Could you please fix it?

I1006 15:14:34.905480    9052 test_util.py:2311] time(__main__.QNNTest.test_expectation): 8.53s
[       OK ] QNNTest.test_expectation
[ RUN      ] QNNTest.test_init
I1006 15:14:35.065576    9052 test_util.py:2311] time(__main__.QNNTest.test_init): 0.16s
[       OK ] QNNTest.test_init
[ RUN      ] QNNTest.test_pulled_back_circuits
I1006 15:14:50.442585    9052 test_util.py:2311] time(__main__.QNNTest.test_pulled_back_circuits): 15.38s
[       OK ] QNNTest.test_pulled_back_circuits
[ RUN      ] QNNTest.test_pulled_back_expectation
I1006 15:14:50.444001    9052 test_util.py:2311] time(__main__.QNNTest.test_pulled_back_expectation): 0.0s
[       OK ] QNNTest.test_pulled_back_expectation
[ RUN      ] QNNTest.test_pulled_back_sample_basic
I1006 15:14:50.892094    9052 test_util.py:2311] time(__main__.QNNTest.test_pulled_back_sample_basic): 0.45s
[       OK ] QNNTest.test_pulled_back_sample_basic
[ RUN      ] QNNTest.test_sample_basic
I1006 15:15:26.088289    9052 test_util.py:2311] time(__main__.QNNTest.test_sample_basic): 35.19s
[       OK ] QNNTest.test_sample_basic
[ RUN      ] QNNTest.test_sample_uneven
I1006 15:15:40.465752    9052 test_util.py:2311] time(__main__.QNNTest.test_sample_uneven): 14.37s
[       OK ] QNNTest.test_sample_uneven
[ RUN      ] QNNTest.test_session
[  SKIPPED ] QNNTest.test_session
[ RUN      ] QNNTest.test_trainable_variables
I1006 15:15:40.568558    9052 test_util.py:2311] time(__main__.QNNTest.test_trainable_variables): 0.1s
[       OK ] QNNTest.test_trainable_variables
======================================================================
FAIL: test_bit_circuit (__main__.BitCircuitTest)
BitCircuitTest.test_bit_circuit
Confirm correct bit injector circuit creation.
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tests/qnn_test.py", line 55, in test_bit_circuit
    self.assertAllEqual(test_symbols, expected_symbols)
AssertionError: 
Arrays are not equal

not equal where = (array([0, 1, 2]),)
not equal lhs = array(['build_bit_test_bit_0', 'build_bit_test_bit_1',
       'build_bit_test_bit_2'], dtype='not equal rhs = array([build_bit_test_bit_0, build_bit_test_bit_1, build_bit_test_bit_2],
      dtype=object)
Mismatched elements: 3 / 3 (100%)
 x: array([b'build_bit_test_bit_0', b'build_bit_test_bit_1',
       b'build_bit_test_bit_2'], dtype='|S20')
 y: array([build_bit_test_bit_0, build_bit_test_bit_1, build_bit_test_bit_2],
      dtype=object)

----------------------------------------------------------------------
Ran 13 tests in 75.615s

FAILED (failures=1, skipped=2)

@jaeyoo
Copy link
Contributor

jaeyoo commented Oct 7, 2021

NVM. I've fixed it. Closed.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants