Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions brainpy/dyn/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,12 @@ def register_delay(
self.global_delay_data[identifier] = (delay, delay_target)
self.local_delay_vars[identifier] = delay
else:
if self.global_delay_data[identifier][0].num_delay_step - 1 < max_delay_step:
delay = self.global_delay_data[identifier][0]
if delay is None:
delay = bm.LengthDelay(delay_target, max_delay_step, initial_delay_data)
self.global_delay_data[identifier] = (delay, delay_target)
self.local_delay_vars[identifier] = delay
elif delay.num_delay_step - 1 < max_delay_step:
self.global_delay_data[identifier][0].reset(delay_target, max_delay_step, initial_delay_data)
else:
self.global_delay_data[identifier] = (None, delay_target)
Expand Down Expand Up @@ -181,15 +186,16 @@ def get_delay_data(
return self.global_delay_data[identifier][1].value

if identifier in self.global_delay_data:
if isinstance(delay_step, (int, np.integer)):
# if isinstance(delay_step, (int, np.integer)):
if bm.ndim(delay_step) == 0:
return self.global_delay_data[identifier][0](delay_step, *indices)
else:
if len(indices) == 0:
indices = (jnp.arange(delay_step.size),)
return self.global_delay_data[identifier][0](delay_step, *indices)

elif identifier in self.local_delay_vars:
if isinstance(delay_step, (int, np.integer)):
if bm.ndim(delay_step) == 0:
return self.local_delay_vars[identifier](delay_step)
else:
if len(indices) == 0:
Expand Down Expand Up @@ -685,11 +691,11 @@ def __init__(
self.output.register_master(master=self)

# synaptic plasticity
if stp is None: stp = SynSTP()
if not isinstance(stp, SynSTP):
raise TypeError(f'plasticity must be instance of {SynSTP.__name__}, but we got {type(stp)}')
self.stp: SynSTP = stp
self.stp.register_master(master=self)
if stp is not None:
if not isinstance(stp, SynSTP):
raise TypeError(f'plasticity must be instance of {SynSTP.__name__}, but we got {type(stp)}')
stp.register_master(master=self)
self.stp: Optional[SynSTP] = stp

def init_weights(
self,
Expand Down Expand Up @@ -734,7 +740,7 @@ def init_weights(
return weight, conn_mask

def syn2post_with_all2all(self, syn_value, syn_weight):
if bm.size(syn_weight) == 1:
if bm.ndim(syn_weight) == 0:
if self.trainable:
post_vs = bm.sum(syn_value, keepdims=True, axis=tuple(range(syn_value.ndim))[1:])
else:
Expand All @@ -750,7 +756,7 @@ def syn2post_with_one2one(self, syn_value, syn_weight):
return syn_value * syn_weight

def syn2post_with_dense(self, syn_value, syn_weight, conn_mat):
if bm.size(syn_weight) == 1:
if bm.ndim(syn_weight) == 0:
post_vs = (syn_weight * syn_value) @ conn_mat
else:
post_vs = syn_value @ (syn_weight * conn_mat)
Expand Down
22 changes: 11 additions & 11 deletions brainpy/dyn/channels/K.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'IK_p4_markov',
'IKDR_Ba2002',
'IK_TM1991',
'IK_HH',
'IK_HH1952',

'IKA_p4q_ss',
'IKA1_HM1992',
Expand Down Expand Up @@ -269,7 +269,7 @@ def f_p_beta(self, V):
return 0.5 * bm.exp((10 - V + self.V_sh) / 40)


class IK_HH(IK_p4_markov):
class IK_HH1952(IK_p4_markov):
r"""The potassium channel described by Hodgkin–Huxley model [1]_.

The dynamics of this channel is given by:
Expand Down Expand Up @@ -307,7 +307,7 @@ class IK_HH(IK_p4_markov):

See Also
--------
INa_HH
INa_HH1952
"""

def __init__(
Expand All @@ -322,14 +322,14 @@ def __init__(
name: str = None,
trainable: bool = False,
):
super(IK_HH, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
phi=phi,
E=E,
g_max=g_max,
trainable=trainable)
super(IK_HH1952, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
phi=phi,
E=E,
g_max=g_max,
trainable=trainable)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
Expand Down
22 changes: 11 additions & 11 deletions brainpy/dyn/channels/Na.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
'INa_p3q_markov',
'INa_Ba2002',
'INa_TM1991',
'INa_HH',
'INa_HH1952',
]


Expand Down Expand Up @@ -284,7 +284,7 @@ def f_q_beta(self, V):
return 4. / (1 + bm.exp(-(V - self.V_sh - 40) / 5))


class INa_HH(INa_p3q_markov):
class INa_HH1952(INa_p3q_markov):
r"""The sodium current model described by Hodgkin–Huxley model [1]_.

The dynamics of this sodium current model is given by:
Expand Down Expand Up @@ -331,7 +331,7 @@ class INa_HH(INa_p3q_markov):

See Also
--------
IK_HH
IK_HH1952
"""

def __init__(
Expand All @@ -346,14 +346,14 @@ def __init__(
name: str = None,
trainable: bool = False,
):
super(INa_HH, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
E=E,
phi=phi,
g_max=g_max,
trainable=trainable)
super(INa_HH1952, self).__init__(size,
keep_size=keep_size,
name=name,
method=method,
E=E,
phi=phi,
g_max=g_max,
trainable=trainable)
self.V_sh = parameter(V_sh, self.varshape, allow_none=False)

def f_p_alpha(self, V):
Expand Down
3 changes: 3 additions & 0 deletions brainpy/dyn/layers/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ def update(self, sha, x):
return y
return y + self.b.value

def reset_state(self, batch_size=None):
pass


class Conv1D(GeneralConv):
def __init__(
Expand Down
8 changes: 1 addition & 7 deletions brainpy/dyn/layers/dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class Dropout(TrainingSystem):
neural networks from overfitting." The journal of machine learning
research 15.1 (2014): 1929-1958.
"""
def __init__(self, prob, seed=None, trainable=False, name=None):
def __init__(self, prob, seed=None, trainable=True, name=None):
super(Dropout, self).__init__(trainable=trainable, name=name)
self.prob = prob
self.rng = bm.random.RandomState(seed=seed)
Expand All @@ -47,9 +47,3 @@ def update(self, sha, x):
return bm.where(keep_mask, x / self.prob, 0.)
else:
return x

def reset(self, batch_size=1):
pass

def reset_state(self, batch_size=1):
pass
13 changes: 6 additions & 7 deletions brainpy/dyn/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,15 @@ def __init__(
W_initializer: Union[Initializer, Callable, Tensor] = XavierNormal(),
b_initializer: Optional[Union[Initializer, Callable, Tensor]] = ZeroInit(),
trainable: bool = True,
name: str = None
name: str = None,
fit_online: bool = False,
fit_offline: bool = False,
):
super(Dense, self).__init__(trainable=trainable, name=name)

self.fit_online = fit_online
self.fit_offline = fit_offline

# shape
self.num_in = num_in
self.num_out = num_out
Expand Down Expand Up @@ -90,12 +95,6 @@ def update(self, sha, x):
self.fit_record['output'] = res
return res

def reset(self, batch_size=1):
pass

def reset_state(self, batch_size=1):
pass

def online_init(self):
if self.b is None:
num_input = self.num_in
Expand Down
71 changes: 45 additions & 26 deletions brainpy/dyn/layers/nvar.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@
import numpy as np

import brainpy.math as bm
from brainpy.tools.checking import (check_integer, check_sequence)
from brainpy.dyn.training import TrainingSystem

from brainpy.tools.checking import (check_integer, check_sequence)

__all__ = [
'NVAR'
Expand Down Expand Up @@ -69,8 +68,8 @@ def __init__(
order: Union[int, Sequence[int]] = None,
stride: int = 1,
constant: bool = False,
trainable: bool = False,
name: str = None
trainable: bool = True,
name: str = None,
):
super(NVAR, self).__init__(trainable=trainable, name=name)

Expand All @@ -93,8 +92,11 @@ def __init__(

# delay variables
self.idx = bm.Variable(jnp.asarray([0]))
batch_size = 1 # first initialize the state with batch size = 1
self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)))
if trainable:
batch_size = 1 # first initialize the state with batch size = 1
self.store = bm.Variable(jnp.zeros((self.num_delay, batch_size, self.num_in)), batch_axis=1)
else:
self.store = bm.Variable(jnp.zeros((self.num_delay, self.num_in)))

# linear dimension
self.linear_dim = self.delay * num_in
Expand All @@ -115,35 +117,52 @@ def __init__(
if self.constant:
self.num_out += 1

def reset(self, batch_size=1):
self.idx[0] = 0
self.reset_state(batch_size)

def reset_state(self, batch_size=1):
def reset_state(self, batch_size=None):
"""Reset the node state which depends on batch size."""
self.idx[0] = 0
# To store the last inputs.
# Note, the batch axis is not in the first dimension, so we
# manually handle the state of NVAR, rather return it.
self.store._value = jnp.zeros((self.num_delay, batch_size, self.num_in))
if batch_size is None:
self.store.value = jnp.zeros((self.num_delay, self.num_in))
else:
self.store.value = jnp.zeros((self.num_delay, batch_size, self.num_in))

def update(self, sha, x):
all_parts = []
select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay
# 1. Store the current input
self.store[self.idx[0]] = x
# 2. Linear part:
# select all previous inputs, including the current, with strides
select_ids = (self.idx[0] - jnp.arange(0, self.num_delay, self.stride)) % self.num_delay
linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature)
linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1))
# 3. constant
if self.constant:
constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype)
all_parts.append(constant)
all_parts.append(linear_parts)
# 3. Nonlinear part:
# select monomial terms and compute them
for ids in self.comb_ids:
all_parts.append(jnp.prod(linear_parts[:, ids], axis=2))

if self.trainable:
# 2. Linear part:
# select all previous inputs, including the current, with strides
linear_parts = jnp.moveaxis(self.store[select_ids], 0, 1) # (num_batch, num_time, num_feature)
linear_parts = jnp.reshape(linear_parts, (linear_parts.shape[0], -1))
# 3. constant
if self.constant:
constant = jnp.ones((linear_parts.shape[0], 1), dtype=x.dtype)
all_parts.append(constant)
all_parts.append(linear_parts)
# 3. Nonlinear part:
# select monomial terms and compute them
for ids in self.comb_ids:
all_parts.append(jnp.prod(linear_parts[:, ids], axis=2))

else:
# 2. Linear part:
# select all previous inputs, including the current, with strides
linear_parts = self.store[select_ids].flatten() # (num_time x num_feature,)
# 3. constant
if self.constant:
constant = jnp.ones((1,), dtype=x.dtype)
all_parts.append(constant)
all_parts.append(linear_parts)
# 3. Nonlinear part:
# select monomial terms and compute them
for ids in self.comb_ids:
all_parts.append(jnp.prod(linear_parts[ids], axis=1))

# 4. Finally
self.idx.value = (self.idx + 1) % self.num_delay
return jnp.concatenate(all_parts, axis=-1)
Expand Down
14 changes: 5 additions & 9 deletions brainpy/dyn/layers/reservoir.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional, Union, Callable, Tuple

import brainpy.math as bm
from brainpy.initialize import Normal, ZeroInit, Initializer, parameter
from brainpy.initialize import Normal, ZeroInit, Initializer, parameter, variable
from brainpy.tools.checking import check_float, check_initializer, check_string
from brainpy.tools.others import to_size
from brainpy.dyn.training import TrainingSystem
Expand Down Expand Up @@ -102,7 +102,7 @@ def __init__(
noise_rec: float = 0.,
noise_type: str = 'normal',
seed: Optional[int] = None,
trainable: bool = False,
trainable: bool = True,
name: str = None
):
super(Reservoir, self).__init__(trainable=trainable, name=name)
Expand Down Expand Up @@ -179,14 +179,10 @@ def __init__(
self.bias = None if (self.bias is None) else bm.TrainVar(self.bias)

# initialize state
batch_size = 1
self.state = bm.Variable(bm.zeros((batch_size,) + self.output_shape))
self.state = variable(bm.zeros, trainable, self.output_shape)

def reset(self, batch_size=1):
self.state._value = bm.zeros((batch_size,) + self.output_shape).value

def reset_state(self, batch_size=1):
pass
def reset_state(self, batch_size=None):
self.state.value = variable(bm.zeros, batch_size, self.output_shape)

def update(self, sha, x):
"""Feedforward output."""
Expand Down
Loading