In [1]:
#!/usr/local/bin/python

import os, sys
import json
import numpy as np
import matplotlib.pyplot as plt
import chainer
from chainer import cuda
from matplotlib import animation
from optparse import OptionParser

from elecpy.solver.PDE import PDE
from elecpy.stim.ExtracellularStimulator import ExtracellularStimulator
from elecpy.stim.MembraneStimulator import MembraneStimulator
from elecpy.cell.ohararudy.model import model as cell_model_ohararudy
from elecpy.cell.luorudy.model import model as cell_model_luorudy
from elecpy.cell.mahajan.model import model as cell_model_mahajan
from elecpy.util.cmap_bipolar import bipolar
import elecpy.elecpy as elp

from matplotlib import animation, rc
from IPython.display import HTML

import time

In [2]:
with open ('elecpy/temp/sim_params.json','r') as f:
    sim_params = json.load(f)
    
#sim_params['log']['cnt'] = 10
#sim_params['time']['end'] = 0.2
    
print json.dumps(sim_params, indent=4)

{
    "geometory": {
        "width": 200, 
        "ds": 0.015, 
        "height": 200
    }, 
    "stimulation": {
        "extracellular": [
            {
                "name": "point", 
                "interval": 100.0, 
                "start": 0.0, 
                "shape": [
                    200, 
                    200
                ], 
                "amplitude": 50.0, 
                "duration": 20.0, 
                "size": [
                    91, 
                    61, 
                    5
                ]
            }
        ], 
        "membrane": []
    }, 
    "cell_type": "mahajan", 
    "log": {
        "path": "result", 
        "cnt": 1000
    }, 
    "time": {
        "end": 200, 
        "udt": 0.001
    }
}


In [3]:
def conv_cntSave2time(cnt_save):
    udt          = sim_params['time']['udt']     # Universal time step (ms)
    cnt_log      = sim_params['log']['cnt']      # num of udt for logging
    return udt*cnt_log

def conv_cntUdt2time(cnt_udt):
    udt          = sim_params['time']['udt']     # Universal time step (ms)
    return cnt_udt * udt

def conv_time2cntUdt(t):
    udt          = sim_params['time']['udt']     # Universal time step (ms)
    return int(t/udt)

def conv_time2cntSave(t):
    udt          = sim_params['time']['udt']     # Universal time step (ms)
    cnt_log      = sim_params['log']['cnt']      # num of udt for logging
    return conv_time2cntUdt(t) // cnt_log 

In [4]:
if not os.path.isdir(sim_params['log']['path'] ) :
        os.mkdir(sim_params['log']['path'] )
        
with open('{0}/sim_params.json'.format(sim_params['log']['path'] ), 'w') as f:
    json.dump(sim_params, f, indent=4)
    
assert sim_params is not None

cuda.get_device(0).use()

# Constants
Sv           = 1400                  # Surface-to-volume ratio (cm^-1)
Cm           = 1.0                   # Membrane capacitance (uF/cm^2)
sigma_l_i    = 1.74                  # (mS/cm)
sigma_t_i    = 0.19                  # (mS/cm)
sigma_l_e    = 6.25                  # (mS/cm)
sigma_t_e    = 2.36                  # (mS/cm)

# Geometory settings
im_h         = sim_params['geometory']['height']
im_w         = sim_params['geometory']['width']
ds           = sim_params['geometory']['ds'] # Spatial discretization step (cm)
N            = im_h*im_w

# Time settings
udt          = sim_params['time']['udt']     # Universal time step (ms)
time_end     = sim_params['time']['end']

# Logging settings
cnt_log      = sim_params['log']['cnt']      # num of udt for logging
savepath     = sim_params['log']['path']

# Cell model settings
cells = None
if sim_params['cell_type'] == 'ohararudy':
    cells = cell_model_ohararudy(shape=(N,))
if sim_params['cell_type'] == 'luorudy':
    cells = cell_model_luorudy(shape=(N,))
if sim_params['cell_type'] == 'mahajan':
    cells = cell_model_mahajan(shape=(N,))
assert cells is not None

print "Stimulation settings",
stims_ext = []
stims_mem = []
if 'stimulation' in sim_params.keys():
    stim_param = sim_params['stimulation']
    if 'extracellular' in stim_param:
        for param in stim_param['extracellular']:
            stim = ExtracellularStimulator(**param)
            assert tuple(stim.shape) == (im_h, im_w)
            stims_ext.append(stim)
    if 'membrane' in stim_param:
        for param in stim_param['membrane']:
            stim = MembraneStimulator(**param)
            assert tuple(stim.shape) == (im_h, im_w)
            stims_mem.append(stim)
print "...done"

print "Allocating data...",
cells.create()
i_ion              = np.zeros((N),dtype=np.float64)
phie               = np.zeros((N),dtype=np.float64)
i_ext_e            = np.zeros((N),dtype=np.float64)
i_ext_i            = np.zeros((N),dtype=np.float64)
rhs_phie           = np.zeros((N),dtype=np.float64)
rhs_vmem           = np.zeros((N),dtype=np.float64)
vmem               = np.copy(cells.get_param('v'))
print "...done"

print "Initializing data...",
if 'restart' in sim_params.keys():
    cnt_restart = sim_params['restart']['count']
    srcpath = sim_params['restart']['source']
    pfx = '_{0:0>4}'.format(cnt_restart)
    phie = np.load('{0}/phie{1}.npy'.format(srcpath,pfx))
    vmem = np.load('{0}/vmem{1}.npy'.format(srcpath,pfx))
    cells.load('{0}/cell{1}'.format(srcpath,pfx))
    cnt_udt = cnt_restart * cnt_log
print "...done"


print 'Building PDE system ...',
sigma_l      = sigma_l_e + sigma_l_i
sigma_t      = sigma_t_e + sigma_t_i
if not 'pde_i' in locals(): pde_i = PDE( im_h, im_w, sigma_l_i, sigma_t_i, ds )
if not 'pde_m' in locals(): pde_m = PDE( im_h, im_w, sigma_l,   sigma_t,   ds )
print '...done'

Stimulation settings ...done
Allocating data... ...done
Initializing data... ...done
Building PDE system ... ...done


In [5]:
# Initialization
t         = 0.                       # Time (ms)
cnt_udt   = 0                        # Count of udt
dstep     = 1                        # Time step (# of udt)
cnt_save  = -1

run_udt   = True                     # Flag of running sim in udt
flg_st    = False                    # Flaf of stimulation
cnt_st_off = 0

print 'Main loop start!'
sim_result_image = []
start = time.time()

while t < time_end:
    
    t = conv_cntUdt2time(cnt_udt)
    dt = dstep * udt

    # Stimulation control
    i_ext_e[:] = 0.0
    flg_st_temp = False
    for s in stims_ext:
        i_ext_e += s.get_current(t)*Sv
        flg_st_temp = flg_st_temp or s.get_flag(t)
    for s in stims_mem:
        cells.set_param('st', s.get_current(t)) 

    # step.1 cell state transition
    cells.set_param('dt', dt )
    cells.set_param('v', vmem )
    cells.update()
    i_ion = cells.get_param('it')

    # step.2 phie
    rhs_phie = i_ext_e - i_ext_i - pde_i.forward(vmem)
    pde_cnt, phie = pde_m.solve(phie, rhs_phie, tol=1e-2)
    phie -= phie[0]

    # step.3 vmem
    rhs_vmem = pde_i.forward(vmem)
    rhs_vmem += pde_i.forward(phie)
    rhs_vmem -= i_ion * Sv
    rhs_vmem += i_ext_i
    rhs_vmem *= 1 / (Cm * Sv)
    vmem += dt * rhs_vmem


    # Logging & error check
    cnt_save_now = conv_time2cntSave(t)
    if cnt_save_now != cnt_save:
        print '------------------{0}ms({1})'.format(t, pde_cnt)
        cnt_save = cnt_save_now
        np.save('{0}/phie_{1:0>4}'.format(savepath,cnt_save), phie.reshape((im_h, im_w)))
        np.save('{0}/vmem_{1:0>4}'.format(savepath,cnt_save), vmem.reshape((im_h, im_w)))
        cells.save('{0}/cell_{1:0>4}'.format(savepath,cnt_save))
        sim_result_image.append((t, np.copy(vmem.reshape(im_h, im_w))))

        flg = False
        for i,v in enumerate(vmem):
            if v != v :
                print "error : invalid value {1} @ {0} ms, index {2}".format(t, v, i)
                flg = True
                break
        if flg is True:
            break

    # Stim off count
    if flg_st_temp is False:
        if flg_st is True:
            cnt_st_off = 0
        else:
            cnt_st_off += 1
        flg_st = flg_st_temp

    # Time step control
    if run_udt:
        if cnt_st_off >= 3 and cnt_udt % 10 == 0:
            dstep = 2
            run_udt = False
    else:
        if pde_cnt > 5:
            dstep = 1
            run_udt = True

    cnt_udt += dstep

print "elecpy done"

print "elapsed time:", time.time() - start

Main loop start!
------------------0.0ms(300)
------------------1.0ms(1)
------------------2.0ms(1)
------------------3.0ms(1)
------------------4.0ms(1)
------------------5.0ms(1)
------------------6.0ms(1)
------------------7.0ms(1)
------------------8.0ms(1)
------------------9.0ms(1)
------------------10.0ms(1)
------------------11.0ms(1)
------------------12.0ms(1)
------------------13.0ms(1)
------------------14.0ms(1)
------------------15.0ms(1)
------------------16.0ms(1)
------------------17.0ms(1)
------------------18.0ms(1)
------------------19.0ms(1)
------------------20.0ms(300)
------------------21.0ms(4)
------------------22.0ms(1)
------------------23.0ms(1)
------------------24.0ms(1)
------------------25.0ms(10)
------------------26.0ms(1)
------------------27.0ms(4)
------------------28.0ms(1)
------------------29.0ms(7)
------------------30.0ms(1)
------------------31.0ms(1)
------------------32.0ms(4)
------------------33.0ms(1)
------------------34.0ms(7)
--------

In [6]:
fig = plt.figure()
plt.axis('off')

ims = []
for t, img in sim_result_image:
    im = plt.imshow(
        img,
        vmin = -100.0, vmax = 100.0,
        cmap=bipolar(neutral=0, lutsize=1024),
        interpolation='nearest')
    plt.title('')
    ims.append([im])

ani = animation.ArtistAnimation(fig, ims)
#ani.save('anim.gif', writer="imagemagick")
#ani.save('anim.mp4', writer="ffmpeg")
#plt.show()

HTML(ani.to_html5_video())

* PDEを解く際の計算をscipy.sparseからcusparseへ変更
* ２次元配列として確保し毎回flattenしていたデータを最初からベクトルで確保

した事で約２倍の高速化が見られた。

ただし、ヤコビ法の収束判定条件 ε < 1e-02 が妥当かは要検討。