Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
pytest _src/
export IS_GITHUB_ACTIONS=1 && pytest _src/


test_macos:
Expand Down Expand Up @@ -82,7 +82,7 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
pytest _src/
export IS_GITHUB_ACTIONS=1 && pytest _src/


test_windows:
Expand Down Expand Up @@ -113,4 +113,4 @@ jobs:
- name: Test with pytest
run: |
cd brainpy
pytest _src/ -p no:faulthandler
set IS_GITHUB_ACTIONS=1 && pytest _src/
3 changes: 2 additions & 1 deletion brainpy/_src/losses/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,8 @@ def update(self, input, target):


def nll_loss(input, target, reduction: str = 'mean'):
r"""The negative log likelihood loss.
r"""
The negative log likelihood loss.

The negative log likelihood loss. It is useful to train a classification
problem with `C` classes.
Expand Down
35 changes: 6 additions & 29 deletions brainpy/_src/math/compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .interoperability import *
from .ndarray import Array


__all__ = [
'full', 'full_like', 'eye', 'identity', 'diag', 'tri', 'tril', 'triu',
'empty', 'empty_like', 'ones', 'ones_like', 'zeros', 'zeros_like',
Expand Down Expand Up @@ -92,9 +91,8 @@
'ravel_multi_index', 'result_type', 'sort_complex', 'unpackbits', 'delete',

# unique
'add_docstring', 'add_newdoc', 'add_newdoc_ufunc', 'array2string', 'asanyarray',
'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'disp', 'genfromtxt',
'loadtxt', 'info', 'issubclass_', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
'asanyarray', 'ascontiguousarray', 'asfarray', 'asscalar', 'common_type', 'genfromtxt',
'loadtxt', 'info', 'place', 'polydiv', 'put', 'putmask', 'safe_eval',
'savetxt', 'savez_compressed', 'show_config', 'typename', 'copyto', 'matrix', 'asmatrix', 'mat',

]
Expand Down Expand Up @@ -204,11 +202,12 @@ def ascontiguousarray(a, dtype=None, order=None):
return asarray(a, dtype=dtype, order=order)


def asfarray(a, dtype=np.float_):
def asfarray(a, dtype=None):
if not np.issubdtype(dtype, np.inexact):
dtype = np.float_
dtype = np.float64
return asarray(a, dtype=dtype)


def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
del assume_unique
ar1_flat = ravel(ar1)
Expand All @@ -227,6 +226,7 @@ def in1d(ar1, ar2, assume_unique: bool = False, invert: bool = False) -> Array:
else:
return asarray((ar1_flat[:, None] == ar2_flat[None, :]).any(-1))


# Others
# ------
meshgrid = _compatible_with_brainpy_array(jnp.meshgrid)
Expand Down Expand Up @@ -454,7 +454,6 @@ def msort(a):
sometrue = any



def shape(a):
"""
Return the shape of an array.
Expand Down Expand Up @@ -648,7 +647,6 @@ def size(a, axis=None):
finfo = jnp.finfo
iinfo = jnp.iinfo


can_cast = _compatible_with_brainpy_array(jnp.can_cast)
choose = _compatible_with_brainpy_array(jnp.choose)
copy = _compatible_with_brainpy_array(jnp.copy)
Expand Down Expand Up @@ -678,23 +676,6 @@ def size(a, axis=None):
# Unique APIs
# -----------

add_docstring = np.add_docstring
add_newdoc = np.add_newdoc
add_newdoc_ufunc = np.add_newdoc_ufunc


def array2string(a, max_line_width=None, precision=None,
suppress_small=None, separator=' ', prefix="",
style=np._NoValue, formatter=None, threshold=None,
edgeitems=None, sign=None, floatmode=None, suffix="",
legacy=None):
a = as_numpy(a)
return array2string(a, max_line_width=max_line_width, precision=precision,
suppress_small=suppress_small, separator=separator, prefix=prefix,
style=style, formatter=formatter, threshold=threshold,
edgeitems=edgeitems, sign=sign, floatmode=floatmode, suffix=suffix,
legacy=legacy)


def asscalar(a):
return a.item()
Expand Down Expand Up @@ -731,13 +712,9 @@ def common_type(*arrays):
return array_type[0][precision]


disp = np.disp

genfromtxt = lambda *args, **kwargs: asarray(np.genfromtxt(*args, **kwargs))
loadtxt = lambda *args, **kwargs: asarray(np.loadtxt(*args, **kwargs))

info = np.info
issubclass_ = np.issubclass_


def place(arr, mask, vals):
Expand Down
14 changes: 13 additions & 1 deletion brainpy/_src/math/event/tests/test_event_csrmm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,27 @@
# -*- coding: utf-8 -*-

import os
from functools import partial

import jax
import pytest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('gpu')

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)


# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

seed = 1234


Expand Down
7 changes: 5 additions & 2 deletions brainpy/_src/math/event/tests/test_event_csrmv.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-


import os
from functools import partial

import jax
Expand All @@ -19,6 +18,10 @@
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

seed = 1234

Expand Down
5 changes: 5 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_event_matvec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import os

import jax
import jax.numpy as jnp
Expand All @@ -16,6 +17,10 @@
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

shapes = [(100, 200), (1000, 10)]

Expand Down
12 changes: 10 additions & 2 deletions brainpy/_src/math/jitconn/tests/test_get_weight_matrix.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
import os

import jax.numpy as jnp
import pytest
from absl.testing import parameterized
Expand All @@ -12,8 +14,14 @@
import platform

force_test = False # turn on to force test on windows locally
# if platform.system() == 'Windows' and not force_test:
# pytest.skip('skip windows', allow_module_level=True)
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)


shapes = [
(2, 2),
Expand Down
5 changes: 5 additions & 0 deletions brainpy/_src/math/jitconn/tests/test_matvec.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# -*- coding: utf-8 -*-
import os

import jax
import jax.numpy as jnp
Expand All @@ -16,6 +17,10 @@
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

shapes = [(100, 200), (1000, 10)]

Expand Down
5 changes: 5 additions & 0 deletions brainpy/_src/math/op_register/tests/test_taichi_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@
import brainpy.math as bm
from brainpy._src.dependency_check import import_taichi

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

ti = import_taichi(error_if_not_found=False)
if ti is None:
pytest.skip('no taichi', allow_module_level=True)
Expand Down
16 changes: 15 additions & 1 deletion brainpy/_src/math/sparse/tests/test_csrmm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
# -*- coding: utf-8 -*-


import os
from functools import partial

import jax
import pytest
from absl.testing import parameterized

import brainpy as bp
import brainpy.math as bm

# bm.set_platform('gpu')

import platform
force_test = False # turn on to force test on windows locally
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

seed = 1234


Expand Down Expand Up @@ -133,7 +146,8 @@ def test_homo_grad(self, transpose, shape, homo_data):
argnums=0)
r1 = dense_f1(homo_data)
r2 = jax.grad(sum_op(bm.sparse.csrmm))(
bm.asarray([homo_data]), indices, indptr, matrix, shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]),
bm.asarray([homo_data]), indices, indptr, matrix,
shape=(shape[1], shape[0]) if transpose else (shape[0], shape[1]),
transpose=transpose)

self.assertTrue(bm.allclose(r1, r2))
Expand Down
6 changes: 5 additions & 1 deletion brainpy/_src/math/sparse/tests/test_csrmv.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# -*- coding: utf-8 -*-

import os
from functools import partial

import jax
Expand All @@ -17,6 +17,10 @@
if platform.system() == 'Windows' and not force_test:
pytest.skip('skip windows', allow_module_level=True)

# Skip the test in Github Actions
IS_GITHUB_ACTIONS = os.getenv('IS_GITHUB_ACTIONS', '0')
if IS_GITHUB_ACTIONS == '1':
pytest.skip('Skip the test in Github Actions', allow_module_level=True)

seed = 1234

Expand Down
1 change: 0 additions & 1 deletion brainpy/_src/math/surrogate/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# -*- coding: utf-8 -*-


from .base import *
from ._one_input_new import *
from ._two_inputs import *
11 changes: 10 additions & 1 deletion brainpy/_src/math/surrogate/_one_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

from brainpy._src.math.interoperability import as_jax
from brainpy._src.math.ndarray import Array
from .base import Surrogate

__all__ = [
'sigmoid',
Expand All @@ -32,6 +31,16 @@
]


class Surrogate(object):
"""The base surrograte gradient function."""

def __call__(self, *args, **kwargs):
raise NotImplementedError

def __repr__(self):
return f'{self.__class__.__name__}()'


class _OneInpSurrogate(Surrogate):
def __init__(self, forward_use_surrogate=False):
self.forward_use_surrogate = forward_use_surrogate
Expand Down
3 changes: 2 additions & 1 deletion brainpy/_src/math/surrogate/_one_input_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brainpy._src.math.ndarray import Array

__all__ = [
'Surrogate',
'Sigmoid',
'sigmoid',
'PiecewiseQuadratic',
Expand Down Expand Up @@ -61,7 +62,7 @@ def _heaviside_imp(x, dx):


def _heaviside_batching(args, axes):
return heaviside_p.bind(*args), axes
return heaviside_p.bind(*args), [axes[0]]


def _heaviside_jvp(primals, tangents):
Expand Down
19 changes: 0 additions & 19 deletions brainpy/_src/math/surrogate/base.py

This file was deleted.

6 changes: 0 additions & 6 deletions brainpy/math/compat_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,20 +327,14 @@
sort_complex as sort_complex,
unpackbits as unpackbits,
delete as delete,
add_docstring as add_docstring,
add_newdoc as add_newdoc,
add_newdoc_ufunc as add_newdoc_ufunc,
array2string as array2string,
asanyarray as asanyarray,
ascontiguousarray as ascontiguousarray,
asfarray as asfarray,
asscalar as asscalar,
common_type as common_type,
disp as disp,
genfromtxt as genfromtxt,
loadtxt as loadtxt,
info as info,
issubclass_ as issubclass_,
place as place,
polydiv as polydiv,
put as put,
Expand Down
Loading