In [1]:
import numpy as np
import pandas as pd
from scipy.integrate import solve_ivp

from src.model.maleckar import init_states_constants, compute_rates, legend
from src.helpers import get_value_by_key, update_array_from_kwargs
from matplotlib import pyplot as plt

In [2]:
def run_model(S, C, R, A, config, legend, only_last_beat=True):

    stim_period = get_value_by_key(C, legend['constants'], 'stim_period')
    t_sampling = config['t_sampling']
    n_beats = config['n_beats']

    t_space = np.linspace(0, stim_period * n_beats, int(stim_period / t_sampling) * n_beats + 1, endpoint=True)
    t_tail = np.linspace(stim_period * (n_beats - 1), stim_period * n_beats, int(stim_period / t_sampling) + 1, endpoint=True)
    t_span = 0, t_space[-1]
    
    if only_last_beat:
        t_eval = t_tail
    else:
        t_eval = t_space

    sol = solve_ivp(compute_rates, y0=S,
                    t_span=t_span, t_eval=t_eval,
                    args=(C, R, A),
                    method='LSODA', rtol=1e-9,
                    max_step=1. * t_sampling)
    return sol

``` C
    while (t <= ft) {

        if ((t >= CL * (n_beats - beats_save) && (dt_counter % skip == 0))){
            print_to_scv(S + target_cell_index, A + target_cell_index, file_csv);
            // Uncomment to print potentials of the whole chain:
            // for (int i = 0; i < chain_length; ++i) {
            //     std::cout << S[i].V << " ";
            // }
            // std::cout << std::endl;
        }

        double VOI = t / 1000.;

        for (int i = 0; i < chain_length; ++i) {
            compute_rates(VOI, S + i, R + i, A + i, C, dt / 1000.,
                          (i) ? 0 : stim_amplitude, (i) ? 0 : stim_baseline);
                          
            double g_gap_junc = 5.0;
            double I_gap_junc = 0;
            if (i < chain_length - 1) {
                I_gap_junc += -g_gap_junc * (S[i + 1].V - S[i].V);
            }
            if (i > 0) {
                I_gap_junc += -g_gap_junc * (S[i - 1].V - S[i].V);
            }
            R[i].V -= I_gap_junc * 1000.00;
        }
        
        for (int i = 0; i < chain_length; ++i) {
            euler(dt / 1000., S + i, R + i);
        }

        t += dt;
        dt_counter++;
    }
```

In [208]:
S, C = init_states_constants()
R = np.zeros_like(S)
A = np.zeros(len(legend['algebraic']))

n_beats = 1
CL = 1000
stim_period = CL/1000.
t_sampling = 1e-3

C = update_array_from_kwargs(C, legend['constants'], stim_period=stim_period)

t_space = np.linspace(0, stim_period * n_beats, int(stim_period / t_sampling) * n_beats + 1, endpoint=True)
t_tail = np.linspace(stim_period * (n_beats - 1), stim_period * n_beats, int(stim_period / t_sampling) + 1, endpoint=True)
t_span = 0, t_space[-1]

only_last_beat = False
if only_last_beat:
    t_eval = t_tail
else:
    t_eval = t_space

sol = solve_ivp(compute_rates, y0=S,
                t_span=t_span, t_eval=t_eval,
                args=(C, R, A),
                method='LSODA', rtol=1e-9,
                max_step=1. * t_sampling)

# Chain

In [80]:
import time

In [194]:
def compute_rates_chain(t, S_chain, C_chain, R_chain, A_chain,
                        S_size, C_size, A_size,
                        chain_length):
    
    g_gap_junc = 5.0
    
    for index_cell in range(chain_length):
        S = S_chain[index_cell * S_size: (index_cell + 1) * S_size]
        C = C_chain[index_cell * C_size: (index_cell + 1) * C_size]
        R = R_chain[index_cell * S_size: (index_cell + 1) * S_size]
        A = A_chain[index_cell * A_size: (index_cell + 1) * A_size]
        
        compute_rates(t, S, C, R, A)
        
        I_gap_junc = 0
        
        if index_cell < chain_length - 1:
            I_gap_junc += -g_gap_junc * (S_chain[(index_cell + 1) * S_size] - S[0])
            
        if index_cell > 0:
            I_gap_junc += -g_gap_junc * (S_chain[(index_cell - 1) * S_size] - S[0])
            
        R[0] -= I_gap_junc * 1000

    return R_chain


def event_break(t, S_chain, S_size, threshold=0.1, safe_time=0.001):
    mean_abs_diff = np.mean(np.abs(np.diff(S_chain[::S_size], axis=0)), axis=0)
    return mean_abs_diff < threshold and t > safe_time

event_break.terminal = True
event_break.direction = 1

In [219]:
@njit
def compute_rates_chain(t, S_chain, C_chain, R_chain, A_chain,
                        S_size, C_size, A_size,
                        chain_length):
    
    g_gap_junc = 5.0
    
    for index_cell in range(chain_length):
        S = S_chain[index_cell * S_size: (index_cell + 1) * S_size]
        C = C_chain[index_cell * C_size: (index_cell + 1) * C_size]
        R = R_chain[index_cell * S_size: (index_cell + 1) * S_size]
        A = A_chain[index_cell * A_size: (index_cell + 1) * A_size]
    
        I_gap_junc = 0
        
        if index_cell < chain_length - 1:
            I_gap_junc += -g_gap_junc * (S_chain[(index_cell + 1) * S_size] - S[0])
            
        if index_cell > 0:
            I_gap_junc += -g_gap_junc * (S_chain[(index_cell - 1) * S_size] - S[0])
            
        R[0] -= I_gap_junc * 1000

    return R_chain


def event_break(t, S_chain, S_size, threshold=0.1, safe_time=0.001):
    mean_abs_diff = np.mean(np.abs(np.diff(S_chain[::S_size], axis=0)), axis=0)
    return mean_abs_diff < threshold and t > safe_time

event_break.terminal = True
event_break.direction = 1

In [218]:
from numba import njit

In [220]:
S, C = init_states_constants()

#  S = np.concatenate([S, np.array([0])])  # I_stim hack

R = np.zeros_like(S)
A = np.zeros(len(legend['algebraic']))

n_beats = 1
CL = 1000
stim_period = CL/1000.
t_sampling = 1e-4

C = update_array_from_kwargs(C, legend['constants'], stim_period=stim_period)

chain_length = 30

C_stim = update_array_from_kwargs(C, legend['constants'], stim_amplitude=-40)
C_no_stim = update_array_from_kwargs(C, legend['constants'], stim_amplitude=0)

S_chain = np.tile(S, chain_length)

C_chain = np.concatenate([C_stim, np.tile(C_no_stim, chain_length - 1)])
#C_chain = np.tile(C, chain_length)

R_chain = np.zeros(len(R) * chain_length)
A_chain = np.zeros(len(A) * chain_length)

In [221]:
R_chain = compute_rates_chain(0, S_chain, C_chain, R_chain, A_chain,
                              len(S), len(C), len(A), chain_length)
#S_chain += 1e-5 * R_chain

In [222]:
from scipy.integrate import LSODA

In [223]:
solver = LSODA(lambda t, y: compute_rates_chain(t, y, C_chain, R_chain, A_chain, len(S), len(C), len(A), chain_length),
               t0=0, y0=S_chain, t_bound=t_span[-1], rtol=1e-9, atol=1e-6)

In [224]:
#y = []
#t = []

t_chain = time.time()
while solver.status == 'running' and not event_break(solver.t, solver.y, len(S)):
    solver.step()
    #y.append(solver.y[0])
    #t.append(solver.t)
    
t_chain = time.time() - t_chain    
print(solver.t, t_chain)

0.029400457365996887 8.297836780548096


In [225]:
t_single = time.time()

target_cell = 0#chain_length // 2
S = solver.y[target_cell * len(S): (target_cell + 1) * len(S)]

sol = solve_ivp(compute_rates, y0=S,
                t_span=(solver.t, t_span[-1]),# t_eval=t_eval,
                args=(C, R, A), first_step=solver.step_size,
                method='LSODA', rtol=1e-9,
                max_step=1. * t_sampling)

t_single = time.time() - t_single

In [227]:
t_sum = t_chain + t_single

In [228]:
print(t_chain / t_sum * 100, t_single / t_sum * 100, t_sum)

94.93866841422786 5.061331585772129 8.74020767211914


In [201]:
plt.plot(t, y)
plt.plot(sol.t, sol.y[0], '.-')

[<matplotlib.lines.Line2D at 0x7fd13361a9d0>]

In [157]:
S

array([-2.01616516e+01,  1.29946417e+02,  8.52708090e+00,  7.00455474e-01,
        2.68621341e-04,  2.32499153e-04,  2.15019508e-02,  1.51156236e-01,
        1.80267302e-01,  6.46129109e-01,  5.65497189e+00,  1.29494089e+02,
        1.38699117e-01,  1.41433983e-01,  2.03356019e-01,  9.60587788e-01,
        1.17217103e-02,  2.57705588e-01,  1.79290074e+00,  2.53922501e-04,
        1.00402281e-01,  5.12036405e-02,  4.81969542e-01,  4.57127494e-01,
        1.43525225e+00,  1.44766437e-01,  9.43174874e-01,  1.44409652e-01,
        5.54878725e-03,  2.37347145e-02])

In [158]:
solver.status

'running'

In [159]:
solver.step_size

0.0006736447172107987

In [160]:
sol

  message: 'The solver successfully reached the end of the integration interval.'
     nfev: 25378
     njev: 524
      nlu: 524
      sol: None
   status: 0
  success: True
        t: array([0.05021799, 0.05022424, 0.05023049, ..., 1.        , 1.        ,
       1.        ])
 t_events: None
        y: array([[-2.01616516e+01, -2.01633350e+01, -2.01650181e+01, ...,
        -7.40762071e+01, -7.40762071e+01, -7.40762065e+01],
       [ 1.29946417e+02,  1.29946420e+02,  1.29946423e+02, ...,
         1.30012162e+02,  1.30012162e+02,  1.30012162e+02],
       [ 8.52708090e+00,  8.52708048e+00,  8.52708005e+00, ...,
         8.51822797e+00,  8.51822797e+00,  8.51822797e+00],
       ...,
       [ 1.44409652e-01,  1.44420210e-01,  1.44430769e-01, ...,
         4.31071334e-01,  4.31071334e-01,  4.31071334e-01],
       [ 5.54878725e-03,  5.54922813e-03,  5.54966911e-03, ...,
         4.70249003e-01,  4.70249003e-01,  4.70249003e-01],
       [ 2.37347145e-02,  2.37319574e-02,  2.37292017e-02, ...,


In [161]:
plt.plot(sol.t, sol.y[0], '.-')

[<matplotlib.lines.Line2D at 0x7fd1398841d0>]

In [109]:
print(solver.t)

0.029400457365996887


In [163]:
t_space = np.linspace(0, stim_period * n_beats, int(stim_period / t_sampling) * n_beats + 1, endpoint=True)
t_tail = np.linspace(stim_period * (n_beats - 1), stim_period * n_beats, int(stim_period / t_sampling) + 1, endpoint=True)
t_span = 0, t_space[-1]

only_last_beat = False
if only_last_beat:
    t_eval = t_tail
else:
    t_eval = t_space

sol = solve_ivp(compute_rates_chain, y0=S_chain,
                t_span=t_span, t_eval=t_eval,
                args=(C_chain, R_chain, A_chain, len(S), len(C), len(A), chain_length),
                method='LSODA', rtol=1e-9,
                max_step=1e-3,
                #event=[lambda t, s: event_break(s, len(S))]
               )

In [164]:
y = sol.y.reshape((chain_length, len(S), -1))

In [165]:
plt.plot(y[-1, 0])
#plt.plot(sol.y[0 + 3 * len(S)])

[<matplotlib.lines.Line2D at 0x7fd1397cb4d0>]

In [47]:
I_gap_junc = np.diff(y[:, 0], axis=0, prepend=y[:1, 0]) - np.diff(y[:, 0], axis=0, append=y[-1:, 0])

In [51]:
plt.plot(I_gap_junc[10] * 5e-4)
#plt.xlim(-1, 200)
plt.grid()

In [10]:
from matplotlib import cm

In [52]:
n = 10

for i in range(10):
    plt.plot(I_gap_junc[i], color=cm.viridis(i / n))
    
#plt.xlim(0, 150)

In [53]:
dV = y[2:, 0] - y[:-2, 0]

In [67]:
plt.plot(event_break(sol.y, len(S)))

[<matplotlib.lines.Line2D at 0x7fd17ea4b490>]

In [56]:
plt.plot(dV[9])

[<matplotlib.lines.Line2D at 0x7fd17e3b4890>]

In [58]:
n = 10

for i in range(len(dV)):
    plt.plot(dV[i], lw=0.5, color='0.8')
    
plt.plot(dV[0])
plt.plot(dV[-1])

[<matplotlib.lines.Line2D at 0x7fd17f070610>]