Skip to content

Commit

Permalink
Merge pull request #524 from chaoming0625/updates
Browse files Browse the repository at this point in the history
[math] the interface for operator registration
  • Loading branch information
Routhleck committed Oct 30, 2023
2 parents 06276ee + 1e857c7 commit e6c6664
Show file tree
Hide file tree
Showing 32 changed files with 527 additions and 401 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ We provide a Binder environment for BrainPy. You can use the following button to
- **[BrainPy](https://github.com/brainpy/BrainPy)**: The solution for the general-purpose brain dynamics programming.
- **[brainpy-examples](https://github.com/brainpy/examples)**: Comprehensive examples of BrainPy computation.
- **[brainpy-datasets](https://github.com/brainpy/datasets)**: Neuromorphic and Cognitive Datasets for Brain Dynamics Modeling.
- [第一届神经计算建模与编程培训班 (BrainPy First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course)
- [第一届神经计算建模与编程培训班 (First Training Course on Neural Modeling and Programming)](https://github.com/brainpy/1st-neural-modeling-and-programming-course)


## Citing
Expand Down
72 changes: 46 additions & 26 deletions brainpy/_src/dnn/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from typing import Dict, Optional, Union, Callable

import numba
import jax
import numpy as np
import jax.numpy as jnp
Expand Down Expand Up @@ -227,6 +228,45 @@ def update(self, x):
return x


def event_mm(pre_spike, post_inc, weight, w_min, w_max):
return weight


@numba.njit
def event_mm_imp(outs, ins):
pre_spike, post_inc, weight, w_min, w_max = ins
w_min = w_min[()]
w_max = w_max[()]
outs = outs
outs.fill(weight)
for i in range(pre_spike.shape[0]):
if pre_spike[i]:
outs[i] = np.clip(outs[i] + post_inc, w_min, w_max)


event_left_mm = bm.CustomOpByNumba(event_mm, event_mm_imp, multiple_results=False)


def event_mm2(post_spike, pre_inc, weight, w_min, w_max):
return weight


@numba.njit
def event_mm_imp2(outs, ins):
post_spike, pre_inc, weight, w_min, w_max = ins
w_min = w_min[()]
w_max = w_max[()]
outs = outs
outs.fill(weight)
for j in range(post_spike.shape[0]):
if post_spike[j]:
outs[:, j] = np.clip(outs[:, j] + pre_inc, w_min, w_max)


event_right_mm = bm.CustomOpByNumba(event_mm2, event_mm_imp2, multiple_results=False)



class AllToAll(Layer, SupportSTDP):
"""Synaptic matrix multiplication with All2All connections.
Expand Down Expand Up @@ -289,20 +329,15 @@ def update(self, pre_val):
post_val = pre_val @ self.weight
return post_val

def update_STDP(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
def stdp_update_on_pre(self, pre_spike, trace, w_min=None, w_max=None):
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)
self.weight.value = event_left_mm(pre_spike, trace, self.weight, w_min, w_max)

def stdp_update_on_post(self, post_spike, trace, w_min=None, w_max=None):
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight.value = event_right_mm(post_spike, trace, self.weight, w_min, w_max)


class OneToOne(Layer, SupportSTDP):
Expand Down Expand Up @@ -338,21 +373,6 @@ def __init__(
def update(self, pre_val):
return pre_val * self.weight

def update_STDP(self, dW, constraints=None):
if isinstance(self.weight, float):
raise ValueError(f'Cannot update the weight of a constant node.')
if not isinstance(dW, (bm.ndarray, jnp.ndarray, np.ndarray)):
raise ValueError(f'"delta_weight" must be a array, but got {type(dW)}')
dW = dW.sum(axis=0)
if self.weight.shape != dW.shape:
raise ValueError(f'The shape of delta_weight {dW.shape} '
f'should be the same as the shape of weight {self.weight.shape}.')
if not isinstance(self.weight, bm.Variable):
self.tracing_variable('weight', self.weight, self.weight.shape)
self.weight += dW
if constraints is not None:
self.weight.value = constraints(self.weight)


class MaskedLinear(Layer, SupportSTDP):
r"""Synaptic matrix multiplication with masked dense computation.
Expand Down
8 changes: 2 additions & 6 deletions brainpy/_src/dyn/others/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,5 @@ def update(self):
def return_info(self):
return self.spike

def reset_state(self, batch_size=None, **kwargs):
self.spike = variable_(partial(jnp.zeros, dtype=self.spk_type),
self.varshape,
batch_size,
axis_names=self.sharding,
batch_axis_name=bm.sharding.BATCH_AXIS)
def reset_state(self, batch_or_mode=None, **kwargs):
self.spike = self.init_variable(partial(jnp.zeros, dtype=self.spk_type), batch_or_mode)
36 changes: 19 additions & 17 deletions brainpy/_src/dyn/projections/plasticity.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from brainpy._src.delay import register_delay_by_return
from brainpy._src.dyn.synapses.abstract_models import Expon
from brainpy._src.dynsys import DynamicalSystem, Projection
from brainpy._src.initialize import parameter
from brainpy._src.mixin import (JointType, ParamDescriber, SupportAutoDelay,
BindCondData, AlignPost, SupportSTDP)
from brainpy.types import ArrayType
Expand Down Expand Up @@ -111,7 +110,8 @@ def run(i, I_pre, I_post):
A1: float. The increment of :math:`A_{pre}` produced by a spike. Must be a positive value.
A2: float. The increment of :math:`A_{post}` produced by a spike. Must be a positive value.
W_max: float. The maximum weight.
pre: DynamicalSystem. The pre-synaptic neuron group.
W_min: float. The minimum weight.
pre: DynamicalSystem. The pre-synaptic neuron group.
delay: int, float. The pre spike delay length. (ms)
syn: DynamicalSystem. The synapse model.
comm: DynamicalSystem. The communication model, for example, dense or sparse connection layers.
Expand All @@ -135,6 +135,7 @@ def __init__(
A1: Union[float, ArrayType, Callable] = 0.96,
A2: Union[float, ArrayType, Callable] = 0.53,
W_max: Optional[float] = None,
W_min: Optional[float] = None,
# others
out_label: Optional[str] = None,
name: Optional[str] = None,
Expand All @@ -144,21 +145,21 @@ def __init__(

# synaptic models
check.is_instance(pre, JointType[DynamicalSystem, SupportAutoDelay])
check.is_instance(syn, ParamDescriber[DynamicalSystem])
check.is_instance(comm, JointType[DynamicalSystem, SupportSTDP])
check.is_instance(syn, ParamDescriber[DynamicalSystem])
check.is_instance(out, ParamDescriber[JointType[DynamicalSystem, BindCondData]])
check.is_instance(post, DynamicalSystem)
self.pre_num = pre.num
self.post_num = post.num
self.comm = comm
self.syn = syn
self._is_align_post = issubclass(syn.cls, AlignPost)

# delay initialization
delay_cls = register_delay_by_return(pre)
delay_cls.register_entry(self.name, delay)

# synapse and output initialization
if issubclass(syn.cls, AlignPost):
if self._is_align_post:
syn_cls, out_cls = align_post_add_bef_update(out_label, syn_desc=syn, out_desc=out, post=post,
proj_name=self.name)
else:
Expand All @@ -171,24 +172,27 @@ def __init__(
self.refs['delay'] = delay_cls
self.refs['syn'] = syn_cls # invisible to ``self.node()``
self.refs['out'] = out_cls # invisible to ``self.node()``
self.refs['comm'] = comm

# tracing pre-synaptic spikes using Exponential model
self.refs['pre_trace'] = _init_trace_by_align_pre2(pre, delay, Expon.desc(pre.num, tau=tau_s))

# tracing post-synaptic spikes using Exponential model
self.refs['post_trace'] = _init_trace_by_align_pre2(post, None, Expon.desc(post.num, tau=tau_t))

# synapse parameters
self.W_max = W_max
self.tau_s = parameter(tau_s, sizes=self.pre_num)
self.tau_t = parameter(tau_t, sizes=self.post_num)
self.A1 = parameter(A1, sizes=self.pre_num)
self.A2 = parameter(A2, sizes=self.post_num)
self.W_min = W_min
self.tau_s = tau_s
self.tau_t = tau_t
self.A1 = A1
self.A2 = A2

def update(self):
# pre-synaptic spikes
pre_spike = self.refs['delay'].at(self.name) # spike
# pre-synaptic variables
if issubclass(self.syn.cls, AlignPost):
if self._is_align_post:
# For AlignPost, we need "pre spikes @ comm matrix" for computing post-synaptic conductance
x = pre_spike
else:
Expand All @@ -201,19 +205,17 @@ def update(self):
post_spike = self.refs['post'].spike

# weight updates
Apre = self.refs['pre_trace'].g
Apost = self.refs['post_trace'].g
delta_w = - bm.outer(pre_spike, Apost * self.A2) + bm.outer(Apre * self.A1, post_spike)
self.comm.update_STDP(delta_w, constraints=self._weight_clip)
self.comm.stdp_update_on_pre(pre_spike, -Apost * self.A2, self.W_min, self.W_max)
Apre = self.refs['pre_trace'].g
self.comm.stdp_update_on_post(post_spike, Apre * self.A1, self.W_min, self.W_max)

# currents
# synaptic currents
current = self.comm(x)
if issubclass(self.syn.cls, AlignPost):
if self._is_align_post:
self.refs['syn'].add_current(current) # synapse post current
else:
self.refs['out'].bind_cond(current) # align pre
return current

def _weight_clip(self, w):
return w if self.W_max is None else bm.minimum(w, self.W_max)

26 changes: 18 additions & 8 deletions brainpy/_src/dyn/projections/tests/test_STDP.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-

import os
os.environ['JAX_TRACEBACK_FILTERING'] = 'off'
import matplotlib.pyplot as plt
import numpy as np
from absl.testing import parameterized

import brainpy as bp
Expand All @@ -20,15 +20,18 @@ def __init__(self, num_pre, num_post):
self.syn = bp.dyn.STDP_Song2000(
pre=self.pre,
delay=1.,
comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
weight=lambda s: bm.Variable(bm.random.rand(*s) * 0.1)),
# comm=bp.dnn.EventCSRLinear(bp.conn.FixedProb(1, pre=self.pre.num, post=self.post.num),
# weight=bp.init.Uniform(0., 0.1)),
comm=bp.dnn.AllToAll(self.pre.num, self.post.num, weight=bp.init.Uniform(.1, 0.1)),
syn=bp.dyn.Expon.desc(self.post.varshape, tau=5.),
out=bp.dyn.COBA.desc(E=0.),
post=self.post,
tau_s=16.8,
tau_t=33.7,
A1=0.96,
A2=0.53,
W_min=0.,
W_max=1.
)

def update(self, I_pre, I_post):
Expand All @@ -39,7 +42,7 @@ def update(self, I_pre, I_post):
Apre = self.syn.refs['pre_trace'].g
Apost = self.syn.refs['post_trace'].g
current = self.post.sum_inputs(self.post.V)
return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight
return self.pre.spike, self.post.spike, conductance, Apre, Apost, current, self.syn.comm.weight.flatten()

duration = 300.
I_pre = bp.inputs.section_input([0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0, 30, 0],
Expand All @@ -53,7 +56,14 @@ def run(i, I_pre, I_post):
pre_spike, post_spike, g, Apre, Apost, current, W = net.step_run(i, I_pre, I_post)
return pre_spike, post_spike, g, Apre, Apost, current, W

indices = bm.arange(0, duration, bm.dt)
bm.for_loop(run, [indices, I_pre, I_post], jit=True)
bm.clear_buffer_memory()
indices = np.arange(int(duration / bm.dt))
pre_spike, post_spike, g, Apre, Apost, current, W = bm.for_loop(run, [indices, I_pre, I_post])

fig, gs = bp.visualize.get_figure(4, 1, 3, 10)
bp.visualize.line_plot(indices, g, ax=fig.add_subplot(gs[0, 0]))
bp.visualize.line_plot(indices, Apre, ax=fig.add_subplot(gs[1, 0]))
bp.visualize.line_plot(indices, Apost, ax=fig.add_subplot(gs[2, 0]))
bp.visualize.line_plot(indices, W, ax=fig.add_subplot(gs[3, 0]))
plt.show()

bm.clear_buffer_memory()
2 changes: 1 addition & 1 deletion brainpy/_src/dyn/rates/tests/test_reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class Test_Reservoir(parameterized.TestCase):
def test_Reservoir(self, mode):
bm.random.seed()
input = bm.random.randn(10, 3)
layer = bp.dnn.Reservoir(input_shape=3,
layer = bp.dyn.Reservoir(input_shape=3,
num_out=5,
mode=mode)
if mode in [bm.NonBatchingMode()]:
Expand Down

0 comments on commit e6c6664

Please sign in to comment.