In [2]:
import numpy as np
from matplotlib import pyplot as plt

### Equations:

\begin{equation} \tag{1}
m_i^{(t)} = \frac{a_i^{(t)}}{a_i^{(t)}+b_t^{(t)}} m_{i+1}^{(t-1)} + \frac{b_i^{(t)}}{a_i^{(t)}+b_t^{(t)}} m_{i-1}^{(t-1)} - a_i^{(t)}
\end{equation}

\begin{equation*} \tag{2}
n_i^{(t)} = \frac{a_i^{(t)}}{a_i^{(t)}+b_t^{(t)}} n_{i+1}^{(t-1)} + \frac{b_i^{(t)}}{a_i^{(t)}+b_t^{(t)}} n_{i-1}^{(t-1)} - b_i^{(t)}
\end{equation*}

\begin{equation*} \tag{3}
(a_i^{(t)} + b_i^{(t)})^2 = b_i^{(t)} (m_{i+1}^{(t-1)} - m_{i-1}^{(t-1)})
\end{equation*}

\begin{equation*} \tag{4}
(a_i^{(t)} + b_i^{(t)})^2 = a_i^{(t)} (n_{i-1}^{(t-1)} - n_{i+1}^{(t-1)})
\end{equation*}

### Solving:

Strategy: at each time $t$, suppose we have $m_{i+1}^{(t-1)}$, $m_{i-1}^{(t-1)}$ and $n_{i+1}^{(t-1)}$, $n_{i-1}^{(t-1)}$. <br>
Then, first solve for $a_i^{(t)}$ and $b_i^{(t)}$ using eq (3) and eq (4). <br>
Plug into to eq (1) and eq (2) to solve for $m_{i}^{(t)}$ and $n_{i}^{(t)}$.

Equating (3) and (4):

\begin{equation*}
a_i^{(t)} (n_{i-1}^{(t-1)} - n_{i+1}^{(t-1)}) = b_i^{(t)} (m_{i+1}^{(t-1)} - m_{i-1}^{(t-1)})
\end{equation*}

\begin{equation*}
a_i^{(t)} = \frac{(m_{i+1}^{(t-1)} - m_{i-1}^{(t-1)})}{(n_{i-1}^{(t-1)} - n_{i+1}^{(t-1)})} b_i^{(t)}
\end{equation*}

\begin{equation*}
a_i^{(t)} = \frac{\Delta m}{\Delta n} b_i^{(t)}
\end{equation*}

where $\Delta m$ and $\Delta n$ are calculable constants obtained from previous timestep: 

\begin{equation*}
\Delta m = (m_{i+1}^{(t-1)} - m_{i-1}^{(t-1)})
\end{equation*}

\begin{equation*}
\Delta n = (n_{i-1}^{(t-1)} - n_{i+1}^{(t-1)})
\end{equation*}

Expanding (3):

\begin{equation*}
(a_i^{(t)})^2 + 2 a_i^{(t)} b_i^{(t)} + (b_i^{(t)})^2 = b_i^{(t)} (m_{i+1}^{(t-1)} - m_{i-1}^{(t-1)})
\end{equation*}

Substituting expression for $a_i^{(t)}$ in terms of $b_i^{(t)}$:

\begin{equation*}
(\frac{\Delta m}{\Delta n} b_i^{(t)})^2 + 2 (\frac{\Delta m}{\Delta n} b_i^{(t)}) b_i^{(t)} + (b_i^{(t)})^2 = b_i^{(t)} \Delta m
\end{equation*}

\begin{equation*}
(b_i^{(t)})^2 \left( \left(\frac{\Delta m}{\Delta n} \right)^2 + 2\frac{\Delta m}{\Delta n} + 1 \right) = b_i^{(t)} \Delta m
\end{equation*}

if $b_i^{(t)} \neq 0$:

\begin{equation*}
b_i^{(t)} = \frac{\Delta m}{\left( \left(\frac{\Delta m}{\Delta n} \right)^2 + 2\frac{\Delta m}{\Delta n} + 1 \right)}
\end{equation*}

\begin{equation*}
= \frac{\Delta m}{\left( \frac{\Delta m}{\Delta n} + 1 \right)^2}
\end{equation*}

which means:

\begin{equation*}
a_i^{(t)} = \frac{\Delta m}{\Delta n} \frac{\Delta m}{\left( \frac{\Delta m}{\Delta n} + 1 \right)^2}
\end{equation*}

Finally, we can plugin values of $a_i^{(t)}$ and $b_i^{(t)}$ into eq (1) and eq (2) to obtain $m_i^{(t)}$ and $n_i^{(t)}$.

### Initialization:

#### Contants:
Left end = $-k$ <br>
Right end = $l$ <br>
Payoff factor = $\lambda$ <br>

#### Positional Boundary Values:
$n_{-k}^{(t)} = 0$ -------- $n_{l}^{(t)} = 1$ <br>
$m_{-k}^{(t)} = 0$ -------- $m_{l}^{(t)} = \lambda$ <br>
for all time step $t$ <br>

#### Time Boundary Values:
$m_i^{(0)} = \frac{\lambda}{k+l}(i+k)$ <br>
$n_i^{(0)} = \frac{l-i}{k+l}$ <br>
for all positions $-k \leq i \leq l$ <br>

In [30]:
k = 6      # left endpoint: -k
l = 6      # right endpoint: l
L = k+l+1    # total length: k+l
T = 100    # Time steps

# Note: position index i will be shifted
# before: left end = -k, middle = 0, right end = l
# now: left end = 0, middle = k, right end = k + l (= L)

_lambda = 0.6

# initialize the m,n values are time t=0, 
# then iterate starting from t=1, use m(t-1), n(t-1) to compute a(t) and b(t) 
# then use a(t) and b(t) to compute m(t) and n(t)

# Each matrix: vertical axis (rows): time step, horizontal axis (colunmns): position
# Therefore, m[t][i] access m value at time t, position index i
a = np.zeros(T*L).reshape((T, L))
b = np.zeros(T*L).reshape((T, L))
m = np.zeros(T*L).reshape((T, L))
n = np.zeros(T*L).reshape((T, L))

# Initialize positional boundary values
for t in range(a.shape[0]):
    n[t][0] = 1
    m[t][0] = 0
    n[t][L-1] = 0
    m[t][L-1] = _lambda
    
# Initialize time boundary values
for i in range(L):
    m[0][i] = (_lambda / (L-1) ) * i
    n[0][i] = (1 / (L-1)) * ((L-1) - i)

In [31]:
m[0]

array([0.        , 0.05454545, 0.10909091, 0.16363636, 0.21818182,
       0.27272727, 0.32727273, 0.38181818, 0.43636364, 0.49090909,
       0.54545455, 0.6       ])

In [32]:
m[1]

array([0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0. , 0.6])

In [33]:
n[0]

array([1.        , 0.90909091, 0.81818182, 0.72727273, 0.63636364,
       0.54545455, 0.45454545, 0.36363636, 0.27272727, 0.18181818,
       0.09090909, 0.        ])

In [34]:
n[1]

array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1.])

In [37]:
# Recursive solving algorithm

for t in range(1, T):
    for i in range(1, L-1):
        delta_m = m[t-1][i+1] - m[t-1][i-1]
        delta_n = n[t-1][i-1] - n[t-1][i+1]
        # Compute a,b using previous m,n
        b[t][i] = delta_m / (((delta_m / delta_n) + 1)**2)
        a[t][i] = (delta_m / delta_n) * b[t][i]
        print(a[t][i], b[t][i])
        # Compute current m,n using current a,b
        m[t][i] = (a[t][i]/(a[t][i] + b[t][i]))*m[t-1][i+1] + (b[t][i]/(a[t][i] + b[t][i]))*m[t-1][i-1] - a[t][i]
        n[t][i] = (a[t][i]/(a[t][i] + b[t][i]))*n[t-1][i+1] + (b[t][i]/(a[t][i] + b[t][i]))*n[t-1][i-1] - n[t][i]

0.025568181818181816 0.042613636363636354
0.02556818181818182 0.042613636363636374
0.02556818181818182 0.04261363636363638
0.02556818181818181 0.042613636363636374
0.025568181818181823 0.04261363636363636
0.025568181818181813 0.04261363636363635
0.02556818181818181 0.042613636363636374
0.02556818181818183 0.04261363636363637
0.025568181818181806 0.04261363636363635
0.02556818181818181 0.04261363636363636
nan 0.0
nan 0.0
nan 0.0
nan 0.0
nan 0.0
nan 0.0
nan 0.0
nan 0.0
nan 0.0
-0.0303164084039126 0.2044324857888359
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
nan nan
n

  b[t][i] = delta_m / (((delta_m / delta_n) + 1)**2)
  a[t][i] = (delta_m / delta_n) * b[t][i]
  a[t][i] = (delta_m / delta_n) * b[t][i]
