# Cascade

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

from scripts.grid_class import GridParms
from scripts.tree_class import Tree
from scripts.notebooks.output_helper import *
from scripts.reference_solutions.ssa_helper import SSASol

plt.style.use("./scripts/notebooks/custom_style.mplstyle")
%matplotlib inline

In [None]:
slice_vec = np.zeros(20, dtype="int")

## Load initial data

# Tree

## Binary tree format

In [None]:
tree_p0_r5 = readTree("output/cascade_bt_r5_e_tau1e-1/output_t3500.nc")
_, ttn_marginal_p0_r5 = tree_p0_r5.calculateObservables(slice_vec)

tree_p0_r6 = readTree("output/cascade_bt_r6_e_tau1e-1/output_t3500.nc")
_, ttn_marginal_p0_r6 = tree_p0_r6.calculateObservables(slice_vec)

tree_p0_r7 = readTree("output/cascade_bt_r7_e_tau1e-1/output_t3500.nc")
_, ttn_marginal_p0_r7 = tree_p0_r7.calculateObservables(slice_vec)

In [None]:
time_series = TimeSeries("output/cascade_bt_r5_e_tau1e-1")
walltime_ttn_p0_r5 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/cascade_bt_r6_e_tau1e-1")
walltime_ttn_p0_r6 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/cascade_bt_r7_e_tau1e-1")
walltime_ttn_p0_r7 = time_series.getWallTime()

## Tensor train format

In [None]:
tree_p1_r5 = readTree("output/cascade_tt_r5_e_tau1e-1/output_t3500.nc")
_, ttn_marginal_p1_r5 = tree_p1_r5.calculateObservables(slice_vec)

tree_p1_r6 = readTree("output/cascade_tt_r6_e_tau1e-1/output_t3500.nc")
_, ttn_marginal_p1_r6 = tree_p1_r6.calculateObservables(slice_vec)

tree_p1_r7 = readTree("output/cascade_tt_r7_e_tau1e-1/output_t3500.nc")
_, ttn_marginal_p1_r7 = tree_p1_r7.calculateObservables(slice_vec)

In [None]:
time_series = TimeSeries("output/cascade_tt_r5_e_tau1e-1")
walltime_ttn_p1_r5 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/cascade_tt_r6_e_tau1e-1")
walltime_ttn_p1_r6 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/cascade_tt_r7_e_tau1e-1")
walltime_ttn_p1_r7 = time_series.getWallTime()

### SSA

In [None]:
idx_2D = np.array([0, 1])
slice_vec = np.zeros(20, dtype="int")

In [None]:
with np.load("scripts/reference_solutions/cascade_ssa_1e+04.npz") as data:
    ssa_1e4 = data["result"]
    walltime_ssa_1e4 = data["wall_time"]

ssa_1e4_sol = SSASol(ssa_1e4)
ssa_marginal_1e4, _, _, _ = ssa_1e4_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
with np.load("scripts/reference_solutions/cascade_ssa_1e+05.npz") as data:
    ssa_1e5 = data["result"]
    walltime_ssa_1e5 = data["wall_time"]

ssa_1e5_sol = SSASol(ssa_1e5)
ssa_marginal_1e5, _, _, _ = ssa_1e5_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
with np.load("scripts/reference_solutions/cascade_ssa_1e+06.npz") as data:
    ssa_1e6 = data["result"]
    walltime_ssa_1e6 = data["wall_time"]

ssa_1e6_sol = SSASol(ssa_1e6)
ssa_marginal_1e6, _, _, _ = ssa_1e6_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
with np.load("scripts/reference_solutions/cascade_ssa_1e+07.npz") as data:
    ssa_1e7 = data["result"]

ssa_1e7_sol = SSASol(ssa_1e7)
ssa_marginal_1e7, _, _, _ = ssa_1e7_sol.calculateObservables(slice_vec, idx_2D)

# Error between SSA and the SSA reference solution

In [None]:
d = len(ssa_1e7_sol.n)
SSA_marginal_err_1e4 = np.zeros(d)
SSA_marginal_err_1e5 = np.zeros(d)
SSA_marginal_err_1e6 = np.zeros(d)

for i in range(d):
    n_start_1e4 = ssa_1e4_sol.n_min[i] - ssa_1e7_sol.n_min[i]
    n_start_1e5 = ssa_1e5_sol.n_min[i] - ssa_1e7_sol.n_min[i]
    n_start_1e6 = ssa_1e6_sol.n_min[i] - ssa_1e7_sol.n_min[i]

    SSA_marginal_err_1e4[i] = np.linalg.norm(ssa_marginal_1e7[-1][i][n_start_1e4 : n_start_1e4+ssa_1e4_sol.n[i]] - ssa_marginal_1e4[-1][i][:ssa_1e7_sol.n[i]])
    SSA_marginal_err_1e5[i] = np.linalg.norm(ssa_marginal_1e7[-1][i][n_start_1e5 : n_start_1e5+ssa_1e5_sol.n[i]] - ssa_marginal_1e5[-1][i][:ssa_1e7_sol.n[i]])
    SSA_marginal_err_1e6[i] = np.linalg.norm(ssa_marginal_1e7[-1][i][n_start_1e6 : n_start_1e6+ssa_1e6_sol.n[i]] - ssa_marginal_1e6[-1][i][:ssa_1e7_sol.n[i]])

# Error between the TTN solution and the SSA reference solution

In [None]:
def marginal_err(marginal_tt, tree: Tree):
    marginal_err = np.array([np.linalg.norm(marginal_tt[tree.species_names[i]][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - ssa_marginal_1e7[-1][i][:tree.grid.n[i]], ord=None) for i in range(tree.grid.d())])
    return marginal_err

In [None]:
marginal_err_p0_r5_SSA = marginal_err(ttn_marginal_p0_r5, tree_p0_r5)
marginal_err_p0_r6_SSA = marginal_err(ttn_marginal_p0_r6, tree_p0_r6)
marginal_err_p0_r7_SSA = marginal_err(ttn_marginal_p0_r7, tree_p0_r7)

marginal_err_p1_r5_SSA = marginal_err(ttn_marginal_p1_r5, tree_p1_r5)
marginal_err_p1_r6_SSA = marginal_err(ttn_marginal_p1_r6, tree_p1_r6)
marginal_err_p1_r7_SSA = marginal_err(ttn_marginal_p1_r7, tree_p1_r7)

# Plot the results

## With clustering

In [None]:
walltimes = [walltime_ssa_1e6, walltime_ssa_1e5, walltime_ssa_1e4, walltime_ttn_p1_r5, walltime_ttn_p1_r6, walltime_ttn_p1_r7]
labels_walltime = ["$10^6$ runs", "$10^5$ runs", "$10^4$ runs", "$r = 5$", "$r = 6$", "$r = 7$"]#, "$r = 8$"]

colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
color = [colors[1], colors[1], colors[1], colors[0], colors[0], colors[0]]#, colors[0]]
bar_labels = ['SSA', '_SSA', '_SSA', 'PS-TTN', '_PS-TTN', '_PS-TTN']#, '_PS-TTN'] #, '_PS-TTN integrator']

In [None]:
gs = gridspec.GridSpec(1, 30)

fig = plt.figure(figsize=(8.5, 3))
ax00 = plt.subplot(gs[0, :9])
ax00.plot(np.arange(tree_p1_r5.grid.d()), marginal_err_p1_r5_SSA, '.-', label="$r = 5$")
ax00.plot(np.arange(tree_p1_r6.grid.d()), marginal_err_p1_r6_SSA, '.-', label="$r = 6$")
ax00.plot(np.arange(tree_p1_r7.grid.d()), marginal_err_p1_r7_SSA, '.-', label="$r = 7$")
ax00.set_yscale("log")
ax00.set_xlabel("species $S_i$")
ax00.legend()
ax00.set_title("$\Vert P_M^\mathrm{{TTN}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")

ax01 = plt.subplot(gs[0, 9:18])
ax01.plot(np.arange(tree_p1_r5.grid.d()), SSA_marginal_err_1e4, 'x-', label="$10^4$ runs")
ax01.plot(np.arange(tree_p1_r5.grid.d()), SSA_marginal_err_1e5, 'x-', label="$10^5$ runs")
ax01.plot(np.arange(tree_p1_r5.grid.d()), SSA_marginal_err_1e6, 'x-', label="$10^6$ runs")
ax01.set_yscale("log")
ax01.set_xlabel("species $S_i$")
ax01.legend()
ax01.set_title("$\Vert P_M^\mathrm{{SSA}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")
ax01.set_yticklabels([])

plt.setp((ax00, ax01), ylim=[1e-5, 1e-1], xticks=[0, 4, 8, 12, 16, 19])

ax10 = plt.subplot(gs[0:, 22:])
ax10.barh(labels_walltime, walltimes, label=bar_labels, color=color)
# ax10.set_xscale("log")
ax10.set_xlabel("wall time [$\mathrm{{s}}$]")
ax10.legend()

plt.subplots_adjust(wspace=0.3)
# plt.tight_layout()

plt.savefig("plots/cascade_comparison_marginal_ttn_ssa.pdf", bbox_inches="tight");

In [None]:
fig, axs = plt.subplots(1, 2, figsize=(7, 3))
axs[0].plot(np.arange(tree_p1_r5.grid.d()), marginal_err_p1_r5_SSA, '.-', label="$r = 5$")
axs[0].plot(np.arange(tree_p1_r6.grid.d()), marginal_err_p1_r6_SSA, '.-', label="$r = 6$")
axs[0].plot(np.arange(tree_p1_r7.grid.d()), marginal_err_p1_r7_SSA, '.-', label="$r = 7$")
axs[0].set_yscale("log")
axs[0].set_xlabel("species $S_i$")
axs[0].set_title("$\Vert P_M^\mathrm{{TT}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")

axs[1].plot(np.arange(tree_p0_r5.grid.d()), marginal_err_p0_r5_SSA, '.-')
axs[1].plot(np.arange(tree_p0_r6.grid.d()), marginal_err_p0_r6_SSA, '.-')
axs[1].plot(np.arange(tree_p0_r7.grid.d()), marginal_err_p0_r7_SSA, '.-')
axs[1].set_yscale("log")
axs[1].set_xlabel("species $S_i$")
axs[1].set_title("$\Vert P_M^\mathrm{{BT}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")
axs[1].yaxis.tick_right()
axs[1].yaxis.set_ticks_position("both")
axs[1].yaxis.set_label_position("right")

fig.legend(loc="center", ncol=3, bbox_to_anchor=(0.5, 1.02))
plt.subplots_adjust(wspace=0.02)
plt.setp(axs, ylim=[1e-4, 3e-2], xticks=[0, 4, 8, 12, 16, 19])

plt.savefig("plots/cascade_comparison_marginal_tt_bt.pdf", bbox_inches="tight");

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 6))
axs[0, 0].plot(np.arange(tree_p1_r5.grid.d()), marginal_err_p1_r5_SSA, '.-', label="$r = 5$")
axs[0, 0].plot(np.arange(tree_p1_r6.grid.d()), marginal_err_p1_r6_SSA, '.-', label="$r = 6$")
axs[0, 0].plot(np.arange(tree_p1_r7.grid.d()), marginal_err_p1_r7_SSA, '.-', label="$r = 7$")
axs[0, 0].set_yscale("log")
axs[0, 0].set_xlabel("species $S_i$")
axs[0, 0].legend()
axs[0, 0].set_title("$\Vert P_M^\mathrm{{TT}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")

axs[0, 1].plot(np.arange(tree_p0_r5.grid.d()), marginal_err_p0_r5_SSA, '.-', label="$r = 5$")
axs[0, 1].plot(np.arange(tree_p0_r6.grid.d()), marginal_err_p0_r6_SSA, '.-', label="$r = 6$")
axs[0, 1].plot(np.arange(tree_p0_r7.grid.d()), marginal_err_p0_r7_SSA, '.-', label="$r = 7$")
axs[0, 1].set_yscale("log")
axs[0, 1].set_xlabel("species $S_i$")
axs[0, 1].legend()
axs[0, 1].set_title("$\Vert P_M^\mathrm{{BT}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")

axs[1, 0].plot(np.arange(tree_p1_r5.grid.d()), SSA_marginal_err_1e4, 'x-', label="$10^4$ runs")
axs[1, 0].plot(np.arange(tree_p1_r5.grid.d()), SSA_marginal_err_1e5, 'x-', label="$10^5$ runs")
axs[1, 0].plot(np.arange(tree_p1_r5.grid.d()), SSA_marginal_err_1e6, 'x-', label="$10^6$ runs")
axs[1, 0].set_yscale("log")
axs[1, 0].set_xlabel("species $S_i$")
axs[1, 0].legend()
axs[1, 0].set_title("$\Vert P_M^\mathrm{{SSA}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")
axs[1, 0].yaxis.tick_right()
axs[1, 0].yaxis.set_ticks_position("both")
axs[1, 0].yaxis.set_label_position("right")

plt.setp(axs[0, :], ylim=[1e-6, 1e-1], xticks=[0, 4, 8, 12, 16, 19])
plt.setp(axs[1, 0], ylim=[1e-6, 1e-1], xticks=[0, 4, 8, 12, 16, 19])

axs[1, 1].barh(labels_walltime, walltimes, label=bar_labels, color=color)
axs[1, 1].set_xscale("log")
axs[1, 1].set_xlabel("wall time [$\mathrm{{s}}$]")
axs[1, 1].legend(loc="center right", ncols=1)
plt.tight_layout()

plt.savefig("plots/cascade_comparison_marginal_tt_bt_ssa.pdf", bbox_inches="tight");

# Error depending on time step size

# Comparison with deterministic solution

In [None]:
time_series = TimeSeries("output/cascade_tt_r6_e_tau1e-2")
concentrations = time_series.calculateMoments()
t = time_series.time

In [None]:
fig, ax = plt.subplots(figsize=(5, 3))
deviation = {key: np.sqrt(concentrations[1][key]-concentrations[0][key]**2) for key in concentrations[0]}
observables = ["$S_{{0}}$", "$S_{{2}}$", "$S_{{4}}$", "$S_{{8}}$", "$S_{{16}}$"]
idx = ["S0", "S2", "S4", "S8", "S16"]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for j, (i, o) in enumerate(zip(idx, observables)):
    ax.plot(t, concentrations[0][i], '-', label=o, alpha=0.7, color=colors[j])
    ax.fill_between(t, concentrations[0][i]-deviation[i], concentrations[0][i]+deviation[i], alpha=.1, color=colors[j])
ax.set_ylabel("$\langle x_i \\rangle (t)$")
plt.setp(ax, xlabel="$t$", xlim=[0.0, 350.0], ylim=[0.0, 14.0]);

lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(ll, []) for ll in zip(*lines_labels)]
fig.legend(lines, labels, ncols=5, loc="upper center")

plt.savefig("plots/cascade_concentrations.pdf");

# Mass error

In [None]:
time_series = TimeSeries("output/cascade_tt_r5_e_tau1e-1")
mass_err00 = np.abs(time_series.getMassErr())
time = time_series.time

time_series = TimeSeries("output/cascade_tt_r6_e_tau1e-1")
mass_err10 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/cascade_tt_r7_e_tau1e-1")
mass_err20 = np.abs(time_series.getMassErr())

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 6))

axs[0, 0].plot(time, mass_err00, ".-", label="$\mathcal{{P}}_0$")
axs[0, 0].set_title("$r = 5$")

axs[0, 1].plot(time, mass_err10, ".-")
axs[0, 1].set_title("$r = 6$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

axs[1, 0].plot(time, mass_err20, ".-")
axs[1, 0].set_title("$r = 7$")

plt.setp(axs, xlabel="$t$", ylim=[3e-8, 5e-2], yscale="log")
plt.subplots_adjust(hspace=0.5)
fig.suptitle("$|\Delta m(t)|$", fontsize=16, y=1.05)
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="center", bbox_to_anchor=(0.5, 0.95))
plt.tight_layout()
plt.savefig("plots/cascade_mass_err_comparison_ttn.pdf");