# General equations 
$$
        \begin{align*}
            \frac{\partial {\bf u}}{\partial t} + \nabla \cdot \Gamma = f
        \end{align*} \\
        \begin{align*}
            {\bf u} &= \begin{bmatrix}
                p \\
                m \\
                c_f \\
                c_b
            \end{bmatrix}, \,
            \nabla = \frac{\partial}{\partial x}
        \end{align*} \\
        \begin{align*}
            \Gamma &= \begin{bmatrix}
                        0 & -D_m (1-\beta p) \partial_x m & -D_c \partial_x c_f & 0 
                    \end{bmatrix}^T \\
            f &= \begin{bmatrix}
                v \partial_x p - \alpha p c_b \\
                \alpha p c_b \\
                -k_s c_f p + \alpha p c_b \\
                v \partial_x c_b + k_s c_f p - \alpha p c_b
            \end{bmatrix} \\
        \end{align*} \\
        \begin{align*}
            v &= \left. a_{br} k_{br} m^2 p \right\vert_{x=0}  + \left. a_{gr}k_{gr} (m-m_c) p \right\vert_{x=0}
        \end{align*}
$$

With the boundaries both at $x=0$ and $x=L$:
$$
    \begin{align}
        &p = 0 \\
        &D_m (1-\beta p) \partial_x m =  0 \\
        &\frac{\partial c_f}{\partial x} = 0 \\
        &c_b = 0 
    \end{align}
$$


# Discretization
**Convention**: A subscript of $i$ indicates the spatial discretizetion, an upperscript of $n$ indicates the temporal discretization. A subscript with no upper script means the current timestamp ($n$).

### Equations

$$
\begin {align}
    p_i^{n+1} &= p_i + \frac{\Delta t}{2 \Delta x} \cdot v (p_{i+1} -p_{i-1}) - Depoly \\
    m_i^{n+1} &= m_i + D_m \frac{\Delta t}{2 \Delta x^2} \big[(2 - \beta (p_i + p_{i+1}))(m_{i+1}-m_i) - (2 - \beta (p_i + p_{i-1}))(m_{i}-m_{i-1}) \big] + Depoly \\
    c_{f,i}^{n+1} &= c_{f, i} + D_c \frac{\Delta  t}{\Delta x^2} (c_{f, i+1} + c_{f, i-1} - 2 c_{f, i}) - Bind + Depoly \\
    c_{b,i}^{n+1} &= c_{b, i} + v \frac{\Delta  t}{2 \Delta x} (c_{b, i+1} - c_{b, i-1}) + Bind - Depoly
\end{align}
$$

Where:
$$
\begin{align}
    Depoly &= \Delta t \, \alpha c_{b, i} p_i \\ 
    Bind &= \Delta t \, k_s c_{f, i} p_i \\ 
\end{align}
$$

### Boundaries
#### Right boundary
$$
    p_{M} = 0 \\
    m_{M} = m_{M-2} \\
    c_{f, M} = c_{f, M-2} \\
    c_{b, M} = 0 \\
$$

#### Left boundary
$$
\begin{align}
    p_{0} = p_2 \\
    c_{f, 0} = c_{f, 2} \\
    c_{b, 0} = 0 \\
\end{align}
$$
As for $m$:

$$
\begin{align*}
    \sum{J_{total}} &=J_m+J_p=D_m (1-\beta p) \nabla m +vp=0 \\
    \Rightarrow  \, & 0 = D_m(1-\beta p_0) \frac{m_1 - m_0}{\Delta x} + p_0^2 \left( a_{br} k_{br} m^2_0  + a_{gr}k_{gr} (m_0-m_c) \right) \\
    \Rightarrow \, & 0 = p_0^2 a_{br} k_{br} \cdot m^2_0 + \left( p_0^2 a_{gr}k_{gr} - \frac{D_m(1-\beta p_0)}{\Delta x} \right) \cdot m_0 + m_1 \frac{D_m(1-\beta p_0)}{\Delta x} - p_0^2 a_{gr}k_{gr} m_c \\ 
\end{align*}
$$


In [61]:
import numpy as np
import matplotlib.pyplot as plt

In [62]:
# Constants

Dm = 8.4
Dc = 10
beta = 0.1
alpha = 0.5
k_s = 0.01
k_gr = 8.7
k_br = 2.16E-5
a_gr = 2
a_br = 2
m_c = 0.2
l = 0.003



$$
    \begin{align*}
        depoly &= \Delta t \cdot \alpha c_{b,i} p_i \\
        bind &= \Delta t \cdot k_s c_{f,i} p_i
    \end{align*}

$$

$$
    \begin{align*}
        f = \begin{bmatrix}
            v_i \partial_x p - \alpha p c_b \\
            \alpha p c_b \\
            -k_s c_f p + \alpha p c_b \\
            v \partial_x c_b + k_s c_f p - \alpha p c_b
        \end{bmatrix} \\
    \end{align*} \\

$$

In [63]:
def get_f(p, cf, cb, dt, dx, v, alpha):
    """ Calculates the f part of the discretizied equations

    Args:
        p (_type_): p
        cf (_type_): cf
        cb (function): cb
        dt (_type_): dt
        dx (_type_): dx
        v (_type_): v

    Returns:
        tuple: f_p[-1:1], f_m[-1:1], f_cf[-1:1], f_cb[-1:1]
    """
    global k_s

    depoly = (dt*alpha[1: -1])*cb[1:-1]*p[1:-1]
    bind = (dt*k_s)*cf[1:-1]*p[1:-1]
    # print(f"Depoly={depoly}")

    f_p = v*dt/(2*dx) * (p[2:] - p[:-2]) - depoly
    f_m = depoly
    f_cf = depoly - bind
    f_cb = v*dt/(2*dx) * (cb[2:] - cb[:-2]) - f_cf
    # print(f_p)
    
    return f_p, f_m, f_cf, f_cb


# get_f(np.random.random(100), np.random.random(100), np.random.random(100), 0.01, 0.1, 5)

In [64]:
def solver(I, L, T, dx, dt, user_action=None, stability_safety_factor=1.0, alpha_func=0.5):
    """_summary_
get_f - A function that returns a list of 4 values according to ${\bf f}$. The spatial derivatives would be calculated based on regular finite difference, e.g. $\frac{r_p^{n, m+1} - r_p^{n, m}}{\Delta x}$
    Args:
        I (List[Callable]): Initial condition for all of the dependent variables (list of 4 callables) 
        L (float): Length of domain
        T (float):  Stop time for the simulation 
        dx (float): Space grid unit
        dt (float): Time unit
        user_action (Callable, optional): _description_. Defaults to None.
        stability_safety_factor (float, optional): _description_. Defaults to 1.0.

    Returns:
        _type_: _description_
    """

    global a_br, k_br, a_gr, k_gr, Dm, Dc, beta, l, m_c

    Nt = int(round(T/dt))
    t = np.linspace(0, Nt*dt, Nt+1)      # Mesh points in time

    # Find max(c) using a fake mesh and adapt dx to C and dt
    # c_max = max([c(x_) for x_ in np.linspace(0, L, 101)])
    # dx = dt*c_max/(stability_safety_factor*C)

    Nx = int(round(L/dx))
    x = np.linspace(0, L, Nx+1)          # Mesh points in space

    p = np.vectorize(I[0], otypes=[float])(x)
    m = np.vectorize(I[1], otypes=[float])(x)
    cf = np.vectorize(I[2], otypes=[float])(x)
    cb = np.vectorize(I[3], otypes=[float])(x)

    alpha = np.vectorize(alpha_func)(x)


    for n in t:
        v = (a_br*k_br*(m[0]**2) + a_gr*k_gr*(m[0])) * p[0]

        if user_action:
            user_action(p=p, m=m, cf=cf, cb=cb, v=v, dx=dx, dt=dt, x=x, t=t, n=n)

        
        f = get_f(p, cf, cb, dt, dx, v, alpha)
        
        # Calculate boundaries
        p_L, m_L, cf_L, cb_L = 0, m[-3], cf[-3], 0
        p_0, cf_0, cb_0 = p[2], cf[2], 0

        poly = [a_br*k_gr*p[0]**2, p[0]**2 *a_gr*k_gr - Dm*(1 - beta*p[0])/(2*dx),  m[2]*Dm*(1-beta *p[0])/(2*dx) - p[0]**2 *a_gr*k_gr*m_c]
        m_12 = np.roots(poly)
        m_0 = min(m_12) # TODO: check what value we should take
        # print(m_0)

        # Calculate inner-points
        m_flux = Dm*dt/(2*dx**2) * \
                    ((2-beta*(p[1: -1] + p[2:]))*(m[2:] - m[1: -1]) - (2-beta*(p[1: -1] + p[:-2]))*(m[1: -1] - m[:-2]))
        p[1: -1] = p[1: -1] + f[0]
        m[1: -1] = m[1: -1] +  m_flux + f[1]
        
        cf[1: -1] = cf[1: -1] + Dc*dt/(dx**2) * (cf[2:] + cf[:-2] - 2*cf[1:-1]) + f[2]
        cb[1: -1] = cb[1: -1] + f[3]

        # Set boundaries

        p[-1], m[-1], cf[-1], cb[-1] = p_L, m_L, cf_L, cb_L
        p[0], m[0], cf[0], cb[0] = p_0, m_0, cf_0, cb_0

    return p, m, cf, cb, x, t


In [65]:
import math
from IPython import display
from time import sleep

I = [
    lambda p: 0.1 * (p < 0.1),
    lambda m: 0,
    lambda cf: 0,
    lambda cb: 0.2 * (cb < 0.5),
]

# fig = plt.figure()
v = np.array(0)
def append_v(**kwargs):
    global v
    v_timestamp = kwargs['v']
    v = np.append(v, v_timestamp)

def plot(**kwargs):
    n = kwargs['n']
    if not( n % 0.01 <= 1e-6):
        return
        
    x = kwargs['x']
    m = kwargs['m']
    plt.title(f"t={n}")
    plt.plot(x, m, '-r')
    plt.ylim((-0.1, 0.1))
    display.display(plt.gcf())
    display.clear_output(wait=True)
    sleep(0.1)
    plt.clf()

def log_numbers(**kwargs):
    parameters = ["p", "m", "cf", "cb", "n"]
    [print(f"{k}: {kwargs[k]}, ", end="") for k in parameters]
    print()

def aggregate_gen(arr):
    def aggregate(**kwargs):
        for k in arr:
            _sum = sum(kwargs[k])
            print(f"{k}: {_sum}")
            

    return aggregate


monomers = []

def monomer_sum(**kwargs):
    global L, monomers

    x = kwargs["x"]
    p = kwargs["p"]
    m = kwargs["m"]

    sum_p = sum(p)/len(x)
    sum_m = sum(m)/len(x)
    # print(f"Monomers: {sum_m + sum_p}")
    monomers.append(sum_m + sum_p)
    return 


    
sol = solver(I, 1, 1, 0.01, 1e-6, user_action=plot, alpha_func=lambda y: 1)


[-5.52087350e-08 -5.53378322e-08 -5.53438531e-08 -5.53470722e-08
 -5.53478672e-08 -5.53462475e-08 -5.53422249e-08 -5.53354378e-08
 -5.54478954e-08 -3.56764019e-08 -3.55408630e-08 -3.55250261e-08
 -3.55073696e-08 -3.54875445e-08 -3.54655785e-08 -3.54415020e-08
 -3.54153478e-08 -3.53871505e-08 -3.53569466e-08 -3.53247747e-08
 -3.52906754e-08 -3.52546909e-08 -3.52168656e-08 -3.51772453e-08
 -3.51358778e-08 -3.50928125e-08 -3.50481004e-08 -3.50017942e-08
 -3.49539481e-08 -3.49046176e-08 -3.48538598e-08 -3.48017333e-08
 -3.47482976e-08 -3.46936137e-08 -3.46377439e-08 -3.45807514e-08
 -3.45227004e-08 -3.44636563e-08 -3.44036853e-08 -3.43428545e-08
 -3.42812316e-08 -3.42188854e-08 -3.41558849e-08 -3.40923000e-08
 -3.40282009e-08 -3.39636583e-08 -3.38987434e-08 -3.38335273e-08
 -3.37680818e-08 -3.37024784e-08 -3.36367890e-08 -3.35710853e-08
 -3.35054390e-08 -3.34399215e-08 -3.33746042e-08 -3.33095581e-08
 -3.32448537e-08 -3.31805612e-08 -3.31167503e-08 -3.30534900e-08
 -3.29908487e-08 -3.29288

KeyboardInterrupt: 

<Figure size 640x480 with 0 Axes>

In [None]:
print(100*(monomers[-1] - monomers[0])/monomers[0])
print(len(monomers))

### Notes

percent of monomer difference with m_c:

2.20327% on dt=1e-6 

2.204861% on dt=0.5e-6

m_c made p be unstable
