Skip to content

Commit

Permalink
Revert commits
Browse files Browse the repository at this point in the history
  • Loading branch information
Routhleck committed May 12, 2024
1 parent 4e3151e commit 85a51df
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 101 deletions.
2 changes: 0 additions & 2 deletions brainpy/_src/dnn/tests/test_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@

from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
2 changes: 0 additions & 2 deletions brainpy/_src/dnn/tests/test_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
4 changes: 1 addition & 3 deletions brainpy/_src/dyn/projections/tests/test_STDP.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@

import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)
from brainpy._src.dependency_check import import_taichi

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)
Expand Down
2 changes: 0 additions & 2 deletions brainpy/_src/math/event/tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
2 changes: 0 additions & 2 deletions brainpy/_src/math/jitconn/tests/test_event_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
2 changes: 0 additions & 2 deletions brainpy/_src/math/jitconn/tests/test_matvec.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
35 changes: 8 additions & 27 deletions brainpy/_src/math/object_transform/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -884,12 +884,8 @@ def hessian(
func: Callable,
grad_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
argnums: Optional[Union[int, Sequence[int]]] = None,
return_value: bool = False,
has_aux: Optional[bool] = None,
holomorphic=False,

# deprecated
dyn_vars: Optional[Union[Variable, Sequence[Variable], Dict[str, Variable]]] = None,
child_objs: Optional[Union[BrainPyObject, Sequence[BrainPyObject], Dict[str, BrainPyObject]]] = None,
) -> ObjectTransform:
"""Hessian of ``func`` as a dense array.
Expand All @@ -916,29 +912,14 @@ def hessian(
obj: ObjectTransform
The transformed object.
"""
child_objs = check.is_all_objs(child_objs, out_as='dict')
dyn_vars = check.is_all_vars(dyn_vars, out_as='dict')

return jacfwd(jacrev(func,
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic),
dyn_vars=dyn_vars,
child_objs=child_objs,
grad_vars=grad_vars,
argnums=argnums,
holomorphic=holomorphic,
return_value=return_value)

# return GradientTransformPreserveTree(target=func,
# transform=jax.hessian,
# grad_vars=grad_vars,
# argnums=argnums,
# has_aux=False if has_aux is None else has_aux,
# transform_setting=dict(holomorphic=holomorphic),
# return_value=False)
return GradientTransformPreserveTree(target=func,
transform=jax.hessian,
grad_vars=grad_vars,
argnums=argnums,
has_aux=False if has_aux is None else has_aux,
transform_setting=dict(holomorphic=holomorphic),
return_value=False)


def functional_vector_grad(func, argnums=0, return_value=False, has_aux=False):
Expand Down
94 changes: 47 additions & 47 deletions brainpy/_src/math/object_transform/tests/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1172,52 +1172,52 @@ def f(a, b):



class TestHessian(unittest.TestCase):
def test_hessian5(self):
bm.set_mode(bm.training_mode)

class RNN(bp.DynamicalSystem):
def __init__(self, num_in, num_hidden):
super(RNN, self).__init__()
self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
self.out = bp.dnn.Dense(num_hidden, 1)

def update(self, x):
return self.out(self.rnn(x))

# define the loss function
def lossfunc(inputs, targets):
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

model = RNN(1, 2)
data_x = bm.random.rand(1, 1000, 1)
data_y = data_x + bm.random.randn(1, 1000, 1)

bp.reset_state(model, 1)
losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
hess_matrix = losshess(data_x, data_y)

weights = model.train_vars().unique()

# define the loss function
def loss_func_for_jax(weight_vals, inputs, targets):
for k, v in weight_vals.items():
weights[k].value = v
runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
predicts = runner.predict(inputs)
loss = bp.losses.mean_squared_error(predicts, targets)
return loss

bp.reset_state(model, 1)
jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)

for k, v in hess_matrix.items():
for kk, vv in v.items():
self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))

bm.clear_buffer_memory()
# class TestHessian(unittest.TestCase):
# def test_hessian5(self):
# bm.set_mode(bm.training_mode)
#
# class RNN(bp.DynamicalSystem):
# def __init__(self, num_in, num_hidden):
# super(RNN, self).__init__()
# self.rnn = bp.dyn.RNNCell(num_in, num_hidden, train_state=True)
# self.out = bp.dnn.Dense(num_hidden, 1)
#
# def update(self, x):
# return self.out(self.rnn(x))
#
# # define the loss function
# def lossfunc(inputs, targets):
# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
# predicts = runner.predict(inputs)
# loss = bp.losses.mean_squared_error(predicts, targets)
# return loss
#
# model = RNN(1, 2)
# data_x = bm.random.rand(1, 1000, 1)
# data_y = data_x + bm.random.randn(1, 1000, 1)
#
# bp.reset_state(model, 1)
# losshess = bm.hessian(lossfunc, grad_vars=model.train_vars())
# hess_matrix = losshess(data_x, data_y)
#
# weights = model.train_vars().unique()
#
# # define the loss function
# def loss_func_for_jax(weight_vals, inputs, targets):
# for k, v in weight_vals.items():
# weights[k].value = v
# runner = bp.DSTrainer(model, progress_bar=False, numpy_mon_after_run=False)
# predicts = runner.predict(inputs)
# loss = bp.losses.mean_squared_error(predicts, targets)
# return loss
#
# bp.reset_state(model, 1)
# jax_hessian = jax.hessian(loss_func_for_jax, argnums=0)({k: v.value for k, v in weights.items()}, data_x, data_y)
#
# for k, v in hess_matrix.items():
# for kk, vv in v.items():
# self.assertTrue(bm.allclose(vv, jax_hessian[k][kk], atol=1e-4))
#
# bm.clear_buffer_memory()


26 changes: 15 additions & 11 deletions brainpy/_src/math/object_transform/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,15 @@ def test1(self):
hh = bp.dyn.HH(1)
hh.reset()

tree = jax.tree.structure(hh)
leaves = jax.tree.leaves(hh)
# tree = jax.tree.structure(hh)
# leaves = jax.tree.leaves(hh)
tree = jax.tree_structure(hh)
leaves = jax.tree_leaves(hh)

print(tree)
print(leaves)
print(jax.tree.unflatten(tree, leaves))
# print(jax.tree.unflatten(tree, leaves))
print(jax.tree_unflatten(tree, leaves))
print()


Expand Down Expand Up @@ -278,17 +281,18 @@ def update(self, x):

def not_close(x, y):
assert not bm.allclose(x, y)

def all_close(x, y):
assert bm.allclose(x, y)

jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
# jax.tree.map(all_close, all_states, variables, is_leaf=bm.is_bp_array)
jax.tree_map(all_close, all_states, variables, is_leaf=bm.is_bp_array)

random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
# random_state = jax.tree.map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
# jax.tree.map(not_close, random_state, variables, is_leaf=bm.is_bp_array)
random_state = jax.tree_map(bm.random.rand_like, all_states, is_leaf=bm.is_bp_array)
jax.tree_map(not_close, random_state, variables, is_leaf=bm.is_bp_array)

obj.load_state_dict(random_state)
jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)




# jax.tree.map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
jax.tree_map(all_close, random_state, variables, is_leaf=bm.is_bp_array)
3 changes: 0 additions & 3 deletions brainpy/_src/math/sparse/tests/test_csrmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@
import brainpy as bp
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

pytest.skip('Remove customize op tests', allow_module_level=True)

if import_taichi(error_if_not_found=False) is None:
pytest.skip('no taichi', allow_module_level=True)

Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ tqdm
pathos
taichi==1.7.0
numba
braincore
braintools


# test requirements
Expand Down

0 comments on commit 85a51df

Please sign in to comment.