From 65238f6b647e08403dc77b98f16f942d48da4fd5 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 12 Jul 2022 20:04:30 +0800 Subject: [PATCH 1/2] updates --- brainpy/dyn/neurons/input_groups.py | 2 +- brainpy/math/operators/op_register.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/brainpy/dyn/neurons/input_groups.py b/brainpy/dyn/neurons/input_groups.py index d03fb4b17..80520091e 100644 --- a/brainpy/dyn/neurons/input_groups.py +++ b/brainpy/dyn/neurons/input_groups.py @@ -96,7 +96,7 @@ def __init__( # variables self.i = bm.Variable(bm.zeros(1)) - self.spike = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape) + self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape) if need_sort: sort_idx = bm.argsort(self.times) self.indices.value = self.indices[sort_idx] diff --git a/brainpy/math/operators/op_register.py b/brainpy/math/operators/op_register.py index 0be715ccb..c4c195523 100644 --- a/brainpy/math/operators/op_register.py +++ b/brainpy/math/operators/op_register.py @@ -47,7 +47,11 @@ def register_op( A jitable JAX function. """ _check_brainpylib(register_op.__name__) - f = brainpylib.register_op(op_name, cpu_func, gpu_func, out_shapes, apply_cpu_func_to_gpu) + f = brainpylib.register_op(op_name, + cpu_func=cpu_func, + gpu_func=gpu_func, + out_shapes=out_shapes, + apply_cpu_func_to_gpu=apply_cpu_func_to_gpu) def fixed_op(*inputs): inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs]) From 58fe093d7fb3786161e7de96019bca6cad6122b5 Mon Sep 17 00:00:00 2001 From: chaoming Date: Tue, 12 Jul 2022 20:14:27 +0800 Subject: [PATCH 2/2] fixs --- brainpy/math/operators/pre2post.py | 13 +++++++------ brainpy/math/operators/utils.py | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/brainpy/math/operators/pre2post.py b/brainpy/math/operators/pre2post.py index 0bf0e59e0..be6b8a40c 100644 --- a/brainpy/math/operators/pre2post.py +++ b/brainpy/math/operators/pre2post.py @@ -1,18 +1,19 @@ # -*- coding: utf-8 -*- -import jax.numpy as jnp +from functools import partial from typing import Union, Tuple + +import jax.numpy as jnp from jax import vmap, jit -from jax.lax import cond, scan, fori_loop -from functools import partial +from jax.lax import cond from brainpy.errors import MathError -from brainpy.math.numpy_ops import as_device_array from brainpy.math.jaxarray import JaxArray -from .utils import _check_brainpylib +from brainpy.math.numpy_ops import as_device_array +from brainpy.types import Tensor from .pre2syn import pre2syn from .syn2post import syn2post_mean -from brainpy.types import Tensor +from .utils import _check_brainpylib try: import brainpylib diff --git a/brainpy/math/operators/utils.py b/brainpy/math/operators/utils.py index bae41f232..730599fc3 100644 --- a/brainpy/math/operators/utils.py +++ b/brainpy/math/operators/utils.py @@ -8,7 +8,7 @@ brainpylib = None -_BRAINPYLIB_MINIMAL_VERSION = '0.0.6' +_BRAINPYLIB_MINIMAL_VERSION = '0.0.5' def _check_brainpylib(ops_name):