Skip to content

Commit

Permalink
Revert "fix issue #661 (#662)"
Browse files Browse the repository at this point in the history
This reverts commit 4bd1898.
  • Loading branch information
Routhleck committed May 12, 2024
1 parent 4bd1898 commit d795517
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 55 deletions.
94 changes: 47 additions & 47 deletions brainpy/_src/math/object_transform/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,52 +1172,52 @@ def f(a, b):



class TestHessian(unittest.TestCase):
def test_hessian5(self):
bm.set_mode(bm.training_mode)

class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)

def update(self, x):
return self.out(self.rnn(x))

# define the loss function
def lossfunc(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

model = RNN(1, 2)
data_x = bm.random.rand(1, 1000, 1)
data_y = data_x + bm.random.randn(1, 1000, 1)

bp.reset_state(model, 1)
losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
hess_matrix = losshess(data_x, data_y)

weights = model.train_vars().unique()

# define the loss function
def loss_func_for_jax(weight_vals, inputs, targets):
for k, v in weight_vals.items():
weights[k].value = v
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

bp.reset_state(model, 1)
jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)

for k, v in hess_matrix.items():
for kk, vv in v.items():
self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))

bm.clear_buffer_memory()
# class TestHessian(unittest.TestCase):
# def test_hessian5(self):
# bm.set_mode(bm.training_mode)
#
# class RNN(bp.DynamicalSystem):
# def __init__(self, num_in, num_hidden):
# super(RNN, self).__init__()
# self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
# self.out = bp.dnn.Dense(num_hidden, 1)
#
# def update(self, x):
# return self.out(self.rnn(x))
#
# # define the loss function
# def lossfunc(inputs, targets):
# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
# predicts = runner.predict(inputs)
# loss = bp.losses.mean_squared_error(predicts, targets)
# return loss
#
# model = RNN(1, 2)
# data_x = bm.random.rand(1, 1000, 1)
# data_y = data_x + bm.random.randn(1, 1000, 1)
#
# bp.reset_state(model, 1)
# losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
# hess_matrix = losshess(data_x, data_y)
#
# weights = model.train_vars().unique()
#
# # define the loss function
# def loss_func_for_jax(weight_vals, inputs, targets):
# for k, v in weight_vals.items():
# weights[k].value = v
# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
# predicts = runner.predict(inputs)
# loss = bp.losses.mean_squared_error(predicts, targets)
# return loss
#
# bp.reset_state(model, 1)
# jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)
#
# for k, v in hess_matrix.items():
# for kk, vv in v.items():
# self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))
#
# bm.clear_buffer_memory()


21 changes: 14 additions & 7 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,15 @@ def test1(self):
hh = bp.dyn.HH(1)
hh.reset()

tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)
tree = jax.tree_structure(hh)
leaves = jax.tree_leaves(hh)
# tree = jax.tree.structure(hh)
# leaves = jax.tree.leaves(hh)

print(tree)
print(leaves)
print(jax.tree.unflatten(tree, leaves))
print(jax.tree_unflatten(tree, leaves))
# print(jax.tree.unflatten(tree, leaves))
print()


Expand Down Expand Up @@ -281,13 +284,17 @@ def not_close(x, y):
def all_close(x, y):
assert bm.allclose(x, y)

jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)

random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
# random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
# jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)

obj.load_state_dict(random_state)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)



Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ matplotlib
msgpack
tqdm
pathos
taichi
taichi==1.7.0
numba
braincore
braintools
Expand Down

0 comments on commit d795517

Please sign in to comment.