Skip to content

Commit

Permalink
Merge pull request #13 from PKU-NIP-Lab/develop
Browse files Browse the repository at this point in the history
New update
  • Loading branch information
chaoming0625 committed Nov 19, 2020
2 parents 4e05915 + 3c1f740 commit 24030a4
Show file tree
Hide file tree
Showing 20 changed files with 716 additions and 290 deletions.
3 changes: 2 additions & 1 deletion README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,8 @@ The following packages need to be installed to use ``BrainPy``:

Packages recommended to install:

- Numba >= 0.40.0
- Numba >= 0.50.0
- TensorFlow >= 2.4


Define a Hodgkin–Huxley neuron model
Expand Down
11 changes: 0 additions & 11 deletions brainpy/core_system/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,5 @@
# -*- coding: utf-8 -*-

"""
This module defines the core of the framework, including the
abstraction of ``Neurons``, ``Synapses``, ``Monitor``, ``Network``,
and numerical integrator methods.
The core is so small, and the overall framework is easy to
understand. Using it, you can easily write your own neurons,
synapses, etc.
"""


from .base import *
from .types import *
from .neurons import *
Expand Down
50 changes: 33 additions & 17 deletions brainpy/core_system/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
# -*- coding: utf-8 -*-

import re
import inspect
import time
import typing
from copy import deepcopy

import autopep8
Expand Down Expand Up @@ -40,12 +42,12 @@ class BaseType(object):
"""

def __init__(self,
requires,
steps,
name,
vector_based=True,
heter_params_replace=None,
extra_functions=()):
requires: typing.Dict,
steps: typing.Union[typing.Callable, typing.List, typing.Tuple],
name: str,
vector_based: bool = True,
heter_params_replace: typing.Dict = None,
extra_functions: typing.Union[typing.Callable, typing.List, typing.Tuple] = ()):
# type : neuron based or group based code
# ---------------------------------------
self.vector_based = vector_based
Expand Down Expand Up @@ -75,16 +77,13 @@ def __init__(self,
raise ModelDefError(f'In "requires", each value must be a {TypeChecker.__name__}, '
f'but got "{type(v)}" for "{k}".')

# variables
# ----------
self.variables = self.requires['ST']._vars

# steps
# ------
self.steps = []
self.step_names = []
self.step_scopes = dict()
self.step_args = set()
step_vars = set()
if callable(steps):
steps = [steps]
elif isinstance(steps, (list, tuple)):
Expand All @@ -96,14 +95,17 @@ def __init__(self,
assert callable(func)
except AssertionError:
raise ModelDefError('"steps" must be a list/tuple of callable functions.')

# function name
func_name = tools.get_func_name(func, replace=True)
self.step_names.append(func_name)

# function arg
for arg in inspect.getfullargspec(func).args:
if arg in ARG_KEYWORDS:
continue
self.step_args.add(arg)

# function scope
scope = tools.get_func_scope(func, include_dispatcher=True)
for k, v in scope.items():
Expand All @@ -113,12 +115,25 @@ def __init__(self,
f'{self.name}: {k} = {v} and {k} = {self.step_scopes[k]}.\n'
f'This maybe cause a grievous mistake in the future. Please change!')
self.step_scopes[k] = v

# function
self.steps.append(func)

# set attribute
setattr(self, func_name, func)

# get the STATE variables
step_vars.update(re.findall(r'ST\[[\'"](\w+)[\'"]\]', tools.get_main_code(func)))

self.step_args = list(self.step_args)

# variables
# ----------
self.variables = self.requires['ST']._vars
for var in step_vars:
if var not in self.variables:
raise ModelDefError(f'Variable "{var}" is used in {self.name}, but not defined in "ST".')

# integrators
# -----------
self.integrators = []
Expand Down Expand Up @@ -154,7 +169,7 @@ def __init__(self,
# extra functions
# ---------------
if callable(extra_functions):
extra_functions = (extra_functions, )
extra_functions = (extra_functions,)
try:
assert isinstance(extra_functions, (tuple, list))
if len(extra_functions):
Expand All @@ -179,6 +194,7 @@ class ParsUpdate(dict):
- model : the model which this ParsUpdate belongs to
"""

def __init__(self,
all_pars,
num,
Expand Down Expand Up @@ -300,12 +316,12 @@ class BaseEnsemble(object):
"""

def __init__(self,
name,
num,
model,
monitors,
pars_update,
cls_type):
name: str,
num: int,
model: BaseType,
monitors: typing.Tuple,
pars_update: typing.Dict,
cls_type: str):
# class type
# -----------
assert cls_type in [_NEU_GROUP, _SYN_CONN], f'Only support "{_NEU_GROUP}" and "{_SYN_CONN}".'
Expand Down
20 changes: 17 additions & 3 deletions brainpy/core_system/neurons.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-

import typing

from .base import BaseEnsemble
from .base import BaseType
from .constants import _NEU_GROUP
Expand All @@ -21,7 +23,13 @@ class NeuType(BaseType):
It can be defined based on a group of neurons or a single neuron.
"""

def __init__(self, name, requires, steps, vector_based=True, heter_params_replace=None, extra_functions=()):
def __init__(self,
name: str,
requires: dict,
steps: typing.Union[typing.Callable, typing.List, typing.Tuple],
vector_based: bool = True,
heter_params_replace: typing.Dict = None,
extra_functions: typing.Union[typing.Callable, typing.List, typing.Tuple] = ()):
super(NeuType, self).__init__(requires=requires,
steps=steps,
name=name,
Expand All @@ -47,7 +55,12 @@ class NeuGroup(BaseEnsemble):
The name of the neuron group.
"""

def __init__(self, model, geometry, pars_update=None, monitors=None, name=None):
def __init__(self,
model,
geometry,
pars_update=None,
monitors=None,
name=None):
# name
# -----
if name is None:
Expand Down Expand Up @@ -83,7 +96,8 @@ def __init__(self, model, geometry, pars_update=None, monitors=None, name=None):
try:
assert isinstance(model, NeuType)
except AssertionError:
raise ModelUseError(f'{NeuGroup.__name__} receives an instance of {NeuType.__name__}, '
raise ModelUseError(f'{NeuGroup.__name__} receives an '
f'instance of {NeuType.__name__}, '
f'not {type(model).__name__}.')

# initialize
Expand Down
29 changes: 29 additions & 0 deletions brainpy/core_system/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
from .. import numpy as np
from .. import profile
from .. import tools
from . neurons import NeuGroup
from ..errors import ModelDefError
from ..errors import ModelUseError
from ..integration.integrator import Integrator
from ..integration.sympy_tools import get_mapping_scope


class Runner(object):
"""Basic runner class.
Parameters
----------
ensemble : NeuGroup, SynConn
The ensemble of the models.
"""
def __init__(self, ensemble):
# ensemble: NeuGroup / SynConn
self.ensemble = ensemble
Expand Down Expand Up @@ -1022,3 +1030,24 @@ def set_schedule(self, schedule):
except AssertionError:
raise ModelUseError(f'Unknown step function "{s}" for model "{self._name}".')
self._schedule = schedule


class TrajectoryRunner(Runner):
"""Runner class for trajectory.
Parameters
----------
ensemble : NeuGroup
The neuron ensemble.
"""
def __init__(self, ensemble, target_vars):
try:
assert isinstance(ensemble, NeuGroup)
except AssertionError:
raise ModelUseError('TrajectoryRunner only supports the instance of NeuGroup.')
self.target_vars = target_vars
super(TrajectoryRunner, self).__init__(ensemble=ensemble)




3 changes: 0 additions & 3 deletions brainpy/core_system/synapses.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from .. import numpy as np
from .. import profile
from .. import tools
from ..connectivity import mat2ij
from ..connectivity import Connector
from ..connectivity import post2syn
from ..connectivity import pre2syn
from ..errors import ModelDefError
from ..errors import ModelUseError

Expand Down
13 changes: 13 additions & 0 deletions brainpy/dynamics/trajectory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# -*- coding: utf-8 -*-


def format_model():
pass




def plot_trajectory(neu, target_vars, initial_values=()):
# format initial values
pass

18 changes: 10 additions & 8 deletions brainpy/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
_math_funcs = [
# Basics
# --------
'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', # 'angle',
'real', 'imag', 'conj', 'conjugate', 'ndim', 'isreal', 'isscalar', # 'angle',

# Arithmetic operations
# ----------------------
Expand Down Expand Up @@ -53,6 +53,7 @@
]

isreal = numpy.isreal
isscalar = numpy.isscalar
real = numpy.real
imag = numpy.imag
conj = numpy.conj
Expand Down Expand Up @@ -446,14 +447,11 @@
kaiser = numpy.kaiser

# https://numpy.org/doc/stable/reference/constants.html
_constants = ['e', 'pi', 'inf', 'nan', 'newaxis', 'euler_gamma']
_constants = ['e', 'pi', 'inf']

e = numpy.e
pi = numpy.pi
inf = numpy.inf
nan = numpy.nan
newaxis = numpy.newaxis
euler_gamma = numpy.euler_gamma

# https://numpy.org/doc/stable/reference/routines.linalg.html
_linear_algebra = [
Expand Down Expand Up @@ -531,11 +529,15 @@ def _reload(backend):
else:
global_vars[__ops] = getattr(numpy, __ops)

elif backend == 'jax':
jax = import_module('jax')
elif backend == 'tf-numpy':
tf_numpy = import_module('tensorflow.experimental.numpy')
from ._backends import tensorflow

for __ops in _all:
global_vars[__ops] = getattr(jax, __ops)
if hasattr(tf_numpy, __ops):
global_vars[__ops] = getattr(tf_numpy, __ops)
else:
global_vars[__ops] = getattr(tensorflow, __ops)

else:
raise ValueError(f'Unknown backend device: {backend}')
Expand Down

0 comments on commit 24030a4

Please sign in to comment.