Skip to content

Commit

Permalink
Merge pull request #42 from PKU-NIP-Lab/develop
Browse files Browse the repository at this point in the history
Improve for Numba CUDA backend
  • Loading branch information
chaoming0625 committed Apr 14, 2021
2 parents 7711b21 + 3768e6a commit eadb391
Show file tree
Hide file tree
Showing 8 changed files with 333 additions and 117 deletions.
9 changes: 7 additions & 2 deletions brainpy/backend/drivers/numba_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ def visit_Expr(self, node, level=0):
def visit_Expression(self, node, level=0):
self.visit_node_not_assign(node, level)

def visit_Return(self, node, level=0):
self.visit_node_not_assign(node, level)

def visit_content_in_condition_control(self, node, level):
if isinstance(node, ast.Expr):
self.visit_Expr(node, level)
Expand All @@ -194,6 +197,8 @@ def visit_content_in_condition_control(self, node, level):
self.visit_Call(node, level)
elif isinstance(node, ast.Raise):
self.visit_Raise(node, level)
elif isinstance(node, ast.Return):
self.visit_Return(node, level)
else:
code = tools.ast2code(ast.fix_missing_locations(node))
raise errors.CodeError(f'BrainPy does not support {type(node)} '
Expand Down Expand Up @@ -253,7 +258,7 @@ def visit_Call(self, node, level=0):
elif len(args) + len(kw_args) == 2:
value = kw_args['value'] if len(args) <= 1 else args[1]
if uniform_delay:
rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx][{idx_or_val}] = {value}'
rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx, {idx_or_val}] = {value}'
else:
rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_in_idx[{idx_or_val}], {idx_or_val}] = {value}'
else:
Expand All @@ -266,7 +271,7 @@ def visit_Call(self, node, level=0):
elif len(args) + len(kw_args) == 1:
idx = kw_args['idx'] if len(args) == 0 else args[0]
if uniform_delay:
rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx][{idx}]'
rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx, {idx}]'
else:
rep_expression = f'{dvar4call}.delay_data[{dvar4call}.delay_out_idx[{idx}], {idx}]'
else:
Expand Down
267 changes: 174 additions & 93 deletions brainpy/backend/drivers/numba_cuda.py

Large diffs are not rendered by default.

5 changes: 2 additions & 3 deletions brainpy/simulation/connectivity/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,14 +357,13 @@ def __init__(self):

def requires(self, *syn_requires):
# get synaptic requires
requires = set()
requires = []
for n in syn_requires:
if n in SUPPORTED_SYN_STRUCTURE:
requires.add(n)
requires.append(n)
else:
raise ValueError(f'Unknown synapse structure {n}. We only support '
f'{SUPPORTED_SYN_STRUCTURE}.')
requires = list(requires)

# synaptic structure to handle
needs = []
Expand Down
23 changes: 14 additions & 9 deletions examples/numba_cuda/AMPA_synapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import brainpy as bp
from numba import cuda

bp.backend.set(backend='numba-cuda', dt=0.05)
bp.integrators.set_default_odeint('exponential_euler')
Expand Down Expand Up @@ -97,6 +98,7 @@ def __init__(self, pre, post, conn, delay=0., g_max=0.10, E=0., tau=2.0, **kwarg

# data
self.s = bp.ops.zeros(self.num)
self.s0 = bp.ops.zeros(1)
self.g = self.register_constant_delay('g', size=self.num, delay_time=delay)

super(AMPA1, self).__init__(pre=pre, post=post, **kwargs)
Expand All @@ -115,39 +117,42 @@ def update(self, _t):
self.g.push(i, self.g_max * self.s[i])
post_id = self.post_ids[i]
self.post.input[post_id] -= self.g.pull(i) * (self.post.V[post_id] - self.E)
if i == 0:
self.s0[0] = self.s[i]
cuda.syncthreads()


def uniform_delay():
hh = HH(100, monitors=['V'])
ampa = AMPA1(pre=hh, post=hh, conn=bp.connect.All2All(), delay=10., monitors=['s'])
hh = HH(4000, monitors=['V'])
ampa = AMPA1(pre=hh, post=hh, conn=bp.connect.All2All(), delay=1., monitors=['s0'])
ampa.g_max /= hh.num
net = bp.Network(hh, ampa)

net.run(100., inputs=(hh, 'input', 10.), report=True)
net.driver.to_host()

fig, gs = bp.visualize.get_figure(row_num=2, col_num=1, )
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(hh.mon.ts, hh.mon.V)
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(ampa.mon.ts, ampa.mon.s, show=True)
bp.visualize.line_plot(ampa.mon.ts, ampa.mon.s0, show=True)


def non_uniform_delay():
hh = HH(100, monitors=['V'])
hh = HH(4000, monitors=['V'])
ampa = AMPA1(pre=hh, post=hh, conn=bp.connect.All2All(),
delay=lambda: np.random.random() * 10., monitors=['s'])
delay=lambda: np.random.random() * 1., monitors=['s0'])
ampa.g_max /= hh.num
net = bp.Network(hh, ampa)

net.run(100., inputs=(hh, 'input', 10.), report=True)
net.driver.to_host()

fig, gs = bp.visualize.get_figure(row_num=2, col_num=1, )
fig.add_subplot(gs[0, 0])
bp.visualize.line_plot(hh.mon.ts, hh.mon.V)
fig.add_subplot(gs[1, 0])
bp.visualize.line_plot(ampa.mon.ts, ampa.mon.s, show=True)
bp.visualize.line_plot(ampa.mon.ts, ampa.mon.s0, show=True)


if __name__ == '__main__':
uniform_delay()
# uniform_delay()
non_uniform_delay()
125 changes: 125 additions & 0 deletions examples/numba_cuda/COBA.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# -*- coding: utf-8 -*-


import time
import numpy as np
import brainpy as bp

np.random.seed(1234)
dt = 0.05
bp.backend.set('numba', dt=dt)

# Parameters
num = 4000 * 15
num_exc = int(num * 0.75)
num_inh = int(num * 0.25)
taum = 20
taue = 5
taui = 10
Vt = -50
Vr = -60
El = -60
Erev_exc = 0.
Erev_inh = -80.
I = 20.
we = 0.6 # excitatory synaptic weight (voltage)
wi = 6.7 # inhibitory synaptic weight
ref = 5.0


class LIF(bp.NeuGroup):
target_backend = ['numpy', 'numba', 'numba-cuda']

def __init__(self, size, **kwargs):
# variables
self.V = bp.ops.zeros(size)
self.spike = bp.ops.zeros(size)
self.ge = bp.ops.zeros(size)
self.gi = bp.ops.zeros(size)
self.input = bp.ops.zeros(size)
self.t_last_spike = bp.ops.ones(size) * -1e7

super(LIF, self).__init__(size=size, **kwargs)

@staticmethod
@bp.odeint
def int_ge(ge, t):
dge = - ge / taue
return dge

@staticmethod
@bp.odeint
def int_gi(gi, t):
dgi = - gi / taui
return dgi

@staticmethod
@bp.odeint
def int_V(V, t, ge, gi):
dV = (ge * (Erev_exc - V) + gi * (Erev_inh - V) + El - V + I) / taum
return dV

def update(self, _t):
for i in range(self.num):
self.ge[i] = self.int_ge(self.ge[i], _t)
self.gi[i] = self.int_gi(self.gi[i], _t)
self.spike[i] = 0.
if (_t - self.t_last_spike[i]) > ref:
V = self.int_V(self.V[i], _t, self.ge[i], self.gi[i])
if V >= Vt:
self.V[i] = Vr
self.spike[i] = 1.
self.t_last_spike[i] = _t
else:
self.V[i] = V
self.input[i] = I


class ExcSyn(bp.TwoEndConn):
target_backend = ['numpy', 'numba', 'numba-cuda']

def __init__(self, pre, post, conn, **kwargs):
self.conn = conn(pre.size, post.size)
self.post_ids, self.pre_slice_syn = self.conn.requires('post_ids', 'pre_slice_syn')
super(ExcSyn, self).__init__(pre=pre, post=post, **kwargs)

def update(self, _t):
for pre_id in range(self.pre.num):
if self.pre.spike[pre_id]:
start, end = self.pre_slice_syn[pre_id]
for post_i in self.post_ids[start: end]:
self.post.ge[post_i] += we


class InhSyn(bp.TwoEndConn):
target_backend = ['numpy', 'numba', 'numba-cuda']

def __init__(self, pre, post, conn, **kwargs):
self.conn = conn(pre.size, post.size)
self.post_ids, self.pre_slice_syn = self.conn.requires('post_ids', 'pre_slice_syn')
super(InhSyn, self).__init__(pre=pre, post=post, **kwargs)

def update(self, _t):
for pre_id in range(self.pre.num):
if self.pre.spike[pre_id]:
start, end = self.pre_slice_syn[pre_id]
for post_i in self.post_ids[start: end]:
self.post.gi[post_i] += wi


E_group = LIF(num_exc, monitors=[])
E_group.V = np.random.randn(num_exc) * 5. - 55.
I_group = LIF(num_inh, monitors=[])
I_group.V = np.random.randn(num_inh) * 5. - 55.
E2E = ExcSyn(pre=E_group, post=E_group, conn=bp.connect.FixedProb(0.02))
E2I = ExcSyn(pre=E_group, post=I_group, conn=bp.connect.FixedProb(0.02))
I2E = InhSyn(pre=I_group, post=E_group, conn=bp.connect.FixedProb(0.02))
I2I = InhSyn(pre=I_group, post=I_group, conn=bp.connect.FixedProb(0.02))

net = bp.Network(E_group, I_group, E2E, E2I, I2E, I2I)
t0 = time.time()

net.run(5000., report=True)
print('Used time {} s.'.format(time.time() - t0))

# bp.visualize.raster_plot(net.ts, E_group.mon.spike, show=True)
2 changes: 0 additions & 2 deletions examples/numba_cuda/HH_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,7 @@ def update(self, _t):
group = HH(10000, monitors=['V'])

group.run(200., inputs=('input', 10.), report=True)
group.driver.to_host()
bp.visualize.line_plot(group.mon.ts, group.mon.V, show=True)

group.run(200., report=True)
group.driver.to_host()
bp.visualize.line_plot(group.mon.ts, group.mon.V, show=True)
2 changes: 0 additions & 2 deletions examples/numba_cuda/LIF_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def update(self, _t):
group = LIF(100000, monitors=['V'])

group.run(duration=200., inputs=('input', 26.), report=True)
group.driver.to_host()
bp.visualize.line_plot(group.mon.ts, group.mon.V, show=True)

group.run(duration=(200, 400.), report=True)
group.driver.to_host()
bp.visualize.line_plot(group.mon.ts, group.mon.V, show=True)
17 changes: 11 additions & 6 deletions tests/backend/drivers/test_numba_cuda_on_func_of_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import numpy as np

import brainpy as bp
from brainpy.backend.drivers.numba_cuda import set_monitor_done_in
from brainpy.backend.drivers.numba_cuda import NumbaCUDANodeDriver

bp.backend.set('numba-cuda', dt=0.02)
Expand Down Expand Up @@ -102,12 +103,16 @@ def update(self, _t):


def test_stochastic_lif_monitors1():
lif = StochasticLIF(1, monitors=['V', 'input', 'spike'])
driver = NumbaCUDANodeDriver(pop=lif)
driver.get_monitor_func(mon_length=100, show_code=True)
pprint(driver.formatted_funcs)
print()
print()

for place in ['cpu', 'cuda']:
set_monitor_done_in(place)
lif = StochasticLIF(1, monitors=['V', 'input', 'spike'])
driver = NumbaCUDANodeDriver(pop=lif)
driver.get_monitor_func(mon_length=100, show_code=True)
pprint(driver.formatted_funcs)
print()
print()



test_stochastic_lif_monitors1()
Expand Down

0 comments on commit eadb391

Please sign in to comment.