In [30]:
import numpy as np
import pandas as pd
from mesa import Agent, Model
from mesa.time import RandomActivation
from mesa.space import NetworkGrid
from mesa.datacollection import DataCollector
from mesa import batchrunner
from matplotlib import pyplot as plt
from matplotlib.lines import Line2D
import networkx as nx
import enum
import random


COLORS = {"S": "#2f4b7c",
          "E": "#ffa600",
          "I": "#f95d6a",
          "R": "#a05195",
          "D": "#003f5c",
          "C": "#ff7c43",
          "X": "#665191",
          "Y": "#d45087"}

In [11]:
class State(enum.IntEnum):
    SUSCEPTIBLE = 0
    INFECTED = 1
    RECOVERED = 2
    DECEASED = 3

In [12]:
class Person(Agent):
    def __init__(self, uid, model):
        super().__init__(uid, model)
        
        # Naive start
        self.state = State.SUSCEPTIBLE
        
    def contact_event(self):
        neighbourhood = self.model.grid.get_neighbors(self.pos, include_center=False)
        susceptible_neighbourhood = [agent for agent in self.model.grid.get_cell_list_contents(neighbourhood) 
                                                                            if agent.state is State.SUSCEPTIBLE]

        q_choice = random.sample(susceptible_neighbourhood, min(self.model.q, len(susceptible_neighbourhood)))

        for neighbour in q_choice:
            if np.random.rand() < self.model.beta:
                neighbour.state = State.INFECTED

    def status_update(self):        
        if self.state == State.INFECTED:
            if np.random.rand() < 1/self.model.recovery_period:
                if np.random.rand() < self.model.CFR:
                    self.state = State.DECEASED
                else:
                    self.state = State.RECOVERED

        if np.random.rand() < self.model.base_mortality_rate:
            self.state = State.DECEASED


    def step(self):
        self.status_update()

        if self.state is State.INFECTED:
            self.contact_event()
    

In [13]:
class NetworkInfectiousDiseaseModel(Model):

    def __init__(self, 
                 nodes=5000, 
                 mean_degree=12,
                 recovery_period=14,
                 beta=0.04,
                 CFR=0.05,
                 base_mortality_rate=0.000125,
                 I0=0.005,
                 q=4):
                     
        self.N_agents = nodes
        self.recovery_period = recovery_period
        self.beta = beta
        self.CFR = CFR
        self.base_mortality_rate = base_mortality_rate
        self.q = q

        self.graph = nx.erdos_renyi_graph(n=self.N_agents, p=mean_degree/self.N_agents)
        self.grid = NetworkGrid(self.graph)

        self.schedule = RandomActivation(self)
        self.running = True

        for idx, node in enumerate(self.graph.nodes()):
            agent = Person(uid=idx + 1, model=self)
            self.schedule.add(agent)
            self.grid.place_agent(agent, node)

            if np.random.rand() < I0:
                agent.state = State.INFECTED

        self.datacollector = DataCollector({"Infectious": lambda m: m.number_infectious()})

    def step(self):
        self.datacollector.collect(self)
        self.schedule.step()

    def number_infectious(self):
        return np.sum([1 for i in self.schedule.agents if i.state is State.INFECTED])

In [14]:
br = batchrunner.FixedBatchRunner(NetworkInfectiousDiseaseModel,
                                  fixed_parameters={"nodes": 2000},
                                  parameters_list=[{"q": 3}, {"q": 5}, {"q": 7}],
                                  iterations=15,
                                  max_steps=200,
                                  model_reporters={"vals": lambda m: m.datacollector},
                                  display_progress=True)

In [15]:
br.run_all()

0it [00:00, ?it/s]1it [00:02,  2.07s/it]2it [00:04,  2.10s/it]3it [00:06,  2.09s/it]4it [00:08,  2.07s/it]5it [00:10,  2.10s/it]6it [00:12,  2.09s/it]7it [00:14,  2.08s/it]8it [00:16,  2.07s/it]9it [00:18,  2.08s/it]10it [00:20,  2.10s/it]11it [00:22,  2.09s/it]12it [00:25,  2.09s/it]13it [00:27,  2.09s/it]14it [00:29,  2.08s/it]15it [00:31,  2.10s/it]16it [00:33,  2.10s/it]17it [00:35,  2.09s/it]18it [00:37,  2.09s/it]19it [00:39,  2.09s/it]20it [00:41,  2.12s/it]21it [00:44,  2.12s/it]22it [00:46,  2.11s/it]23it [00:48,  2.10s/it]24it [00:50,  2.09s/it]25it [00:52,  2.11s/it]26it [00:54,  2.09s/it]27it [00:56,  2.08s/it]28it [00:58,  2.08s/it]29it [01:00,  2.09s/it]30it [01:02,  2.11s/it]31it [01:04,  2.11s/it]32it [01:07,  2.10s/it]33it [01:09,  2.10s/it]34it [01:11,  2.10s/it]35it [01:13,  2.12s/it]36it [01:15,  2.12s/it]37it [01:17,  2.11s/it]38it [01:19,  2.11s/it]39it [01:21,  2.13s/it]40it [01:24,  2.12s/it]41it [01:26,  2.11s/it]42it 

In [16]:
by_q_values = {}

for i in range(len(br.get_model_vars_dataframe().q)):
    if br.get_model_vars_dataframe().q[i] not in by_q_values.keys():
        by_q_values[br.get_model_vars_dataframe().q[i]] = []
    
    by_q_values[br.get_model_vars_dataframe().q[i]].append(
        br.get_model_vars_dataframe().vals[i].get_model_vars_dataframe()["Infectious"]
    )

In [17]:
means, sds = {}, {}

for i in br.get_model_vars_dataframe().q.unique():
    mu = pd.concat([*(by_q_values[i])], axis=1).mean(axis=1)
    sigma = pd.concat([*(by_q_values[i])], axis=1).std(axis=1)

    means[i] = mu
    sds[i] = sigma

In [39]:
fig = plt.figure(facecolor="w", figsize=(8, 6), dpi=600)
ax = fig.add_subplot(111, axisbelow=True)

colour_cycle = ["X", "I", "S"]

logged_qs = []
linestyle_idx = 0

for idx, q in enumerate(by_q_values.keys()):
    for ts in by_q_values[q]:
        ax.plot(ts, color=COLORS[colour_cycle[idx]], alpha=0.3, lw=0.5, label="")
        logged_qs.append(q)
    linestyle_idx += 1

    ax.plot(means[q], color=COLORS[colour_cycle[idx]], lw=1.5, alpha=0.5, linestyle=":")
    ax.fill_between(ts.index, means[q], 0, color=COLORS[colour_cycle[idx]], alpha=0.2)

legend_lines = [Line2D([0], [0], color=COLORS["S"], lw=4, alpha=0.5),
                Line2D([0], [0], color=COLORS["I"], lw=4, alpha=0.5),
                Line2D([0], [0], color=COLORS["X"], lw=4, alpha=0.5)]

ax.legend(legend_lines, ["Infectious (q=7)", "Infectious (q=5)", "Infectious (q=3)"], title="", bbox_to_anchor=(0.5, -0.355), loc="lower center", ncol=3, frameon=False)


ax.set_xlabel("Days")
ax.set_ylabel("Number in compartment")

fig.tight_layout(pad=5.0)
plt.savefig("Q_infector.pdf")