In [None]:
import numpy as np
from scipy import integrate
import matplotlib.pyplot as plt

In [None]:
class SocialSpatialModel:
    def __init__(self, grid_size, subpopulatons, n_infected, contact_rates, vulnerabilities, exposure_rate, recovery_rate):
        self.m = grid_size # grid will be of shape m x m
        self.N_k = subpopulatons # m^2 vector
        self.I_0 = n_infected #m^2 vector
        self.A = contact_rates
        self.Q = vulnerabilities # assumes we have already run logistic social vulnerability model
        self.alpha = exposure_rate
        self.gamma = recovery_rate
        self.B = self.compute_beta()

    def compute_beta(self):
        # return a N2 x N2 matrix of transmission rates
        pass

    def get_initial_state(self):
        # need to return X where X = (S, E, I, R)
        S0 = self.N_k - self.n_infected
        E0 = np.zeros(self.m)
        I0 = self.n_infected
        R0 = np.zeros(self.m)
        return (S0, E0, I0, R0)

In [None]:
def derivative(X, t, B, alpha, gamma):
    # this ODE is constant across t
    S, E, I, R = X # each state is a *flattened* N x N matrix
    # so really, a N^2 x 1 vector
    dsdt = - B @ I * S
    dedt = B @ I * S - alpha * E
    didt = alpha * E - gamma * I
    drdt = gamma * I
    return np.concatenate([dsdt, dedt, didt, drdt]) # this should be a vector of length m^2 * 4

In [None]:
t = np.linspace(0, 180, 2000) # whatever makes sense
simulation = SocialSpatialModel() # need to pass in parameters
solution = integrate.odeint(
    func = derivative, 
    y0 = simulation.get_initial_state(),
    t = t,
    args=(simulation.B, simulation.alpha, simulation.gamma)
).T
# solution will be of shape (m^2 * 4, T)

S = solution[:m**2, :]
E = solution[m**2:2*m**2, :]
I = solution[2*m**2:3*m**2, :]
R = solution[3*m**2:, :]