Skip to content

Commit

Permalink
Merge pull request #520 from ztqakita/master
Browse files Browse the repository at this point in the history
[dyn] add neuron scaling
  • Loading branch information
chaoming0625 committed Oct 24, 2023
2 parents 1d24470 + 809293b commit 5cd1832
Show file tree
Hide file tree
Showing 11 changed files with 347 additions and 110 deletions.
26 changes: 26 additions & 0 deletions brainpy/_src/dyn/neurons/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
mode: Optional[bm.Mode] = None,
name: Optional[str] = None,
method: str = 'exp_auto',
scaling: Optional[bm.Scaling] = None,

spk_fun: Callable = bm.surrogate.InvSquareGrad(),
spk_type: Any = None,
Expand All @@ -43,6 +44,10 @@ def __init__(
self.spk_fun = is_callable(spk_fun)
self.detach_spk = detach_spk
self._spk_type = spk_type
if scaling is None:
self.scaling = bm.get_membrane_scaling()
else:
self.scaling = scaling

@property
def spk_type(self):
Expand All @@ -51,5 +56,26 @@ def spk_type(self):
else:
return self._spk_type

def offset_scaling(self, x, bias=None, scale=None):
s = self.scaling.offset_scaling(x, bias=bias, scale=scale)
if isinstance(x, bm.Array):
x.value = s
return x
return s

def std_scaling(self, x, scale=None):
s = self.scaling.std_scaling(x, scale=scale)
if isinstance(x, bm.Array):
x.value = s
return x
return s

def inv_scaling(self, x, scale=None):
s = self.scaling.inv_scaling(x, scale=scale)
if isinstance(x, bm.Array):
x.value = s
return x
return s


GradNeuDyn.__doc__ = GradNeuDyn.__doc__.format(pneu=pneu_doc, dpneu=dpneu_doc)

0 comments on commit 5cd1832

Please sign in to comment.