In [113]:
import numpy as np
import pandas as pd
from mesa import Agent, Model
from mesa.time import RandomActivation
from mesa.space import NetworkGrid
from mesa.datacollection import DataCollector
from mesa import batchrunner
from matplotlib import pyplot as plt
from matplotlib import colors as clrs
import networkx as nx
import json
from scipy import stats
import enum


COLORS = {"S": "#2f4b7c",
          "E": "#ffa600",
          "I": "#f95d6a",
          "R": "#a05195",
          "D": "#003f5c",
          "C": "#ff7c43",
          "X": "#665191",
          "Y": "#d45087"}

In [114]:
class State(enum.IntEnum):
    SUSCEPTIBLE = 0
    EXPOSED = 1
    SYMPTOMATIC = 2
    RECOVERED = 3
    DECEASED = 4
    QUARANTINED = 5

In [115]:
class Person(Agent):
    def __init__(self, uid, model):
        super().__init__(uid, model)
        
        # Naive start
        self.state = State.SUSCEPTIBLE
        self.days_remaining_of_quarantine = None
        
    def contact_event(self):
        neighbourhood = self.model.grid.get_neighbors(self.pos, include_center=False)
        susceptible_neighbourhood = [agent for agent in self.model.grid.get_cell_list_contents(neighbourhood) 
                                                                            if agent.state is State.SUSCEPTIBLE]

        for neighbour in susceptible_neighbourhood:
            if self.state in [State.SYMPTOMATIC, State.EXPOSED] and np.random.rand() < self.model.beta:
                neighbour.state = State.EXPOSED

    def status_update(self):        
        if self.state == State.SYMPTOMATIC:
            if np.random.rand() < 1/self.model.recovery_period:
                if np.random.rand() < self.model.CFR:
                    self.state = State.DECEASED
                else:
                    self.state = State.RECOVERED
            elif np.random.rand() < self.model.quarantine_capture_fraction:
                self.state = State.QUARANTINED
                self.days_remaining_of_quarantine = self.model.quarantine_length + 1
        
        if self.state == State.EXPOSED:
            if np.random.rand() < 1/self.model.latency_period:
                self.state = State.SYMPTOMATIC

        if self.state == State.QUARANTINED:
            self.days_remaining_of_quarantine -= 1
            if self.days_remaining_of_quarantine < 1:
                self.state = State.RECOVERED
                self.days_remaining_of_quarantine = None

        if np.random.rand() < self.model.base_mortality_rate:
            self.state = State.DECEASED


    def step(self):
        self.status_update()

        if self.state in [State.SYMPTOMATIC, State.EXPOSED]:
            self.contact_event()
    

In [116]:
class NetworkInfectiousDiseaseModel(Model):

    def __init__(self, 
                 nodes=5000, 
                 mean_degree=12,
                 recovery_period=7,
                 latency_period=4,
                 beta=0.0075,
                 CFR=0.05,
                 base_mortality_rate=0.000125,
                 quarantine_capture_fraction=0.1,
                 quarantine_length=7,
                 I0=0.05):
                     
        self.N_agents = nodes
        self.recovery_period = recovery_period
        self.latency_period = latency_period
        self.beta = beta
        self.CFR = CFR
        self.base_mortality_rate = base_mortality_rate
        self.quarantine_capture_fraction = quarantine_capture_fraction
        self.quarantine_length = quarantine_length

        self.graph = nx.erdos_renyi_graph(n=self.N_agents, p=mean_degree/self.N_agents)
        self.grid = NetworkGrid(self.graph)

        self.schedule = RandomActivation(self)
        self.running = True

        for idx, node in enumerate(self.graph.nodes()):
            agent = Person(uid=idx + 1, model=self)
            self.schedule.add(agent)
            self.grid.place_agent(agent, node)

            if np.random.rand() < I0:
                agent.state = State.EXPOSED

        self.datacollector = DataCollector(agent_reporters={"State": "state"})

    def step(self):
        self.datacollector.collect(self)
        self.schedule.step()

    def to_df(self):
        state_profile = self.datacollector.get_agent_vars_dataframe()
        output_df = pd.pivot_table(state_profile.reset_index(), 
                       index="Step", 
                       columns="State", 
                       aggfunc=np.size, 
                       fill_value=0)
        output_df.columns = [i.title() for i in State.__dict__.get("_member_names_")]
        return output_df

In [117]:
br = batchrunner.FixedBatchRunner(NetworkInfectiousDiseaseModel,
                                  fixed_parameters={"nodes": 5000},
                                  iterations=32,
                                  max_steps=60,
                                  model_reporters={"vals": lambda m: m.datacollector},
                                  display_progress=True)

In [118]:
br.run_all()

0it [00:00, ?it/s]1it [00:03,  3.54s/it]2it [00:07,  3.54s/it]3it [00:10,  3.54s/it]4it [00:14,  3.52s/it]5it [00:18,  3.73s/it]6it [00:21,  3.65s/it]7it [00:25,  3.74s/it]8it [00:29,  3.69s/it]9it [00:32,  3.64s/it]10it [00:36,  3.75s/it]11it [00:40,  3.67s/it]12it [00:43,  3.66s/it]13it [00:47,  3.79s/it]14it [00:51,  3.72s/it]15it [00:55,  3.67s/it]16it [00:59,  3.82s/it]17it [01:02,  3.74s/it]18it [01:06,  3.69s/it]19it [01:09,  3.64s/it]20it [01:13,  3.61s/it]21it [01:17,  3.81s/it]22it [01:21,  3.75s/it]23it [01:24,  3.68s/it]24it [01:28,  3.65s/it]25it [01:31,  3.62s/it]26it [01:36,  3.87s/it]27it [01:39,  3.77s/it]28it [01:43,  3.73s/it]29it [01:47,  3.67s/it]30it [01:50,  3.63s/it]31it [01:54,  3.59s/it]32it [01:57,  3.57s/it]32it [01:57,  3.68s/it]


In [119]:
results = [pd.pivot_table(i.get_agent_vars_dataframe().reset_index(), 
                    index="Step", 
                    columns="State", 
                    aggfunc=np.size, 
                    fill_value=0).rename(columns={0: "Susceptible",
                                                  1: "Exposed",
                                                  2: "Symptomatic",
                                                  3: "Recovered",
                                                  4: "Deceased",
                                                  5: "Quarantined"}).droplevel(0, axis=1) for i in br.get_model_vars_dataframe().vals]

In [120]:
means = pd.concat([each.stack() for each in results], axis=1).apply(lambda x: x.mean(), axis=1).unstack()
sds = pd.concat([each.stack() for each in results], axis=1).apply(lambda x: x.std(), axis=1).unstack()

In [121]:
fig = plt.figure(facecolor="w", figsize=(8, 6), dpi=600)
ax = fig.add_subplot(111, axisbelow=True)

for idx, model_results in enumerate(results):
    ax.plot(model_results.index, model_results.Symptomatic, color=COLORS["I"], alpha=0.2, lw=0.2, label="")
    ax.plot(model_results.index, model_results.Exposed, color=COLORS["E"], alpha=0.2, lw=0.2, label="")
    ax.plot(model_results.index, model_results.Quarantined, color=COLORS["X"], alpha=0.2, lw=0.2, label="")


ax.plot(means.index, means.Symptomatic, color=COLORS["I"], alpha=0.5, lw=1.25, linestyle="--", label="Symptomatic (mean)")
ax.plot(means.index, means.Symptomatic + 1.96 * sds.Symptomatic, color=COLORS["I"], alpha=0.5, lw=1, linestyle=":")
ax.plot(means.index, means.Symptomatic - 1.96 * sds.Symptomatic, color=COLORS["I"], alpha=0.5, lw=1, linestyle=":")
ax.fill_between(means.index, means.Symptomatic, color=COLORS["I"], alpha=0.25)

ax.plot(means.index, means.Exposed, color=COLORS["E"], alpha=0.5, lw=1.25, linestyle="--", label="Exposed (mean)")
ax.plot(means.index, means.Exposed + 1.96 * sds.Exposed, color=COLORS["E"], alpha=0.5, lw=1, linestyle=":")
ax.plot(means.index, means.Exposed - 1.96 * sds.Exposed, color=COLORS["E"], alpha=0.5, lw=1, linestyle=":")
ax.fill_between(means.index, means.Exposed, color=COLORS["E"], alpha=0.25)

ax.plot(means.index, means.Quarantined, color=COLORS["X"], alpha=0.5, lw=1.25, linestyle="--", label="Quarantined (mean)")


legend = ax.legend()
legend.get_frame().set_alpha(0.5)
plt.legend(title="", bbox_to_anchor=(0.5, -0.355), loc="lower center", ncol=3, frameon=False)

ax.set_xlabel("Days")
ax.set_ylabel("Number in compartment")

fig.tight_layout(pad=5.0)
plt.savefig("SEIRDQ_ABM_output.pdf")