From 324aec3160e20db01c9227f9a5b80038ea5cfe14 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 2 Oct 2022 11:22:49 +0800 Subject: [PATCH 1/4] Support initializing a Variable by data shape --- brainpy/math/jaxarray.py | 42 +++++++++++++++++++++++++---- brainpy/math/tests/test_jaxarray.py | 11 ++++++++ 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/brainpy/math/jaxarray.py b/brainpy/math/jaxarray.py index 409c400bd..f9003df16 100644 --- a/brainpy/math/jaxarray.py +++ b/brainpy/math/jaxarray.py @@ -885,10 +885,42 @@ def __jax_array__(self): class Variable(JaxArray): """The pointer to specify the dynamical variable. + + Initializing an instance of ``Variable`` by two ways: + + >>> import brainpy.math as bm + >>> # 1. init a Variable by the concreate data + >>> v1 = bm.Variable(bm.zeros(10)) + >>> # 2. init a Variable by the data shape + >>> v2 = bm.Variable(10) + + Note that when initializing a `Variable` by the data shape, + all values in this `Variable` will be initialized as zeros. + + Parameters + ---------- + value_or_size: Shape, Array + The value or the size of the value. + dtype: + The type of the data. + batch_axis: optional, int + The batch axis. """ __slots__ = ('_value', '_batch_axis') - def __init__(self, value, dtype=None, batch_axis: int = None): + def __init__( + self, + value_or_size, + dtype=None, + batch_axis: int = None + ): + if isinstance(value_or_size, int): + value = jnp.zeros(value_or_size, dtype=dtype) + elif isinstance(value_or_size, (tuple, list)) and all([isinstance(s, int) for s in value_or_size]): + value = jnp.zeros(value_or_size, dtype=dtype) + else: + value = value_or_size + super(Variable, self).__init__(value, dtype=dtype) # check batch axis @@ -1464,8 +1496,8 @@ class TrainVar(Variable): """ __slots__ = ('_value', '_batch_axis') - def __init__(self, value, dtype=None, batch_axis: int = None): - super(TrainVar, self).__init__(value, dtype=dtype, batch_axis=batch_axis) + def __init__(self, value_or_size, dtype=None, batch_axis: int = None): + super(TrainVar, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis) class Parameter(Variable): @@ -1473,8 +1505,8 @@ class Parameter(Variable): """ __slots__ = ('_value', '_batch_axis') - def __init__(self, value, dtype=None, batch_axis: int = None): - super(Parameter, self).__init__(value, dtype=dtype, batch_axis=batch_axis) + def __init__(self, value_or_size, dtype=None, batch_axis: int = None): + super(Parameter, self).__init__(value_or_size, dtype=dtype, batch_axis=batch_axis) register_pytree_node(JaxArray, diff --git a/brainpy/math/tests/test_jaxarray.py b/brainpy/math/tests/test_jaxarray.py index e4d6b3059..2f6a9c10e 100644 --- a/brainpy/math/tests/test_jaxarray.py +++ b/brainpy/math/tests/test_jaxarray.py @@ -40,3 +40,14 @@ def test_none(self): ee = a + e +class TestVariable(unittest.TestCase): + def test_variable_init(self): + self.assertTrue( + bm.array_equal(bm.Variable(bm.zeros(10)), + bm.Variable(10)) + ) + bm.random.seed(123) + self.assertTrue( + not bm.array_equal(bm.Variable(bm.random.rand(10)), + bm.Variable(10)) + ) From 409761fd42937ef2f1a565cfb7cb93100f6d6948 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 2 Oct 2022 13:21:15 +0800 Subject: [PATCH 2/4] fix bugs on GapJunction model --- brainpy/dyn/synapses/gap_junction.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/brainpy/dyn/synapses/gap_junction.py b/brainpy/dyn/synapses/gap_junction.py index 1b4027042..8d94c633f 100644 --- a/brainpy/dyn/synapses/gap_junction.py +++ b/brainpy/dyn/synapses/gap_junction.py @@ -29,8 +29,8 @@ def __init__( conn=conn, name=name) # checking - self.check_pre_attrs('V', 'spike') - self.check_post_attrs('V', 'input', 'spike') + self.check_pre_attrs('V') + self.check_post_attrs('V', 'input') # assert isinstance(self.output, _NullSynOut) # assert isinstance(self.stp, _NullSynSTP) From 10d25bb1caaf190aef3f689994b388e953239817 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 2 Oct 2022 13:21:55 +0800 Subject: [PATCH 3/4] add example of Fazli_2022_gj_coupled_bursting_pituitary_cells.py --- ...022_gj_coupled_bursting_pituitary_cells.py | 101 ++++++++++++++++++ 1 file changed, 101 insertions(+) create mode 100644 examples/simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py diff --git a/examples/simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py b/examples/simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py new file mode 100644 index 000000000..89bad5eab --- /dev/null +++ b/examples/simulation/Fazli_2022_gj_coupled_bursting_pituitary_cells.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- + + +"""" +Implementation of the paper: + +- Fazli, Mehran, and Richard Bertram. "Network Properties of Electrically + Coupled Bursting Pituitary Cells." Frontiers in Endocrinology 13 (2022). +""" + +import brainpy as bp +import brainpy.math as bm + + +class PituitaryCell(bp.NeuGroup): + def __init__(self, size, name=None): + super(PituitaryCell, self).__init__(size, name=name) + + # parameter values + self.vn = -5 + self.kc = 0.12 + self.ff = 0.005 + self.vca = 60 + self.vk = -75 + self.vl = -50.0 + self.gk = 2.5 + self.cm = 5 + self.gbk = 1 + self.gca = 2.1 + self.gsk = 2 + self.vm = -20 + self.vb = -5 + self.sn = 10 + self.sm = 12 + self.sbk = 2 + self.taun = 30 + self.taubk = 5 + self.ks = 0.4 + self.alpha = 0.0015 + self.gl = 0.2 + + # variables + self.V = bm.Variable(bm.random.random(self.num) * -90 + 20) + self.n = bm.Variable(bm.random.random(self.num) / 2) + self.b = bm.Variable(bm.random.random(self.num) / 2) + self.c = bm.Variable(bm.random.random(self.num)) + self.input = bm.Variable(self.num) + + # integrators + self.integral = bp.odeint(bp.JointEq(self.dV, self.dn, self.dc, self.db), method='exp_euler') + + def dn(self, n, t, V): + ninf = 1 / (1 + bm.exp((self.vn - V) / self.sn)) + return (ninf - n) / self.taun + + def db(self, b, t, V): + bkinf = 1 / (1 + bm.exp((self.vb - V) / self.sbk)) + return (bkinf - b) / self.taubk + + def dc(self, c, t, V): + minf = 1 / (1 + bm.exp((self.vm - V) / self.sm)) + ica = self.gca * minf * (V - self.vca) + return -self.ff * (self.alpha * ica + self.kc * c) + + def dV(self, V, t, n, b, c): + minf = 1 / (1 + bm.exp((self.vm - V) / self.sm)) + cinf = c ** 2 / (c ** 2 + self.ks * self.ks) + ica = self.gca * minf * (V - self.vca) + isk = self.gsk * cinf * (V - self.vk) + ibk = self.gbk * b * (V - self.vk) + ikdr = self.gk * n * (V - self.vk) + il = self.gl * (V - self.vl) + return -(ica + isk + ibk + ikdr + il + self.input) / self.cm + + def update(self, tdi, x=None): + V, n, c, b = self.integral(self.V.value, self.n.value, self.c.value, self.b.value, tdi.t, tdi.dt) + self.V.value = V + self.n.value = n + self.c.value = c + self.b.value = b + + def clear_input(self): + self.input.value = bm.zeros_like(self.input) + + +class PituitaryNetwork(bp.Network): + def __init__(self, num, gc): + super(PituitaryNetwork, self).__init__() + + self.N = PituitaryCell(num) + self.gj = bp.synapses.GapJunction(self.N, self.N, bp.conn.All2All(include_self=False), g_max=gc) + + +if __name__ == '__main__': + net = PituitaryNetwork(2, 0.002) + runner = bp.DSRunner(net, monitors={'V': net.N.V}, dt=0.5) + runner.run(10 * 1e3) + + fig, gs = bp.visualize.get_figure(1, 1, 6, 10) + fig.add_subplot(gs[0, 0]) + bp.visualize.line_plot(runner.mon.ts, runner.mon.V, plot_ids=(0, 1), show=True) From 359dcbeddeda9241b297b50b0b23ff649b7fd143 Mon Sep 17 00:00:00 2001 From: chaoming Date: Sun, 2 Oct 2022 13:22:30 +0800 Subject: [PATCH 4/4] update brainpylib --- brainpy/dyn/base.py | 4 +- brainpy/dyn/layers/dropout.py | 4 -- extensions/brainpylib/atomic_sum.py | 3 +- extensions/brainpylib/event_sum.py | 60 ++++++++++++++--------------- extensions/brainpylib/utils.py | 18 +++++++++ extensions/lib/event_sum_cpu.cc | 58 +++++++++++++++++++++++----- extensions/lib/event_sum_gpu.cu | 35 ++++++----------- setup.py | 9 ++++- 8 files changed, 120 insertions(+), 71 deletions(-) create mode 100644 extensions/brainpylib/utils.py diff --git a/brainpy/dyn/base.py b/brainpy/dyn/base.py index 27e31cab5..82d43fb05 100644 --- a/brainpy/dyn/base.py +++ b/brainpy/dyn/base.py @@ -756,8 +756,8 @@ def __init__( def __repr__(self): names = self.__class__.__name__ - return (f'{names}(name={self.name}, mode={self.mode}, ' - f'{" " * len(names)} pre={self.pre}, ' + return (f'{names}(name={self.name}, mode={self.mode}, \n' + f'{" " * len(names)} pre={self.pre}, \n' f'{" " * len(names)} post={self.post})') def check_pre_attrs(self, *attrs): diff --git a/brainpy/dyn/layers/dropout.py b/brainpy/dyn/layers/dropout.py index 542844006..0f6ae2b73 100644 --- a/brainpy/dyn/layers/dropout.py +++ b/brainpy/dyn/layers/dropout.py @@ -15,10 +15,6 @@ class Dropout(DynamicalSystem): In training, to compensate for the fraction of input values dropped (`rate`), all surviving values are multiplied by `1 / (1 - rate)`. - The parameter `shared_axes` allows to specify a list of axes on which - the mask will be shared: we will use size 1 on those axes for dropout mask - and broadcast it. Sharing reduces randomness, but can save memory. - This layer is active only during training (`mode='train'`). In other circumstances it is a no-op. diff --git a/extensions/brainpylib/atomic_sum.py b/extensions/brainpylib/atomic_sum.py index acd18f188..6f6246d12 100644 --- a/extensions/brainpylib/atomic_sum.py +++ b/extensions/brainpylib/atomic_sum.py @@ -115,7 +115,8 @@ def _atomic_sum_translation(c, values, pre_ids, post_ids, *, post_num, platform= shape_with_layout=x_shape(np.dtype(values_dtype), (post_num,), (0,)), ) elif platform == 'gpu': - if gpu_ops is None: raise ValueError('Cannot find compiled gpu wheels.') + if gpu_ops is None: + raise ValueError('Cannot find compiled gpu wheels.') opaque = gpu_ops.build_atomic_sum_descriptor(conn_size, post_num) if values_dim[0] != 1: diff --git a/extensions/brainpylib/event_sum.py b/extensions/brainpylib/event_sum.py index a3cf01d3b..433bf8373 100644 --- a/extensions/brainpylib/event_sum.py +++ b/extensions/brainpylib/event_sum.py @@ -7,14 +7,17 @@ from functools import partial +from typing import Union, Tuple import jax.numpy as jnp import numpy as np -from jax import core +from jax import core, dtypes from jax.abstract_arrays import ShapedArray from jax.interpreters import xla, batching from jax.lax import scan from jax.lib import xla_client +from .utils import GPUOperatorNotFound + try: from . import gpu_ops except ImportError: @@ -26,7 +29,10 @@ _event_sum_prim = core.Primitive("event_sum") -def event_sum(events, pre2post, post_num, values): +def event_sum(events: jnp.ndarray, + pre2post: Tuple[jnp.ndarray, jnp.ndarray], + post_num: int, + values: Union[float, jnp.ndarray]): # events if events.dtype != jnp.bool_: raise ValueError(f'"events" must be a vector of bool, while we got {events.dtype}') @@ -39,17 +45,16 @@ def event_sum(events, pre2post, post_num, values): if indices.dtype != indptr.dtype: raise ValueError(f"The dtype of pre2post[0] must be equal to that of pre2post[1], " f"while we got {(indices.dtype, indptr.dtype)}") - if indices.dtype not in [jnp.uint32, jnp.uint64]: - raise ValueError(f'The dtype of pre2post must be uint32 or uint64, while we got {indices.dtype}') + if indices.dtype not in [jnp.uint32, jnp.uint64, jnp.int32, jnp.int64]: + raise ValueError(f'The dtype of pre2post must be integer, while we got {indices.dtype}') # output value - values = jnp.asarray([values]) - if values.dtype not in [jnp.float32, jnp.float64]: - raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {values.dtype}.') - if values.size not in [1, indices.size]: + dtype = values.dtype if isinstance(values, jnp.ndarray) else dtypes.canonicalize_dtype(type(values)) + if dtype not in [jnp.float32, jnp.float64]: + raise ValueError(f'The dtype of "values" must be float32 or float64, while we got {dtype}.') + if np.size(values) not in [1, indices.size]: raise ValueError(f'The size of "values" must be 1 (a scalar) or len(pre2post[0]) (a vector), ' - f'while we got {values.size} != 1 != {indices.size}') - values = values.flatten() + f'while we got {np.size(values)} != 1 != {indices.size}') # bind operator return _event_sum_prim.bind(events, indices, indptr, values, post_num=post_num) @@ -58,12 +63,8 @@ def _event_sum_abstract(events, indices, indptr, values, *, post_num): return ShapedArray(dtype=values.dtype, shape=(post_num,)) -_event_sum_prim.def_abstract_eval(_event_sum_abstract) -_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim)) - - def _event_sum_translation(c, events, indices, indptr, values, *, post_num, platform="cpu"): - # The pre/post shape + # The shape of pre/post pre_size = np.array(c.get_shape(events).dimensions()[0], dtype=np.uint32) _pre_shape = x_shape(np.dtype(np.uint32), (), ()) _post_shape = x_shape(np.dtype(np.uint32), (), ()) @@ -71,21 +72,18 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat # The indices shape indices_shape = c.get_shape(indices) Itype = indices_shape.element_type() - assert Itype in [np.uint32, np.uint64] # The value shape values_shape = c.get_shape(values) Ftype = values_shape.element_type() - assert Ftype in [np.float32, np.float64] values_dim = values_shape.dimensions() # We dispatch a different call depending on the dtype - f_type = b'_f32' if Ftype == np.float32 else b'_f64' - i_type = b'_i32' if Itype == np.uint32 else b'_i64' + f_type = b'_f32' if Ftype in np.float32 else b'_f64' + i_type = b'_i32' if Itype in [np.uint32, np.int32] else b'_i64' - # And then the following is what changes between the GPU and CPU if platform == "cpu": - v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter' + v_type = b'_event_sum_homo' if len(values_dim) == 0 else b'_event_sum_heter' return x_ops.CustomCallWithLayout( c, platform.encode() + v_type + f_type + i_type, @@ -103,9 +101,12 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat c.get_shape(values)), shape_with_layout=x_shape(np.dtype(Ftype), (post_num,), (0,)), ) + + # GPU platform elif platform == 'gpu': if gpu_ops is None: - raise ValueError('Cannot find compiled gpu wheels.') + raise GPUOperatorNotFound('event_sum') + v_type = b'_event_sum_homo' if values_dim[0] == 1 else b'_event_sum_heter' opaque = gpu_ops.build_event_sum_descriptor(pre_size, post_num) return x_ops.CustomCallWithLayout( @@ -127,11 +128,7 @@ def _event_sum_translation(c, events, indices, indptr, values, *, post_num, plat raise ValueError("Unsupported platform, we only support 'cpu' or 'gpu'") -xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu") -xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu") - - -def _event_sum_batch(args, axes): +def _event_sum_batch(args, axes, *, post_num): batch_axes, batch_args, non_batch_args = [], {}, {} for ax_i, ax in enumerate(axes): if ax is None: @@ -143,19 +140,22 @@ def _event_sum_batch(args, axes): def f(_, x): pars = tuple([(x[f'ax{i}'] if i in batch_axes else non_batch_args[f'ax{i}']) for i in range(len(axes))]) - return 0, _event_sum_prim.bind(*pars) + return 0, _event_sum_prim.bind(*pars, post_num=post_num) + _, outs = scan(f, 0, batch_args) return outs, 0 +_event_sum_prim.def_abstract_eval(_event_sum_abstract) +_event_sum_prim.def_impl(partial(xla.apply_primitive, _event_sum_prim)) batching.primitive_batchers[_event_sum_prim] = _event_sum_batch - +xla.backend_specific_translations["cpu"][_event_sum_prim] = partial(_event_sum_translation, platform="cpu") +xla.backend_specific_translations["gpu"][_event_sum_prim] = partial(_event_sum_translation, platform="gpu") # --------------------------- # event sum kernel 2 # --------------------------- - _event_sum2_prim = core.Primitive("event_sum2") diff --git a/extensions/brainpylib/utils.py b/extensions/brainpylib/utils.py new file mode 100644 index 000000000..07df0b047 --- /dev/null +++ b/extensions/brainpylib/utils.py @@ -0,0 +1,18 @@ +# -*- coding: utf-8 -*- + + +__all__ = [ + 'GPUOperatorNotFound', +] + + +class GPUOperatorNotFound(Exception): + def __init__(self, name): + super(GPUOperatorNotFound, self).__init__(f''' +GPU operator for "{name}" does not found. + +Please compile brainpylib GPU operators with the guidance in the following link: + +https://brainpy.readthedocs.io/en/latest/tutorial_advanced/compile_brainpylib.html + ''') + diff --git a/extensions/lib/event_sum_cpu.cc b/extensions/lib/event_sum_cpu.cc index e807e8799..18bf8c738 100644 --- a/extensions/lib/event_sum_cpu.cc +++ b/extensions/lib/event_sum_cpu.cc @@ -4,24 +4,42 @@ namespace brainpy_lib { namespace{ template void cpu_event_sum_homo(void *out, const void **in) { - // Parse the inputs const std::uint32_t pre_size = *reinterpret_cast(in[0]); const std::uint32_t post_size = *reinterpret_cast(in[1]); const bool *events = reinterpret_cast(in[2]); const I *indices = reinterpret_cast(in[3]); const I *indptr = reinterpret_cast(in[4]); - const F *values = reinterpret_cast(in[5]); - const F value = values[0]; + const F weight = *reinterpret_cast(in[5]); + F *result = reinterpret_cast(out); - // The output + // algorithm + memset(&result[0], 0, sizeof(F) * post_size); + for (std::uint32_t i=0; i + void cpu_event_sum_batch_homo(void *out, const void **in) { + const std::uint32_t pre_size = *reinterpret_cast(in[0]); + const std::uint32_t post_size = *reinterpret_cast(in[1]); + const bool *events = reinterpret_cast(in[2]); + const I *indices = reinterpret_cast(in[3]); + const I *indptr = reinterpret_cast(in[4]); + const F weight = *reinterpret_cast(in[5]); F *result = reinterpret_cast(out); // algorithm - memset(&result[0], 0, sizeof(result[0]) * post_size); + memset(&result[0], 0, sizeof(F) * post_size); for (std::uint32_t i=0; i void cpu_event_sum_heter(void *out, const void **in) { - // Parse the inputs const std::uint32_t pre_size = *reinterpret_cast(in[0]); const std::uint32_t post_size = *reinterpret_cast(in[1]); const bool *events = reinterpret_cast(in[2]); const I *indices = reinterpret_cast(in[3]); const I *indptr = reinterpret_cast(in[4]); const F *values = reinterpret_cast(in[5]); + F *result = reinterpret_cast(out); + + // algorithm + memset(&result[0], 0, sizeof(F) * post_size); + for (std::uint32_t i = 0; i < pre_size; ++i) { + if (events[i]){ + for (I j = indptr[i]; j < indptr[i+1]; ++j) { + result[indices[j]] += values[j]; + } + } + } + } + - // The output + // TODO:: batch version of "event_sum_heter" CPU operator + template + void cpu_event_sum_batch_heter(void *out, const void **in) { + const std::uint32_t pre_size = *reinterpret_cast(in[0]); + const std::uint32_t post_size = *reinterpret_cast(in[1]); + const bool *events = reinterpret_cast(in[2]); + const I *indices = reinterpret_cast(in[3]); + const I *indptr = reinterpret_cast(in[4]); + const F *values = reinterpret_cast(in[5]); F *result = reinterpret_cast(out); // algorithm - memset(&result[0], 0, sizeof(result[0]) * post_size); + memset(&result[0], 0, sizeof(F) * post_size); for (std::uint32_t i = 0; i < pre_size; ++i) { if (events[i]){ for (I j = indptr[i]; j < indptr[i+1]; ++j) { @@ -50,6 +88,8 @@ namespace{ } } } + + } void cpu_event_sum_homo_f32_i32(void *out, const void **in){cpu_event_sum_homo(out, in);} diff --git a/extensions/lib/event_sum_gpu.cu b/extensions/lib/event_sum_gpu.cu index e0dee75be..79973c9ec 100644 --- a/extensions/lib/event_sum_gpu.cu +++ b/extensions/lib/event_sum_gpu.cu @@ -458,8 +458,7 @@ namespace brainpy_lib { if (threadIdx.x < num_event) { const unsigned int pre_i = (r * 32) + threadIdx.x; shared_events[threadIdx.x] = events[pre_i]; - if (shared_events[threadIdx.x]) - { + if (shared_events[threadIdx.x]) { shPreStartID[threadIdx.x] = indptr[pre_i]; shRowLength[threadIdx.x] = indptr[pre_i + 1] - shPreStartID[threadIdx.x]; } @@ -532,8 +531,7 @@ namespace brainpy_lib { if (threadIdx.x < num_event) { const unsigned int pre_i = (r * 32) + threadIdx.x; shared_events[threadIdx.x] = events[pre_i]; - if (shared_events[threadIdx.x]) - { + if (shared_events[threadIdx.x]) { shPreStartID[threadIdx.x] = indptr[pre_i]; shRowLength[threadIdx.x] = indptr[pre_i + 1] - shPreStartID[threadIdx.x]; } @@ -553,7 +551,6 @@ namespace brainpy_lib { } - template inline void gpu_event_sum4_heter(cudaStream_t stream, void **buffers, @@ -578,17 +575,16 @@ namespace brainpy_lib { cudaMemset(result, 0, sizeof(F) * post_size); event_sum4_heter_kernel<<>>(max_post_conn, - pre_size, - events, - indices, - indptr, - values, - result); + pre_size, + events, + indices, + indptr, + values, + result); ThrowIfError(cudaGetLastError()); } - } // namespace @@ -758,24 +754,15 @@ namespace brainpy_lib { } // heterogeneous event sum 3 - void gpu_event_sum3_heter_f32_i32(cudaStream_t stream, - void **buffers, - const char *opaque, - std::size_t opaque_len) { + void gpu_event_sum3_heter_f32_i32(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { gpu_event_sum3_heter(stream, buffers, opaque, opaque_len); } - void gpu_event_sum3_heter_f32_i64(cudaStream_t stream, - void **buffers, - const char *opaque, - std::size_t opaque_len) { + void gpu_event_sum3_heter_f32_i64(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { gpu_event_sum3_heter(stream, buffers, opaque, opaque_len); } - void gpu_event_sum3_heter_f64_i32(cudaStream_t stream, - void **buffers, - const char *opaque, - std::size_t opaque_len) { + void gpu_event_sum3_heter_f64_i32(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len) { gpu_event_sum3_heter(stream, buffers, opaque, opaque_len); } diff --git a/setup.py b/setup.py index 07f979667..ab529cf28 100644 --- a/setup.py +++ b/setup.py @@ -53,6 +53,13 @@ ''') from None + +packages = find_packages() +if 'docs' in packages: + packages.remove('docs') +if 'tests' in packages: + packages.remove('tests') + # setup setup( name='brainpy', @@ -62,7 +69,7 @@ long_description_content_type="text/markdown", author='BrainPy Team', author_email='chao.brain@qq.com', - packages=find_packages(), + packages=packages, python_requires='>=3.7', install_requires=[ 'numpy>=1.15',