In [1]:
import numpy as np
from scipy.sparse import csgraph
from scipy.integrate import solve_ivp
import scipy.linalg as la
import matplotlib.pyplot as plt

In [2]:
def spatial_PDE(N,b,k,p,T,I0,M=200):
    """
    A spatial PDE simulation that model time dependent variables: S, I, R
    
    N: N individuals the population has
    
    b: the number of interactions each day that could spread the disease (per individual)
    
    k: the fraction of the infectious population which recovers each day
    
    T: simulation time period
    
    I0: Initial positions of infectious individuals
    
    M: size of grid
    
    """
    def f(t,v):
        """
        v = [S, I, R]
        
        """
        # print(v.shape)
        # reshape 1d array back to MxM matrices
        S = v[:M**2].reshape((M,M))
        I = v[M**2:-M**2].reshape((M,M))
        R = v[-M**2:].reshape((M,M))
        
        # derivatives
        dSdt = -b * S * I + p * (1/M) ** 2 * csgraph.laplacian(S)
        dIdt = b * S * I - k * I + p * (1/M) ** 2 * csgraph.laplacian(I)
        dRdt = k * I + p * (1/M) ** 2 * csgraph.laplacian(R)
        
        # reshape MxM grids to 1d array
        dSdt = dSdt.reshape((1,M**2))
        dIdt = dIdt.reshape((1,M**2))
        dRdt = dRdt.reshape((1,M**2))
        return np.concatenate([dSdt, dIdt, dRdt],axis=None)
    
    # reshape initial MxM grids to 1d array
    S0 = (np.ones((M,M))-I0).reshape((1,M**2))
    R0 = np.zeros((M,M)).reshape((1,M**2))
    I0 = I0.reshape((1,M**2))
    
    # solve_ivp only accept 1d initial value, so conbine 3 1d array to 1
    v0 = np.concatenate([S0,I0,R0],axis=None)
    
    t_span = [0,T]
    t_eval = list(range(T))
    sol = solve_ivp(f, t_span, v0, t_eval=t_eval)
    return sol

In [3]:
k = 0.5
b = 0.5

M = 10
N = 100
T = 5
# grids with 1s and 0s, 1s are the infected
I0 = np.random.choice([1]*M+[0]*(M*M-M), size=(10,10), replace=False)

p = 1/M
        
sol = spatial_PDE(N,b,k,p,T,I0,M)

print(sol.y.shape)
# y has shape (M*M*3,T)
# ith column is the 1d array contains the info of the ith day

S = []
I = []
R = []
for dayind in range(T):
    S.append(sol.y[:,dayind][:M**2].reshape((M,M)))
    I.append(sol.y[:,dayind][M**2:-M**2].reshape((M,M)))
    R.append(sol.y[:,dayind][-M**2:].reshape((M,M)))
S = np.array(S)
I = np.array(I)
R = np.array(R)

print(S)
print(I)
print(R)
# plt.plot(sol.t, sol.y[0], label="susceptible", c='g')
# plt.plot(sol.t, sol.y[1], label="infectious", c='r')
# plt.plot(sol.t, sol.y[2], label="removed", c='b')



(300, 5)
[[[1.         1.         1.         1.         1.         1.
   1.         1.         1.         1.        ]
  [1.         1.         1.         1.         0.         0.
   1.         1.         1.         1.        ]
  [1.         1.         1.         1.         1.         1.
   1.         1.         1.         1.        ]
  [0.         0.         1.         1.         1.         1.
   1.         1.         1.         1.        ]
  [1.         1.         1.         1.         1.         1.
   0.         1.         1.         1.        ]
  [1.         1.         1.         1.         1.         1.
   1.         1.         1.         1.        ]
  [1.         1.         1.         1.         1.         1.
   1.         1.         0.         1.        ]
  [1.         1.         1.         0.         1.         1.
   1.         1.         1.         1.        ]
  [1.         1.         1.         1.         0.         1.
   1.         1.         0.         1.        ]
  [1.     