In [None]:
import numpy as np


def get_off_diagonal_sums(matrix: np.ndarray) -> np.ndarray:
    assert len(matrix.shape) == 2

    return matrix.sum(axis=0) - np.diag(matrix)


class Population:
    """
    Represents the entire population.
    """

    def __init__(self,
                 population_sizes: np.ndarray,
                 migration_matrix: np.ndarray,
                 alpha: float,
                 beta: np.ndarray,
                 tau: float,
                 gamma: float):
        self.population_sizes = population_sizes

        self.susceptible = population_sizes
        self.infected = None
        self.recovered = np.zeros(len(population_sizes))

        self._migration_matrix = migration_matrix

        self._m_off_diag = get_off_diagonal_sums(migration_matrix)
        self._m_off_diag_t = get_off_diagonal_sums(migration_matrix.T)

        self._alpha = alpha
        self._beta = beta
        self._gamma = gamma
        self._tau = tau

    def next_day(self):
        if self.infected is None:
            raise Exception('Population must be seeded with virus first')

        x = self._get_x()
        y = self._get_y()

        outside_work_hours = 0

        during_work_hours_1 = 0

        during_work_hours_2 = 0

        self.susceptible = self.susceptible - self._tau * outside_work_hours - \
                           self._alpha * (1 - self._tau) * during_work_hours_1 - \
                           self._alpha * (1 - self._tau) * during_work_hours_2

        self.infected = self.infected + self._tau * outside_work_hours + \
                        self._alpha * (1 - self._tau) * during_work_hours_1 + \
                        self._alpha * (1 - self._tau) * during_work_hours_2

        self.recovered = self.recovered + self._gamma * self.infected

    def _get_x(self) -> np.ndarray:
        return self.susceptible / self.population_sizes

    def _get_y(self) -> np.ndarray:
        return self.infected / self.population_sizes

    def seed(self, infected: np.ndarray):
        self.infected = infected


if __name__ == '__main__':
    _migration_matrix = np.array([[100, 200, 300, 400],
                                  [100, 200, 300, 400],
                                  [100, 200, 300, 400],
                                  [100, 200, 300, 400]])
    population = Population(
        population_sizes=np.array([1000, 2000, 3000, 4000]),
        migration_matrix=_migration_matrix,
        alpha=1,
        beta=np.array([0.5, 0.1, 0.3, 0.4]),
        gamma=0.1,
        tau=2 / 3
    )

    population.seed(np.array([10, 0, 0, 20]))

    for i in range(100):
        population.next_day()

        print('Do some diagnostics here')