Skip to content
Merged

fixes #237

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion brainpy/dyn/neurons/input_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __init__(

# variables
self.i = bm.Variable(bm.zeros(1))
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), trainable, self.varshape)
self.spike = variable(lambda s: bm.zeros(s, dtype=bool), mode, self.varshape)
if need_sort:
sort_idx = bm.argsort(self.times)
self.indices.value = self.indices[sort_idx]
Expand Down
6 changes: 5 additions & 1 deletion brainpy/math/operators/op_register.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def register_op(
A jitable JAX function.
"""
_check_brainpylib(register_op.__name__)
f = brainpylib.register_op(op_name, cpu_func, gpu_func, out_shapes, apply_cpu_func_to_gpu)
f = brainpylib.register_op(op_name,
cpu_func=cpu_func,
gpu_func=gpu_func,
out_shapes=out_shapes,
apply_cpu_func_to_gpu=apply_cpu_func_to_gpu)

def fixed_op(*inputs):
inputs = tuple([i.value if isinstance(i, JaxArray) else i for i in inputs])
Expand Down
13 changes: 7 additions & 6 deletions brainpy/math/operators/pre2post.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
# -*- coding: utf-8 -*-

import jax.numpy as jnp
from functools import partial
from typing import Union, Tuple

import jax.numpy as jnp
from jax import vmap, jit
from jax.lax import cond, scan, fori_loop
from functools import partial
from jax.lax import cond

from brainpy.errors import MathError
from brainpy.math.numpy_ops import as_device_array
from brainpy.math.jaxarray import JaxArray
from .utils import _check_brainpylib
from brainpy.math.numpy_ops import as_device_array
from brainpy.types import Tensor
from .pre2syn import pre2syn
from .syn2post import syn2post_mean
from brainpy.types import Tensor
from .utils import _check_brainpylib

try:
import brainpylib
Expand Down
2 changes: 1 addition & 1 deletion brainpy/math/operators/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
brainpylib = None


_BRAINPYLIB_MINIMAL_VERSION = '0.0.6'
_BRAINPYLIB_MINIMAL_VERSION = '0.0.5'


def _check_brainpylib(ops_name):
Expand Down