# Diffusive toggle switch

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from pysb.integrate import odesolve
# from scripts.models.cascade_pysb import model

from scripts.grid_class import GridParms
from scripts.tree_class import Tree
from scripts.output.output_helper import *
from scripts.reference_solutions.ssa_helper import SSASol

plt.style.use("./scripts/output/notebooks/custom_style.mplstyle")
%matplotlib inline

In [None]:
# concentrations_ode = odesolve(model, np.arange(1.0))
# slice_vec = []
# for o in model.observables:
#     slice_vec.append(int(np.round(concentrations_ode[o.name][-1])))
# slice_vec = np.array(slice_vec)

## Load initial data

# Tree

In [None]:
tree_r5 = readTree("output/cascade_r5_e_tau1e-1/output_t3500.nc")
_, ttn_marginal_r_5 = tree_r5.calculateObservables(np.zeros(20, dtype="int"))

tree_r7 = readTree("output/cascade_r7_e_tau1e-1/output_t3500.nc")
_, ttn_marginal_r_7 = tree_r7.calculateObservables(np.zeros(20, dtype="int"))

# tree_r12 = readTree("output/cascade_r12_e_tau1e-1/output_t3500.nc")
# _, ttn_marginal_r_12 = tree_r12.calculateObservables(np.zeros(20, dtype="int"))

### SSA

In [None]:
idx_2D = np.array([0, 1])
slice_vec = np.zeros(20, dtype="int")

In [None]:
ssa_1e4 = np.load("scripts/reference_solutions/cascade_ssa_1e+04.npy")
ssa_1e4_sol = SSASol(ssa_1e4)
ssa_marginal_1e4, _, _, _ = ssa_1e4_sol.calculateObservables(slice_vec, idx_2D)
ssa_wall_time = 50.3

In [None]:
ssa_1e5 = np.load("scripts/reference_solutions/cascade_ssa_1e+05.npy")
ssa_1e5_sol = SSASol(ssa_1e5)
ssa_marginal_1e5, _, _, _ = ssa_1e5_sol.calculateObservables(slice_vec, idx_2D)
ssa_wall_time = 505.6

In [None]:
ssa_1e6 = np.load("scripts/reference_solutions/cascade_ssa_1e+06.npy")
ssa_1e6_sol = SSASol(ssa_1e6)
ssa_marginal_1e6, _, _, _ = ssa_1e6_sol.calculateObservables(slice_vec, idx_2D)
ssa_wall_time = 4504.3

In [None]:
ssa_1e7 = np.load("scripts/reference_solutions/cascade_ssa_1e+07.npy")
ssa_1e7_sol = SSASol(ssa_1e7)
ssa_marginal_1e7, _, _, _ = ssa_1e7_sol.calculateObservables(slice_vec, idx_2D)
ssa_wall_time = 43828.32

## Convergence with increasing rank

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,8))
ax1.plot(np.arange(tree_r5.grid.n[18]), ttn_marginal_r_5[18], "v", label="TTN, $r=5$")
ax1.plot(np.arange(tree_r7.grid.n[18]), ttn_marginal_r_7[18], "v", label="TTN, $r=7$")
# ax1.plot(np.arange(tree_r12.grid.n[0]), ttn_marginal_r_12[0], "v", label="TTN, $r=12$")

# ax1.plot(np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0], ssa_marginal_1e5[-1][0], "rx", label="SSA, $10^5$ runs")
ax1.plot(np.arange(ssa_1e7_sol.n[18])+ssa_1e7_sol.n_min[18], ssa_marginal_1e7[-1][18], "k.", label="SSA, $10^7$ runs")
ax1.set_xlabel("$x_0$")

ax2.plot(np.arange(tree_r5.grid.n[19]), ttn_marginal_r_5[19], "v", label="TTN, $r=5$")
ax2.plot(np.arange(tree_r7.grid.n[19]), ttn_marginal_r_7[19], "v", label="TTN, $r=7$")
# ax2.plot(np.arange(tree_r12.grid.n[1]), ttn_marginal_r_12[1], "v", label="TTN, $r=12$")
# ax2.plot(np.arange(ssa_1e5_sol.n[1])+ssa_1e5_sol.n_min[1], ssa_marginal_1e5[-1][1], "rx", label="SSA, $10^5$ runs")
ax2.plot(np.arange(ssa_1e7_sol.n[19])+ssa_1e7_sol.n_min[19], ssa_marginal_1e7[-1][19], "k.", label="SSA, $10^7$ runs")
ax2.set_xlabel("$x_1$")

ax2.legend()
plt.savefig("plots/cascade_fig1.pdf")