-
Notifications
You must be signed in to change notification settings - Fork 15
Enable differentiation through hypernetworks #96
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking good! Other than my comments on the VQT hypernetwork test, a few additional tests would be good: trainable_variables
tests for MLP, QNN, and QHBM
tests/vqt_test.py
Outdated
@@ -146,6 +146,49 @@ def test_loss_value_x_rot(self): | |||
actual_thetas_grads, expected_thetas_grads, rtol=RTOL) | |||
self.assertAllClose(actual_phis_grads, expected_phis_grads, rtol=RTOL) | |||
|
|||
def test_hyernetwork(self): | |||
qubit = cirq.GridQubit(0, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If possible, could this be changed to the more general get_random_qhbm
, like in the QMHL test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.
tests/vqt_test.py
Outdated
@@ -146,6 +146,49 @@ def test_loss_value_x_rot(self): | |||
actual_thetas_grads, expected_thetas_grads, rtol=RTOL) | |||
self.assertAllClose(actual_phis_grads, expected_phis_grads, rtol=RTOL) | |||
|
|||
def test_hyernetwork(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in test name
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work! Last thing, could we add a test of the values of the hypernetwork gradient? Can be for a simple scalar hypernetwork.
qhbmlib/vqt.py
Outdated
return grad_vars | ||
return grad_vars, [tf.zeros_like(g) for g in grad_vars] | ||
return grad_qhbm | ||
return grad_qhbm, [tf.zeros_like(grad) for grad in grad_qhbm] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: better not to have a collision with the function name, can this go back to just g
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree
qhbmlib/qmhl.py
Outdated
return grad_vars | ||
return grad_vars, [tf.zeros_like(g) for g in grad_vars] | ||
return grad_qhbm | ||
return grad_qhbm, [tf.zeros_like(grad) for grad in grad_qhbm] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: better not to have a collision with the function name, can this go back to just g
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I agree
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
It shows the error. Could you please fix it?
|
NVM. I've fixed it. Closed. |
#93