In [72]:
from typing import Dict
from IPython.display import clear_output
import matplotlib.pyplot as plt
import random as random
import numpy as np
from time import sleep
from copy import deepcopy
from pprint import pprint

from Complexes.Simplex import Simplex

plt.rcParams["figure.dpi"] = 300


from plotnine import ggplot, aes, geom_line, labs, theme, xlab, ylab, xlim

import pandas as pd
def print_status_bar(progress: float, block_count: int = 10) -> None:
    clear_output(wait=True)
    dark_string = "▓" * round(progress * block_count)
    light_string = "░" * (block_count - len(dark_string))
    print(f"[{dark_string}{light_string}]")

In [2]:
def make_vdws_network(node_count: int,
                      poisson_parameter: int,
                      rewire_probability: float) -> Dict[int, set[int]]:
    debug_period = node_count / 160

    vdws_graph = dict()
    local_degrees = dict()

    # generate local degrees
    l = np.random.poisson(poisson_parameter, node_count)
    while len(l) < node_count:  # very low chance this is invoked even once
        l = np.random.poisson(poisson_parameter, node_count)

    for node in range(node_count):  # circular graph
        vdws_graph[node] = {(node + 1) % node_count, (node - 1) % node_count}
        local_degrees[node] = l[node]
    for node in vdws_graph:
        for circular_neighbour_no_mod in range(node - local_degrees[node], node + local_degrees[node] + 1):
            circular_neighbour = circular_neighbour_no_mod % node_count
            if circular_neighbour == node:
                continue
            vdws_graph[node].add(circular_neighbour)
            vdws_graph[circular_neighbour].add(node)
        # debug
        if node % debug_period == 0:
            print_status_bar(node / len(vdws_graph), block_count=80)
            print("[1/2]  Building circular graph base")
    # rewiring
    for node in vdws_graph:
        for clockwise_neighbour_no_mod in range(node + 1, node + 1 + local_degrees[node]):
            clockwise_neighbour = clockwise_neighbour_no_mod % node_count
            if random.random() < rewire_probability:
                x = int(random.random() * node_count)
                if x == node or x in vdws_graph[node]:
                    continue
                vdws_graph[node].remove(clockwise_neighbour)
                vdws_graph[clockwise_neighbour].remove(node)
                vdws_graph[node].add(x)
                vdws_graph[x].add(node)
        # debug
        if node % debug_period == 0:
            print_status_bar(node / len(vdws_graph), block_count=80)
            print("[2/2]  Rewiring")
    return vdws_graph

In [3]:
random.seed(69420)
np.random.seed(69420)
vdws_network = make_vdws_network(200000, 25, 0.01)
clear_output(wait=True)
print(f"Generated VDWS network with {len(vdws_network)} nodes.")

Generated VDWS network with 200000 nodes.


In [4]:
class EpidemicModel:
    # STATE KEY
    ## 0 - susceptible
    ## 1 - infected
    ## 2 - vaccinated
    ## 3 - vaccinated and infected
    ## 4 - removed
    def __init__(self, network: Dict[int, set[int]], initial_infected_count: int = 5, infection_length: int = 2,
                 pis: float = 0.01, piv: float = 0.01, pvis: float = 0.01, pviv: float = 0.01,
                 vaccination_priority: list[int] = None):
        self.network = network

        self.pis = pis
        self.piv = piv
        self.pvis = pvis
        self.pviv = pviv

        self.infection_length = infection_length

        self.tick_count = 0

        self.vaccination_tick_threshold = 50
        self.vaccination_per_tick = 400

        self.states = dict()
        self.susceptible_nodes = set()
        self.state_counts = dict()
        for i in range(5):
            self.state_counts[i] = 0
        for node in self.network:
            self.states[node] = {"state": 0, "infection_time_remaining": 0}
            self.states[node]["state"] = 0
            self.susceptible_nodes.add(node)
        self.state_counts[0] = len(self.network)
        initial_infected = random.sample(range(len(network)), initial_infected_count)
        for v in initial_infected:
            self.infect(v)

        if vaccination_priority is not None:
            if len(vaccination_priority) != len(self.network):
                raise ValueError("Not enough nodes in vaccination_priority.")
        self.vaccination_priority = vaccination_priority

    def get_state_counts(self) -> Dict[int, int]:
        return self.state_counts

    def get_tick_count(self) -> int:
        return self.tick_count

    def get_transmission_probability(self, node: int, contact_node: int) -> float:
        if node == contact_node:
            raise ValueError("A node cannot infect to itself.")
        if self.get_state(node) == 1:
            if self.states[contact_node]["state"] == 0:
                return self.pis
            if self.states[contact_node]["state"] == 2:
                return self.piv
        if self.get_state(node) == 3:
            if self.states[contact_node]["state"] == 0:
                return self.pvis
            if self.states[contact_node]["state"] == 2:
                return self.pviv
        if self.states[contact_node]["state"] not in [0, 2]:
            raise ValueError(
                f"Node {contact_node} cannot be infectable, it is in state {self.states[contact_node]['state']}.")
        raise ValueError(f"Node {node} is not infectious, it is in state {self.states[node]['state']}.")

    def get_state(self, node: int) -> int:
        return self.states[node]["state"]

    def get_susceptible_nodes(self) -> set[int]:
        return self.susceptible_nodes.copy()

    def is_infectious_period_over(self, node: int) -> bool:
        return self.states[node]["infection_time_remaining"] == 0

    def is_infectious(self, node: int) -> bool:
        return self.get_state(node) in [1, 3]

    def is_infectable(self, node: int) -> bool:
        return self.get_state(node) in [0, 2]

    def is_epidemic_over(self) -> bool:
        return self.state_counts[1] == 0 and self.state_counts[3] == 0

    def infect(self, node: int) -> None:
        if self.get_state(node) in [1, 3, 4]:
            raise ValueError(f"Node cannot be infected, in state {self.states[node]}.")
        if self.get_state(node) == 2:
            self.state_counts[2] -= 1
            self.state_counts[3] += 1
            self.states[node]["state"] = 3
            self.states[node]["infection_time_remaining"] = self.infection_length
            return
        self.state_counts[0] -= 1
        self.state_counts[1] += 1
        self.states[node]["state"] = 1
        self.susceptible_nodes.remove(node)
        self.states[node]["infection_time_remaining"] = self.infection_length

    def vaccinate(self, node: int) -> None:  # todo
        if self.get_state(node) != 0:
            raise ValueError(f"Node {node} is not vaccinatable, it is in state {self.get_state(node)}.")
        self.state_counts[0] -= 1
        self.state_counts[2] += 1
        self.states[node]["state"] = 2
        self.susceptible_nodes.remove(node)

    def remove(self, node: int) -> None:
        if self.get_state(node) not in [1, 3]:
            raise ValueError(
                f"Node {node} cannot be removed, as it is not infection. It is in state {self.states[node]['state']}.")
        if self.get_state(node) == 1:
            self.state_counts[1] -= 1
        else:
            self.state_counts[3] -= 1
        self.state_counts[4] += 1
        self.states[node]["state"] = 4
        self.states[node]["infection_time_remaining"] = 0

    def increment_tick_count(self) -> None:
        self.tick_count += 1

    def reduce_infection_time_remaining(self, node: int) -> None:
        if self.get_state(node) not in [1, 3]:
            raise ValueError(f"Node {node} is not infection, it has state {self.get_state(node)}.")
        self.states[node]["infection_time_remaining"] -= 1

    def tick(self) -> None:
        if self.is_epidemic_over():
            raise Exception("Model can not be ticked, epidemic is over.")
        self.increment_tick_count()

        # We decrement the infection time before infecting people, as they don't get infected until next tick
        for node in self.network:
            if self.get_state(node) not in [1, 3]:
                continue
            self.reduce_infection_time_remaining(node)

        # Find nodes to infect by next tick
        nodes_to_infect = set()
        for node in self.network:
            if not self.is_infectious(node):
                continue
            for neighbour in self.network[node]:
                if not self.is_infectable(neighbour):
                    continue
                if random.random() < self.get_transmission_probability(node, neighbour):
                    nodes_to_infect.add(neighbour)

        # Remove the nodes that are no longer infection
        for node in self.network:
            if self.get_state(node) not in [1, 3]:
                continue
            if self.is_infectious_period_over(node):
                self.remove(node)
                continue

        # Infect nodes
        for node in nodes_to_infect:
            self.infect(node)

        # Move people to vaccinated
        if self.get_tick_count() < self.vaccination_tick_threshold:
            return
        if self.vaccination_priority is None:
            susceptible_nodes = self.get_susceptible_nodes()
            if len(susceptible_nodes) < self.vaccination_per_tick:
                for node in susceptible_nodes:
                    self.vaccinate(node)
                return
            for node in random.sample(tuple(susceptible_nodes), self.vaccination_per_tick):
                self.vaccinate(node)
            return
        to_vaccinate = set()
        while len(to_vaccinate) < self.vaccination_per_tick and len(self.vaccination_priority) > 0:
            node = self.vaccination_priority.pop(0)
            if self.get_state(node) != 0:
                continue
            to_vaccinate.add(node)
        for node in to_vaccinate:
            self.vaccinate(node)


In [5]:
def print_state_counts(state_counts: dict[int, int], column_width: int) -> None:
    cols = [
        ["S", str(state_counts[0])],
        ["I", str(state_counts[1])],
        ["V", str(state_counts[2])],
        ["VI", str(state_counts[3])],
        ["R", str(state_counts[4])]
    ]
    for line_number in range(len(cols[0])):
        for column_number in range(len(cols)):
            print(cols[column_number][line_number].ljust(column_width), end="")
        print()


def get_state_counts_text(state_counts: dict[int, int], column_width: int) -> None:
    output_string = ""
    cols = [
        ["S", str(state_counts[0])],
        ["I", str(state_counts[1])],
        ["V", str(state_counts[2])],
        ["VI", str(state_counts[3])],
        ["R", str(state_counts[4])]
    ]
    for line_number in range(len(cols[0])):
        for column_number in range(len(cols)):
            output_string += cols[column_number][line_number].ljust(column_width)
        output_string += "\n"
    return output_string


def print_model_status(epidemic_model: EpidemicModel, total_tick_count: int, column_width: int,
                       finished: bool = True) -> None:
    if epidemic_model.is_epidemic_over():
        print_status_bar(1, block_count=100)
    else:
        print_status_bar(epidemic_model.get_tick_count() / total_tick_count, block_count=100)
    print_state_counts(epidemic_model.get_state_counts(), column_width)
    if finished:
        print(f"Total ticks {epidemic_model.get_tick_count()}")
        return
    print(f"Tick {epidemic_model.get_tick_count()}")


def run_model(epidemic_model: EpidemicModel, tick_count: int, column_width: int = 20, quiet=False) -> list[list[int]]:
    state_counts_list = []
    for _, state_count in epidemic_model.get_state_counts().items():
        state_counts_list.append([state_count])

    if not quiet:
        print_model_status(epidemic_model, tick_count, column_width=column_width, finished=False)
    for _ in range(tick_count):
        if epidemic_model.is_epidemic_over():
            break
        epidemic_model.tick()
        state_counts = epidemic_model.get_state_counts()
        for state, state_count in state_counts.items():
            state_counts_list[state].append(state_count)
        if not quiet:
            print_model_status(epidemic_model, tick_count, column_width=column_width, finished=False)
    if not quiet:
        print_model_end(epidemic_model)
    return state_counts_list


def print_model_end(epidemic_model: EpidemicModel) -> None:
    if not epidemic_model.is_epidemic_over():
        raise Exception("Epidemic not over!")
    clear_output(wait=True)
    print("Model ended with the following state distribution.")
    print_state_counts(epidemic_model.get_state_counts(), column_width=10)
    print(f"Total tick count: {epidemic_model.get_tick_count()}")


def get_model_end_text(epidemic_model: EpidemicModel) -> str:
    if not epidemic_model.is_epidemic_over():
        raise Exception("Epidemic not over!")
    output_string = ""
    output_string += "Model ended with the following state distribution.\n"
    output_string += get_state_counts_text(epidemic_model.get_state_counts(), column_width=10)
    output_string += f"Total tick count: {epidemic_model.get_tick_count()}"
    return output_string


def plot_epidemic_model_results(states_over_time: list[list[int]], title: str = "", max_x: int = None,
                                max_y: int = None, figure_size=(20, 10)) -> None:
    legend_names = ["S", "I", "V", "VI", "R"]
    x = np.linspace(0, len(states_over_time[0]), len(states_over_time[0]))
    plt.rcParams["figure.figsize"] = figure_size
    if max_x is None:
        plt.xlim([0, len(states_over_time[0])])
    else:
        plt.xlim([0, max_x])
    if max_y is not None:
        plt.ylim([0, max_y])
    plt.xlabel("Tick number", fontsize=20)
    plt.ylabel("Count", fontsize=20)
    plt.xticks(fontsize=20)
    plt.yticks(fontsize=20)
    plt.title(title, fontsize=20)
    for i, (states, legend_name) in enumerate(zip(states_over_time, legend_names)):
        if i in [0]: continue
        plt.plot(x, states, label=legend_name)
    plt.legend(loc="best", prop={'size': 20})
    plt.show()

In [6]:
random.seed(69420)
ep_model_1 = EpidemicModel(vdws_network, infection_length=3)
ep_model_1_results = run_model(ep_model_1, tick_count=500)

Model ended with the following state distribution.
S         I         V         VI        R         
9822      0         78620     0         111558    
Total tick count: 311


In [7]:
random.seed(69420)
ep_model_2 = EpidemicModel(vdws_network, piv=0.005, pvis=0.005, pviv=0.0025, infection_length=3)
ep_model_2_results = run_model(ep_model_2, tick_count=500)

Model ended with the following state distribution.
S         I         V         VI        R         
103652    0         72040     0         24308     
Total tick count: 234


In [76]:
print(ep_model_2_results[0][-1])

103652


In [67]:
x_max = max([len(ep_model_1_results[0]), len(ep_model_2_results[0])])

In [77]:
x = list(range(len(ep_model_1_results[0])))
y1 = ep_model_1_results[0]
y2 = ep_model_1_results[1]
y3 = ep_model_1_results[2]
y4 = ep_model_1_results[3]
y5 = ep_model_1_results[4]

df = pd.DataFrame(zip(x, y1, y2, y3, y4, y5), columns=["Tick", "S", "I", "V", "VI", "R"])
df = pd.melt(df, id_vars="Tick")


p = (
        ggplot(df)
        + aes(x="Tick")
        + geom_line(aes(y="value", color="variable"))
        + labs(color="Category")
        + xlab("Tick")
        + ylab("Count")
        + theme(figure_size=(8, 4), dpi=300)
        + xlim(0, x_max)
)
p.save("disease-model-no-vaccine.pdf", "pdf")



In [78]:
x = list(range(len(ep_model_1_results[0])))
y1 = ep_model_1_results[0]
y2 = ep_model_2_results[1]
y3 = ep_model_2_results[2]
y4 = ep_model_2_results[3]
y5 = ep_model_2_results[4]

df = pd.DataFrame(zip(x, y1, y2, y3, y4, y5), columns=["Tick", "S", "I", "V", "VI", "R"])
df = pd.melt(df, id_vars="Tick")


p = (
        ggplot(df)
        + aes(x="Tick")
        + geom_line(aes(y="value", color="variable"))
        + labs(color="Category")
        + xlab("Tick")
        + ylab("Count")
        + theme(figure_size=(8, 4), dpi=300)
        + xlim(0, x_max)
)
p.save("disease-model-vaccine.pdf", "pdf")

