In [None]:
from importing import *
from TravelAndMutate.randominterface import NumpyRandomGenerator
from TravelAndMutate.system import SystemMix as System

In [None]:
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from time import time
from scipy.optimize import root_scalar

In [None]:
seed = np.prod([ord(c) for c in "TravelAndMutate"]).astype("uint32")
rng = NumpyRandomGenerator(seed)

In [None]:
def kernel(rng, tmax, params, commuting_mat):
    sys = System(rng.cpprng, commuting_mat, params.to_records(index=False))
    sys.spreadForTime(tmax)
    return [pd.DataFrame.from_records(sys.getFullTrajectory(p), index="t") / params["N"][p] for p in range(params.shape[0])]

In [None]:
def Rinf_equation(Rinf, R0, s0):
    return 1 - Rinf - s0 * np.exp(-Rinf*R0)

In [None]:
start = time()

Npatch = 5

commuting = np.where(np.eye(Npatch).astype(bool), 1, 0)
display(pd.DataFrame(commuting))

tmax = 1500

params = pd.DataFrame()
params["N"] = np.full(Npatch, 1e6, dtype="u4")
params["beta"] = np.full(Npatch, 0.06)
params["epsilon"] = np.full(Npatch, 0.1)
params["mu"] = np.full(Npatch, 0.03)
params["I0"] = np.full(Npatch, 2, dtype="u4")
params.index.name = "Patch"
display(params)

reps = 100
avgs = []
for i in range(reps):
	result = kernel(rng, tmax, params, commuting)
	avgs.append([r["R"].iloc[-1] for r in result])
avgs = np.array(avgs)

end = time()
print(f"Duration per replica: {np.round((end-start)/reps, 2)} sec")

params["R0"] = commuting @ (params["beta"] / params["mu"])

print()
print(f"Measured survived {np.sum(avgs > 0.2)/reps/Npatch}, expected {1-(1-1/params['R0'][0])**params['I0'][0]}")

fig = plt.figure(figsize=(8,6))
gs = plt.GridSpec(Npatch, 2)

axs = [fig.add_subplot(gs[0,0])]
axs = axs + [fig.add_subplot(gs[i,0], sharex=axs[0]) for i in range(1,Npatch)]
axs = axs + [fig.add_subplot(gs[:,1])]
for i in range(Npatch):
	result[i][["S","E","I","R"]].plot(ax=axs[i])
	axs[i].set_xlabel("Time")
	axs[i].set_ylabel("Fraction")
	Rinf = root_scalar(Rinf_equation, args=(params["R0"][i],1-params["I0"][i]/params["N"][i]), x0=params["R0"][i]).root
	axs[i].axhline(Rinf, color="k", linestyle=":")
	axs[i].set_ylim([-0.05,1.05])
	axs[i].legend(title=f"Patch {i}", ncol=2)

axs[-1].hist(avgs, bins=np.linspace(0,1,11), stacked=True)
axs[-1].set_ylabel("Counts")
axs[-1].set_xlabel(r"$R_\infty$")

fig.tight_layout(h_pad=0.1)