In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
import random
import numpy as np
import networkx as nx

setup_seed = 250
np.random.seed(setup_seed)
random.seed = setup_seed
simulation_random_instance = random.Random(setup_seed)
plt.rcParams["figure.figsize"] = (15, 15)

Rewiring_Prob = 0.3
Nodes = 150
Nearest_Neighbors = 6

simulation_time = 150
recovery_chance = 0.3
lose_recovery_chance = 0.4
infection_coefficient = 0.5
death_chance = 0.1


graph = nx.watts_strogatz_graph(Nodes, Nearest_Neighbors, Rewiring_Prob, setup_seed)

# Assign node attributes
patient_zero_ind = simulation_random_instance.randint(0, Nodes-1)
nx.set_node_attributes(graph, {1: "Infected", 2: "Susceptible", 3: "Recovered", 4: "Deceased"}, "status")
for node in graph.nodes():
    graph.nodes[node]["status"] = "Susceptible"
    graph.nodes[node]["time_since_infected"] = -1
    graph.nodes[node]["time_since_recovered"] = -1
graph.nodes[patient_zero_ind]['status'] = 'Infected'
graph.nodes[patient_zero_ind]["time_since_infected"] = 0
graph.nodes[patient_zero_ind]["time_since_recovered"] = -1

# Assign random edge strengths (weights) between 0.1 and 1.0
for u, v in graph.edges():
    graph[u][v]['weight'] = round(simulation_random_instance.uniform(0.01, 1.0), 2)


color_map = {
    "Susceptible": "blue",
    "Infected": "red",
    "Recovered": "lime",
    "Deceased": "black"
}

# Convert to a dictionary with edge strengths
edge_strengths = [(u, v, graph[u][v]['weight']) for u, v in graph.edges()]
edge_strengths_numbers = [graph[u][v]['weight']*3 for u, v in graph.edges()]

node_colors = [color_map[graph.nodes[node]["status"]] for node in graph.nodes]
pos = nx.spring_layout(graph)

def get_next_node_status(curr_status, days_passed):
    if curr_status == "Deceased":
        return "Deceased"
    elif curr_status == "Infected":
        for i in range(days_passed):
            if simulation_random_instance.random() < recovery_chance:
                return "Recovered"
            if simulation_random_instance.random() < death_chance:
                return "Deceased"
    elif curr_status == "Recovered":
        for i in range(days_passed):
            if simulation_random_instance.random() < lose_recovery_chance:
                return "Susceptible"
    return curr_status


def animate(i):
    if i > 0:
        print(str(int(100*i/simulation_time)) + "%\n")
        for node in graph.nodes(): #all nodes
            if graph.nodes[node]["status"] == 'Infected': #check infected
                for neighbor in list(graph.neighbors(node)): #infect
                    infect_chance = simulation_random_instance.random()
                    if infection_coefficient * graph[neighbor][node]['weight'] >= infect_chance and graph.nodes[neighbor]["status"] == "Susceptible":
                        graph.nodes[neighbor]['status'] = 'Infected'
                        graph.nodes[neighbor]["time_since_infected"] = 0

                # check if should be recovered
                graph.nodes[node]["time_since_infected"] += 1
                next_status = get_next_node_status(graph.nodes[node]["status"], graph.nodes[node]["time_since_infected"])
                if next_status == "Recovered" and graph.nodes[node]["time_since_infected"] > 1:
                    graph.nodes[node]["status"] = next_status
                    graph.nodes[node]["time_since_recovered"] = 0
                    graph.nodes[node]["time_since_infected"] = -1
                elif next_status == "Deceased" and graph.nodes[node]["time_since_infected"] > 1:
                    graph.nodes[node]["status"] = next_status
                    graph.nodes[node]["time_since_infected"] = -1
                
            elif graph.nodes[node]["status"] == "Recovered":
                graph.nodes[node]["time_since_recovered"] += 1
                next_status = get_next_node_status(graph.nodes[node]["status"], graph.nodes[node]["time_since_recovered"])
                if next_status != graph.nodes[node]["status"] and graph.nodes[node]["time_since_recovered"] > 1:
                    graph.nodes[node]["status"] = next_status
                    graph.nodes[node]["time_since_recovered"] = -1
                        
    node_colors = [color_map[graph.nodes[node]["status"]] for node in graph.nodes]
    nx.draw(graph, pos, with_labels=False, node_color=node_colors, edge_color="gray", node_size=125, width=edge_strengths_numbers)
    ax.text(0.5, 1.100, "Day " + str(i), transform=ax.transAxes, ha="center", fontsize=30, backgroundcolor="1")    

fig = plt.figure(figsize=(16, 9), dpi=(1920/16))
ax = plt.gca()
anim = animation.FuncAnimation(fig, animate, frames=simulation_time, interval=20)
anim.save('test_animation.mp4', fps=simulation_time/30, extra_args=['-vcodec', 'libx264'])
