## Code for Figure 2

In [None]:
from scipy.integrate import solve_ivp
from scipy.sparse.linalg import expm_multiply
import matplotlib.pyplot as plt
import numpy as np
import networkx as nx
from itertools import permutations
from matplotlib.animation import FuncAnimation
from matplotlib import rc
rc('animation', html='jshtml')

# Migration matrix
M = np.array([[-0.1, 0.0, 0.1],
              [0.1, -0.1, 0.0],
              [0.0, 0.1, -0.1]])

# Initial condition on population sizes
n0 = np.array([300, 600, 900])

# Hyperparameters
beta = np.array([0.1, 0.1, 0.1])
gamma = np.array([0.2, 0.2, 0.2])

# Check that the columns of M sum to zero and the parameters make sense
n = M.shape[0]
assert np.allclose(np.sum(M, axis=0), np.zeros(n))
assert len(n0) == n
assert len(beta) == n
assert len(gamma) == n

# Define the time interval (start, stop)
start = 0
stop = 50
num = 50 + 1
domain = np.linspace(start, stop, num)

# Solve the ODE dn/dt = M @ n on the above interval
sol = expm_multiply(M, n0, start=start, stop=stop, num=num, endpoint=True)

# Plot the solutions
fig = plt.figure(figsize=(6,6))
plt.subplot(1,1,1)
plt.title("Population Evolution")
for m in range(sol.shape[1]):
    plt.plot(domain, sol[:,m], label=f'pop {m}')

plt.xlabel(r'$t$')
plt.legend()
plt.show()

fig = plt.figure(figsize=(6,6))

# Plot the migration graph at timestep step
def update(step):
    g = nx.DiGraph()
    n = M.shape[0]

    # Get the data from the migration matrix
    ax = plt.subplot(1,1,1)
    ax.clear()
    ax.margins(0.3)
    g.add_nodes_from(np.arange(n))
    perms = np.array(list(permutations(np.arange(n), 2)))
    weights = M[tuple(perms.T)]

    # Add the migration matrix data to the graph for nonzero weights using the 'weight' attribute
    descr = []
    for i, weight in enumerate(weights):
        if weight > 0:
            descr.append(rf"$m_{{{perms[i,1]}\to{perms[i,0]}}} = {weight}$")

    include_mask = weights > 0
    g.add_weighted_edges_from(zip(perms[:,1][include_mask], perms[:,0][include_mask], descr))

    # Create the layout for the graph drawing
    pos = nx.circular_layout(g)

    nx.draw(g, pos, node_size=5*sol[step], ax=ax,
            node_color=plt.rcParams['axes.prop_cycle'].by_key()['color'][:n],
            with_labels=True)

    labels = nx.get_edge_attributes(g,'weight')

    # Plot the m_{i\to j} labels (label_pos determines the position of the labels)
    nx.draw_networkx_edge_labels(g, pos, edge_labels=labels, label_pos=0.6,
                                bbox=dict(boxstyle='round,pad=0.3', fc='white', ec='gray', lw=1, alpha=0.8))
    plt.title(rf"Migration Graph at time $t={start + step*((stop-start)/(num - 1))}$")

# Plot animation of the migration graph over time
ani = FuncAnimation(fig=fig, func=update, frames=num, interval=100)
ani

## Code for Figure 3

In [None]:
from scipy.integrate import solve_ivp
import numpy as np
from matplotlib import pyplot as plt, animation as ani

ani.writer = ani.writers["ffmpeg"]

# The SIR model with cure
def ode(t, y, beta, gamma, c):
    """
    ODEs for the SIR model with cure
    :param t: time
    :param y: tuple containing S, I, R, D
    :param beta: infection rate
    :param gamma: death rate
    :param c: cure rate
    :return: tuple containing dSdt, dIdt, dRdt, dDdt
    """
    S, I, R, D = y
    dSdt = -beta * S * I
    dIdt = beta * S * I - c(t) * I - gamma * I
    dRdt = c(t) * I
    dDdt = gamma * I
    return dSdt, dIdt, dRdt, dDdt

# Parameters
beta = 0.2
gamma = 0.05
cure_strength = 0.1
y0 = [0.99, 0.01, 0, 0]  # Initial conditions
T = 120 # Total time

# Create figure for animation
fig, ax = plt.subplots(figsize=(8, 6))
ax.set_xlim(0, T)
ax.set_ylim(0, 1)
ax.set_xlabel('Time')
ax.set_ylabel('Proportion')
ax.set_title('SIR Model No Cure')

# Initialize plot objects for S, I, and R
line_s, = ax.plot([], [], label='Susceptible', color='blue')
line_i, = ax.plot([], [], label='Infected', color='red')
line_r, = ax.plot([], [], label='Recovered', color='green')
line_d, = ax.plot([], [], label='Dead', color='black')
# Initialize cure line
line_c, = ax.plot([], [], label='Cure', color='purple')

# Function to update cure rate based on time
def update_cure_rate(cure_time):
    return lambda t: 0 if t < cure_time else cure_strength

# Animation update function
def update(frame):
    cure_time = T - frame  # Vary cure_time from 30 to 0
    c = update_cure_rate(cure_time)

    # Solve the ODEs with the updated cure function
    sol = solve_ivp(ode, [0, T], y0, args=(beta, gamma, c), dense_output=True)
    t = np.linspace(0, T, 1000)
    y = sol.sol(t)

    # Update the line data
    line_s.set_data(t, y[0])
    line_i.set_data(t, y[1])
    line_r.set_data(t, y[2])
    line_d.set_data(t, y[3])

    # Update the cure line
    line_c.set_data([cure_time, cure_time], [0, 1])

    return line_s, line_i, line_r, line_d

# Set up the animation
animation = ani.FuncAnimation(
    fig, update, frames=T+1, interval=100, blit=True
)

# Add legend
ax.legend(loc='upper right')

# Show animation
animation.save('sir_with_cure.mp4')

## Code for Figures 4 and 5

In [None]:
from scipy.integrate import solve_ivp
import numpy as np
from matplotlib import pyplot as plt, animation as ani

ani.writer = ani.writers["ffmpeg"]

# Multi-city SIR model with cure
def multi_city_ode(t, y, beta, gamma, cure_rate, cure_time, m):
    """
    Multi-city SIR model with cure and movement.
    :param t: time
    :param y: flattened array containing [S1, S2, ..., I1, I2, ..., R1, R2, ..., D1, D2, ..., C1, C2, ...]
    :param beta: infection rate
    :param gamma: death rate
    :param m: movement matrix (NxN)
    :param cure_intro_time: time when cure is introduced
    :param cure_intro_city: city where cure is introduced
    :return: flattened array of derivatives for [S1, I1, R1, D1, C1,...]
    """
    # Extract variables
    num_cities = len(y) // 5
    S = y[:num_cities]
    I = y[num_cities:2 * num_cities]
    R = y[2 * num_cities:3 * num_cities]
    D = y[3 * num_cities:4 * num_cities]
    C = y[4 * num_cities:]

    # Initialize derivatives
    dSdt = np.zeros(num_cities)
    dIdt = np.zeros(num_cities)
    dRdt = np.zeros(num_cities)
    dDdt = np.zeros(num_cities)
    dCdt = np.zeros(num_cities)

    # Loop over cities
    for i in range(num_cities):
        # Introduce cure
        if t > cure_time and C[i] == 0:
            dCdt[i] += production_rate

        # Calculate inflow and outflow for each compartment
        inflow_S = sum(m[j, i] * S[j] for j in range(num_cities) if j != i)
        outflow_S = sum(m[i, j] * S[i] for j in range(num_cities) if j != i)

        inflow_I = sum(m[j, i] * I[j] for j in range(num_cities) if j != i)
        outflow_I = sum(m[i, j] * I[i] for j in range(num_cities) if j != i)

        inflow_R = sum(m[j, i] * R[j] for j in range(num_cities) if j != i)
        outflow_R = sum(m[i, j] * R[i] for j in range(num_cities) if j != i)

        # Calculate derivatives
        dSdt[i] = -beta * S[i] * I[i] - outflow_S + inflow_S
        dIdt[i] = beta * S[i] * I[i] - gamma * I[i] - outflow_I + inflow_I

        # ensure cure is always positive
        if C[i] > 0:
            dIdt[i] -= cure_rate * C[i] * I[i]
            dRdt[i] += cure_rate * C[i] * I[i]

        dRdt[i] += - outflow_R + inflow_R

        dDdt[i] = gamma * I[i]

        inflow_C = sum(m[j, i] * C[j] for j in range(num_cities) if j != i and C[j] > 0)
        outflow_C = sum(m[i, j] * C[i] for j in range(num_cities) if j != i and C[i] > 0)
        dCdt[i] += inflow_C - outflow_C

    return np.concatenate([dSdt, dIdt, dRdt, dDdt, dCdt])


# Parameters
cure_time = 120
num_cities = 3
beta = 0.5
gamma = 0.05
production_rate = 500
cure_rate = 10000

T = 120  # Total time
m = np.array([  # Movement matrix (NxN)
    [0, 0.01, 0.02],
    [0.01, 0, 0.01],
    [0.02, 0.01, 0]
])
# m = np.zeros((num_cities, num_cities))

# Initial conditions for each city
y0 = np.array([
    0.7, 0.2, 0.09, # S1, S2, S3
    0, 0, 0.01,    # I1, I2, I3
    0, 0, 0,     # R1, R2, R3
    0, 0, 0,    # D1, D2, D3
    0, 0, 0    # C1, C2, C3
])

# Time points for integration
t_span = (0, T)
t_eval = np.linspace(0, T, 5000)

# Solve the ODE
solution = solve_ivp(
    multi_city_ode, t_span, y0, args=(beta, gamma, cure_rate, cure_time, m), t_eval=t_eval, dense_output=True
)

# Extract results
y = solution.y
S = y[:num_cities]
I = y[num_cities:2 * num_cities]
R = y[2 * num_cities:3 * num_cities]
D = y[3 * num_cities:4 * num_cities]
C = y[4 * num_cities:]

# Plot results
fig, axs = plt.subplots(4, 1, figsize=(10, 12), sharex=True)
time = solution.t
for i in range(num_cities):
    axs[0].plot(time, S[i], label=f'City {i + 1}')
    axs[1].plot(time, I[i], label=f'City {i + 1}')
    axs[2].plot(time, R[i], label=f'City {i + 1}')
    axs[3].plot(time, D[i], label=f'City {i + 1}')
# plot totals black dashed line
axs[0].plot(time, np.sum(S, axis=0), 'k--', label='Total Susceptible')
axs[1].plot(time, np.sum(I, axis=0), 'k--', label='Total Infected')
axs[2].plot(time, np.sum(R, axis=0), 'k--', label='Total Recovered')
axs[3].plot(time, np.sum(D, axis=0), 'k--', label='Total Deceased')

# add vertical line for cure introduction
for ax in axs:
    ax.axvline(x=cure_time, color='r', linestyle='--', label='Cure Introduced')

# Add labels and legends
for ax, var in zip(axs, ["Susceptible", "Infected", "Recovered", "Deceased"]):
    ax.set_ylim(0, 1)
    ax.set_ylabel(var)
    ax.legend()
axs[-1].set_xlabel("Time")

plt.tight_layout()
plt.show()