# Lambda phage

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np
from pysb.integrate import odesolve
from scripts.models.lambda_phage_pysb import model

from scripts.grid_class import GridParms
from scripts.tree_class import Tree
from scripts.output.output_helper import *
from scripts.reference_solutions.ssa_helper import SSASol

plt.style.use("./scripts/output/notebooks/custom_style.mplstyle")
%matplotlib inline

In [None]:
concentrations_ode = odesolve(model, np.arange(10.0))
slice_vec = []
for o in model.observables:
    slice_vec.append(int(np.round(concentrations_ode[o.name][-1])))
slice_vec = np.array(slice_vec)

## Load initial data

In [None]:
with np.load("scripts/reference_solutions/lp_ode_ref_r5.npz") as data:
    P_exact = data['P_full']
    P_best_approximation_r5 = data['P_best_approximation']

In [None]:
with np.load("scripts/reference_solutions/lp_ode_ref_r9.npz") as data:
    P_exact = data['P_full']
    P_best_approximation_r9 = data['P_best_approximation']

In [None]:
ssa_1e4 = np.load("scripts/reference_solutions/lp_ssa_1e4.npy")
ssa_1e4_sol = SSASol(ssa_1e4)
P_ssa_1e4 = ssa_1e4_sol.calculateFullDistribution()

In [None]:
ssa_1e5 = np.load("scripts/reference_solutions/lp_ssa_1e5.npy")
ssa_1e5_sol = SSASol(ssa_1e5)
P_ssa_1e5 = ssa_1e5_sol.calculateFullDistribution()

In [None]:
ssa_1e6 = np.load("scripts/reference_solutions/lp_ssa_1e6.npy")
ssa_1e6_sol = SSASol(ssa_1e6)
P_ssa_1e6 = ssa_1e6_sol.calculateFullDistribution()

## Comparison of TTN and SSA results with matrix integrator

### Load data

#### Get walltimes

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau1e-3")
walltime_p0_r_5_3 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau1e-3")
walltime_p0_r_5_4 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-3")
walltime_p0_r_5_5 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r9-5_e_tau1e-3")
walltime_p0_r_9_5 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau1e-3")
walltime_p0_r_9_9 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/lambda_phage_pfull_r5_e_tau1e-3")
walltime_pfull = time_series.getWallTime()

In [None]:
walltime_exact = 1164

In [None]:
walltime_ssa_1e4 = 855
walltime_ssa_1e5 = 905
walltime_ssa_1e6 = 2319

In [None]:
walltimes = [walltime_exact, walltime_ssa_1e6, walltime_ssa_1e5, walltime_ssa_1e4, walltime_p0_r_9_9, walltime_p0_r_5_5, walltime_p0_r_5_4, walltime_p0_r_5_3]
labels = ["exact", "$10^6$ runs", "$10^5$ runs", "$10^4$ runs", "$r = (9, 9)$", "$r = (5, 5)$", "$r = (5, 4)$", "$r = (5, 3)$"]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
color = ["grey", colors[1], colors[1], colors[1], colors[0], colors[0], colors[0], colors[0]]
bar_labels = ['_exact', 'SSA', '_SSA', '_SSA', 'PS-TTN integrator', '_PS-TTN integrator', '_PS-TTN integrator', '_PS-TTN integrator']

### Plots

#### Error between TTN integrator and the matrix integrator reference solution

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau1e-3")
time = time_series.time
P_p0_r5_3_tau_1e3 = time_series.calculateFullDistribution()
err_r_5_3 = np.linalg.norm(P_p0_r5_3_tau_1e3 - P_exact, axis=1)

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau1e-3")
P_p0_r5_4_tau_1e3 = time_series.calculateFullDistribution()
err_r_5_4 = np.linalg.norm(P_p0_r5_4_tau_1e3 - P_exact, axis=1)

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-3")
P_p0_r5_5_tau_1e3 = time_series.calculateFullDistribution()
err_r_5_5 = np.linalg.norm(P_p0_r5_5_tau_1e3 - P_exact, axis=1)

time_series = TimeSeries("output/lambda_phage_p0_r9-5_e_tau1e-3")
P_p0_r9_5_tau_1e3 = time_series.calculateFullDistribution()
err_r_9_5 = np.linalg.norm(P_p0_r9_5_tau_1e3 - P_exact, axis=1)

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau1e-3")
P_p0_r9_9_tau_1e3 = time_series.calculateFullDistribution()
err_r_9_9 = np.linalg.norm(P_p0_r9_9_tau_1e3 - P_exact, axis=1)

In [None]:
def calculateSSAError(P_ssa, P_exact, ssa_sol):
    grid_liml = np.array([0, 0, 0, 0 ,0], dtype="int64")
    grid_n = np.array([16, 41, 11, 11, 11], dtype="int64")

    n_min = np.amax(np.stack((grid_liml, ssa_sol.n_min)), axis=0)
    n_max = np.amin(np.stack((grid_n + grid_liml, ssa_sol.n_max)), axis=0)
    n = n_max - n_min
    dx = np.prod(n)
    t = P_ssa.shape[0]
    P_ssa_red = np.zeros((t, dx))
    P_exact_red = np.zeros((t, dx))

    vec_index = np.zeros(n.size, dtype="int64")
    for i in range(P_exact.shape[1]):
        if np.all(vec_index < n):
            new_index = vecIndexToCombIndex(vec_index, n)
            P_exact_red[:, new_index] = P_exact[:, i]
        incrVecIndex(vec_index, grid_n, grid_n.size)

    vec_index = np.zeros(n.size, dtype="int64")
    for i in range(P_ssa.shape[1]):
        if np.all(vec_index < n):
            new_index = vecIndexToCombIndex(vec_index, n)
            P_ssa_red[:, new_index] = P_ssa[:, i]
        incrVecIndex(vec_index, ssa_sol.n, ssa_sol.n.size)

    return np.linalg.norm(P_ssa_red - P_exact_red, axis=1)

In [None]:
err_ssa_1e4 = calculateSSAError(P_ssa_1e4, P_exact, ssa_1e4_sol)
err_ssa_1e5 = calculateSSAError(P_ssa_1e5, P_exact, ssa_1e5_sol)
err_ssa_1e6 = calculateSSAError(P_ssa_1e6, P_exact, ssa_1e6_sol)

In [None]:
err_best_approximation_r5 = np.linalg.norm(P_best_approximation_r5 - P_exact, axis=1)
err_best_approximation_r9 = np.linalg.norm(P_best_approximation_r9 - P_exact, axis=1)

### Plots

In [None]:
gs = gridspec.GridSpec(10, 8)

fig = plt.figure(figsize=(7, 9))

general_labels = ["best-approximation ($r=5$)", "best-approximation ($r=9$)"]

ax0 = plt.subplot(gs[0:5, :4])
ax0.plot(time[1:], err_r_5_3[1:], '.-', label="$r = (5, 3)$")
ax0.plot(time[1:], err_r_5_4[1:], '.-', label="$r = (5, 4)$")
ax0.plot(time[1:], err_r_5_5[1:], '.-', label="$r = (5, 5)$")
# axs[0].plot(time[1:], err_r_9_5[1:], '.-', label="$r = (9, 5)$")
ax0.plot(time[1:], err_r_9_9[1:], '.-', label="$r = (9, 9)$")
l1, = ax0.plot(time[1:], err_best_approximation_r5[1:], 'k:')
l2, = ax0.plot(time[1:], err_best_approximation_r9[1:], 'k-.')

ax0.set_yscale("log")
ax0.set_xlabel("$t$")
ax0.legend(ncols=1)
ax0.set_title("$\Vert P^\mathrm{{TTN}}(t,\mathbf{{x}})-P^\mathrm{{ref}}(t,\mathbf{{x}}) \Vert$")

ax1 = plt.subplot(gs[0:5, 4:])
ax1.plot(time[1:], err_ssa_1e4[1:], 'x-', label="$10^4$ runs")
ax1.plot(time[1:], err_ssa_1e5[1:], 'x-', label="$10^5$ runs")
ax1.plot(time[1:], err_ssa_1e6[1:], 'x-', label="$10^6$ runs")
ax1.plot(time[1:], err_best_approximation_r5[1:], 'k:')
ax1.plot(time[1:], err_best_approximation_r9[1:], 'k-.')

ax1.set_yscale("log")
ax1.set_xlabel("$t$")
ax1.legend(ncols=1)
ax1.set_title("$\Vert P^\mathrm{{SSA}}(t,\mathbf{{x}})-P^\mathrm{{ref}}(t,\mathbf{{x}}) \Vert$")
ax1.yaxis.tick_right()
ax1.yaxis.set_ticks_position("both")
ax1.yaxis.set_label_position("right")

plt.setp((ax0, ax1), ylim=[2e-5, 2e-1], xticks=[1, 4, 7, 10])

ax2 = plt.subplot(gs[6:, 2:7])
ax2.barh(labels, walltimes, label=bar_labels, color=color)
ax2.set_xscale("log")
ax2.set_xlabel("wall time [$\mathrm{{s}}$]")
ax2.legend()

plt.subplots_adjust(hspace=0.5, wspace=0.0)

fig.legend([l1, l2], general_labels, loc="center", ncols=2, bbox_to_anchor=(0.5, 0.45))
plt.tight_layout()
plt.savefig("plots/comparison_ttn_ssa.pdf");

## Error depending on time step size

In [None]:
dm_max_r_5_6 = []
dm_max_r_5_5 = []
dm_max_r_5_4 = []
dm_max_r_5_3 = []
dm_max_r_9_9 = []
dm_max_r_9_9_i = []

err_r_5_6 = []
err_r_5_5 = []
err_r_5_4 = []
err_r_5_3 = []
err_r_9_9 = []
err_r_9_9_i = []

tau = []

time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau5e-1")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau5e-1")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau5e-1")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau5e-1")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau5e-1")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau5e-1")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())


time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau2e-1")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau2e-1")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau2e-1")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau2e-1")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau2e-1")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau2e-1")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())


time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau1e-1")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-1")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau1e-1")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau1e-1")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau1e-1")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau1e-1")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())


time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau5e-2")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau5e-2")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau5e-2")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau5e-2")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau5e-2")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau5e-2")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())


time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau2e-2")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau2e-2")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau2e-2")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau2e-2")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau2e-2")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau2e-2")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())


time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau1e-2")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-2")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau1e-2")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau1e-2")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau1e-2")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau1e-2")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())


time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau5e-3")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau5e-3")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau5e-3")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau5e-3")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau5e-3")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau5e-3")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())


time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau2e-3")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau2e-3")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau2e-3")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau2e-3")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau2e-3")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau2e-3")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())


time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau1e-3")
dm_max_r_5_6.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_6.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-3")
dm_max_r_5_5.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_5.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau1e-3")
dm_max_r_5_4.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_4.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau1e-3")
dm_max_r_5_3.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_5_3.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau1e-3")
dm_max_r_9_9.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))

time_series = TimeSeries("output/lambda_phage_p0_r9-9_i_tau1e-3")
dm_max_r_9_9_i.append(np.abs(time_series.getMaxMassErr()))
P_ttn = time_series.calculateFullDistribution()
err_r_9_9_i.append(np.linalg.norm(P_ttn[-1] - P_exact[-1], ord=None))
tau.append(time_series.getTau())

### Plot

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4), sharex="row")
t = np.linspace(0.005, 0.5)

general_labels = ["$r=(5, 3)$", "$r=(5, 4)$", "$r=(5, 5)$", "$r=(9, 9)$"]

l1, = ax1.loglog(tau, err_r_5_3, '.-')
l2, = ax1.loglog(tau, err_r_5_4, '.-')
l3, = ax1.loglog(tau, err_r_5_5, '.-')
l4, = ax1.loglog(tau, err_r_9_9, '.-')
ax1.loglog(tau, err_r_9_9_i, '-.', color=plt.rcParams['axes.prop_cycle'].by_key()['color'][3])
# ax1.loglog(tau[:-2], err_r_9_9[:-2], '.-', label="$r=(6, 6)$")
ax1.loglog(t, t**1*0.02, 'k:', label="$\Delta t$")
ax1.set_xlabel("$\Delta t$")
ax1.set_ylabel("$\Vert P^\mathrm{{TTN}}(10,\mathbf{{x}})-P^\mathrm{{ref}}(10,\mathbf{{x}}) \Vert$")
ax1.set_ylim([5e-5, 5e1])
ax1.legend()

ax2.loglog(tau, dm_max_r_5_3, '.-')
ax2.loglog(tau, dm_max_r_5_4, '.-')
ax2.loglog(tau, dm_max_r_5_5, '.-')
ax2.loglog(tau, dm_max_r_9_9, '.-')
ax2.loglog(tau, dm_max_r_9_9_i, '-.', color=plt.rcParams['axes.prop_cycle'].by_key()['color'][3])
# ax2.loglog(tau[:-2], dm_max_r_9_9[:-2], '.-', label="$r=(6, 4)$")
ax2.loglog(t, t**2*0.05, 'k--', label="$(\Delta t)^2$")
ax2.set_xlabel("$\Delta t$")
ax2.set_ylabel("$\max_t(|\Delta m(t)|)$")
ax2.set_ylim([5e-7, 5e1])
ax2.legend()

fig.legend([l1, l2, l3, l4], general_labels, loc="upper center", ncol=5)
plt.subplots_adjust(wspace=0.4)
# plt.tight_layout()
plt.savefig("plots/time_err_ttn.pdf", bbox_inches="tight");

## Comparison with deterministic solution

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-3")
concentrations = time_series.calculateMoments()
t = time_series.time

In [None]:
concentrations_ode = odesolve(model, t)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 4))
deviation = np.sqrt(concentrations[1]-concentrations[0]**2)
observables = ["x1", "x2"]
observables_alt = ["$S_0$", "$S_1$", "$S_2$"]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for i, (o, o_alt) in enumerate(zip(observables, observables_alt)):
    ax1.plot(t, concentrations[0][:, i], '-', label=o_alt, color=colors[i], alpha=0.7)
    ax1.fill_between(t, concentrations[0][:, i]-deviation[:, i], concentrations[0][:, i]+deviation[:, i], color=colors[i], alpha=.2)
    ax1.plot(t, concentrations_ode[o], '--', color=colors[i], alpha=1.0)
ax1.set_ylabel("$\langle x_i(t) \\rangle$")
ax1.set_ylim([0.0, 20.0])

observables = ["x3", "x4", "x5"]
observables_alt = ["$S_2$", "$S_3$", "$S_4$"]
for idx_o, (o, o_alt) in enumerate(zip(observables, observables_alt)):
    i = idx_o + 2
    ax2.plot(t, concentrations[0][:, i], '-', label=o_alt, color=colors[i], alpha=0.7)
    ax2.fill_between(t, concentrations[0][:, i]-deviation[:, i], concentrations[0][:, i]+deviation[:, i], color=colors[i], alpha=.2)
    ax2.plot(t, concentrations_ode[o], '--', color=colors[i], alpha=1.0)
ax2.set_ylabel("$\langle x_i(t) \\rangle$")
ax2.set_ylim([0, 2.5])
ax2.yaxis.tick_right()
ax2.yaxis.set_ticks_position("both")
ax2.yaxis.set_label_position("right")
plt.setp((ax1, ax2), xlabel="$t$", xlim=[0.0, 10.0], xticks=[0.0, 2.5, 5.0, 7.5, 10.0]);

lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(ll, []) for ll in zip(*lines_labels)]
fig.legend(lines, labels, ncols=5, loc="upper center")

plt.savefig("plots/concentrations.pdf");

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), layout='constrained')
for i, o in enumerate(model.observables):
    ax.plot(t, np.abs((concentrations[0][:, i]-concentrations_ode[o.name])), '-', label="$S_{{{}}}$".format(i))
ax.set_xlabel("$t$")
ax.set_ylabel("$|c_i^\mathrm{{CME}}(t) - c_i^\mathrm{{ODE}}(t)|$");
fig.legend(*ax.get_legend_handles_labels(), loc="outside center right")

plt.savefig("plots/concentrations_err.pdf");

## Error depending on rank and partition

### Load data

#### Partition 0

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_00 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_10 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_20 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_30 = np.linalg.norm(P - P_exact, axis=1, ord=None)

#### Partition 1

In [None]:
time_series = TimeSeries("output/lambda_phage_p1_r5-3_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_01 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p1_r5-4_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_11 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p1_r5-5_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_21 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p1_r9-9_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_31 = np.linalg.norm(P - P_exact, axis=1, ord=None)

#### Partition 2

In [None]:
time_series = TimeSeries("output/lambda_phage_p2_r5-3_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_02 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p2_r5-4_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_12 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p2_r5-5_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_22 = np.linalg.norm(P - P_exact, axis=1, ord=None)

time_series = TimeSeries("output/lambda_phage_p2_r9-9_e_tau1e-3")
P = time_series.calculateFullDistribution()
err_32 = np.linalg.norm(P - P_exact, axis=1, ord=None)

### Plot

In [None]:
# from mpl_toolkits.axes_grid1.inset_locator import inset_axes
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
cols = ["$\mathcal{{P}}_0$", "$\mathcal{{P}}_1$", "$\mathcal{{P}}_2$"]
rows = ["$r = (5, 3)$", "$r = (5, 4)$", "$r = (5, 5)$", "$r = (5, 6)$"]
label = "$\Vert P^\mathrm{{TTN}}(t,\mathbf{{x}})-P^\mathrm{{ref}}(t,\mathbf{{x}}) \Vert$"

fig, axs = plt.subplots(4, 3, figsize=(10, 8), sharex='col', sharey='row')

line, = axs[0, 0].plot(time[1:], err_00[1:], '.-', color=colors[0])
axs[1, 0].plot(time[1:], err_10[1:], '.-', color=colors[0])
axs[2, 0].plot(time[1:], err_20[1:], '.-', color=colors[0])
axs[3, 0].plot(time[1:], err_30[1:], '.-', color=colors[0])

axs[0, 1].plot(time[1:], err_01[1:], '.-', color=colors[1])
axs[1, 1].plot(time[1:], err_11[1:], '.-', color=colors[1])
axs[2, 1].plot(time[1:], err_21[1:], '.-', color=colors[1])
axs[3, 1].plot(time[1:], err_31[1:], '.-', color=colors[1])

axs[0, 2].plot(time[1:], err_02[1:], '.-', color=colors[2])
axs[1, 2].plot(time[1:], err_12[1:], '.-', color=colors[2])
axs[2, 2].plot(time[1:], err_22[1:], '.-', color=colors[2])
axs[3, 2].plot(time[1:], err_32[1:], '.-', color=colors[2])


plt.setp(axs, ylim=[5e-5, 2e-2], yscale="log")

pad = 5 # in points

for ax, col in zip(axs[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad+5),
                xycoords='axes fraction', textcoords='offset points',
                size=14, ha='center', va='baseline')

for ax, row in zip(axs[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad-pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=14, ha='right', va='center')

plt.setp(axs[-1, :], xlabel="$t$", xticks=[1, 4, 7, 10])
fig.suptitle(label, fontsize=16)
plt.subplots_adjust(hspace=.0, wspace=.0)
plt.tight_layout()
plt.savefig("plots/comparison_err_ttn.pdf");

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(6, 5))

axs[0, 0].plot(time[1:], err_00[1:], '.-', label="$\mathcal{{P}}_0$")
axs[0, 0].plot(time[1:], err_01[1:], '.-', label="$\mathcal{{P}}_1$")
axs[0, 0].plot(time[1:], err_02[1:], '.-', label="$\mathcal{{P}}_2$")
axs[0, 0].set_title("$r = (5, 3)$")

axs[0, 1].plot(time[1:], err_10[1:], '.-')
axs[0, 1].plot(time[1:], err_11[1:], '.-')
axs[0, 1].plot(time[1:], err_12[1:], '.-')
axs[0, 1].set_title("$r = (5, 4)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

axs[1, 0].plot(time[1:], err_20[1:], '.-')
axs[1, 0].plot(time[1:], err_21[1:], '.-')
axs[1, 0].plot(time[1:], err_22[1:], '.-')
axs[1, 0].set_title("$r = (5, 5)$")

axs[1, 1].plot(time[1:], err_30[1:], '.-')
axs[1, 1].plot(time[1:], err_31[1:], '.-')
axs[1, 1].plot(time[1:], err_32[1:], '.-')
axs[1, 1].set_title("$r = (9, 9)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.setp(axs, xlabel="$t$", xticks=[1.0, 4.0, 7.0, 10.0], ylim=[3e-4, 2e-2], yscale="log")
plt.subplots_adjust(hspace=0.4)
fig.suptitle("$\Vert P^\mathrm{{TTN}}(t,\mathbf{{x}})-P^\mathrm{{ref}}(t,\mathbf{{x}}) \Vert$", fontsize=16, y=1.03)
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="center", bbox_to_anchor=(0.5, 0.93))
plt.tight_layout()
plt.savefig("plots/err_comparison_ttn_alt.pdf", bbox_inches="tight");

## Mass error

### Load data

#### Partition 0

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau1e-3")
mass_err00 = np.abs(time_series.getMassErr())
time = time_series.time

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau1e-3")
mass_err10 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-3")
mass_err20 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/lambda_phage_p0_r5-6_e_tau1e-3")
mass_err30 = np.abs(time_series.getMassErr())

#### Partition 1

In [None]:
time_series = TimeSeries("output/lambda_phage_p1_r5-3_e_tau1e-3")
mass_err01 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/lambda_phage_p1_r5-4_e_tau1e-3")
mass_err11 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/lambda_phage_p1_r5-5_e_tau1e-3")
mass_err21 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/lambda_phage_p1_r5-6_e_tau1e-3")
mass_err31 = np.abs(time_series.getMassErr())

#### Partition 2

In [None]:
time_series = TimeSeries("output/lambda_phage_p2_r5-3_e_tau1e-3")
mass_err02 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/lambda_phage_p2_r5-4_e_tau1e-3")
mass_err12 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/lambda_phage_p2_r5-5_e_tau1e-3")
mass_err22 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/lambda_phage_p2_r5-6_e_tau1e-3")
mass_err32 = np.abs(time_series.getMassErr())

### Plot

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 6))

axs[0, 0].plot(time[1:], mass_err00[1:], ".-", label="$\mathcal{{P}}_0$")
axs[0, 0].plot(time[1:], mass_err01[1:], ".-", label="$\mathcal{{P}}_1$")
axs[0, 0].plot(time[1:], mass_err02[1:], ".-", label="$\mathcal{{P}}_2$")
axs[0, 0].set_title("$r = (5, 3)$")

axs[0, 1].plot(time[1:], mass_err10[1:], ".-")
axs[0, 1].plot(time[1:], mass_err11[1:], ".-")
axs[0, 1].plot(time[1:], mass_err12[1:], ".-")
axs[0, 1].set_title("$r = (5, 4)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

axs[1, 0].plot(time[1:], mass_err20[1:], ".-")
axs[1, 0].plot(time[1:], mass_err21[1:], ".-")
axs[1, 0].plot(time[1:], mass_err22[1:], ".-")
axs[1, 0].set_title("$r = (5, 5)$")

axs[1, 1].plot(time[1:], mass_err30[1:], ".-")
axs[1, 1].plot(time[1:], mass_err31[1:], ".-")
axs[1, 1].plot(time[1:], mass_err32[1:], ".-")
axs[1, 1].set_title("$r = (5, 6)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.setp(axs, xlabel="$t$", xticks=[1.0, 4.0, 7.0, 10.0], ylim=[3e-8, 5e-5], yscale="log")
plt.subplots_adjust(hspace=0.5)
fig.suptitle("$|\Delta m(t)|$", fontsize=16, y=1.05)
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="center", bbox_to_anchor=(0.5, 0.95))
plt.tight_layout()
plt.savefig("plots/mass_err_comparison_ttn.pdf");

In [None]:
fig, ax = plt.subplots()
ax.plot(time[1:], mass_err01[1:], ".-", label="$r = (5, 3)$")
ax.plot(time[1:], mass_err11[1:], ".-", label="$r = (5, 4)$")
ax.plot(time[1:], mass_err21[1:], ".-", label="$r = (5, 5)$")
ax.plot(time[1:], mass_err31[1:], ".-", label="$r = (5, 6)$")
ax.legend()
ax.ticklabel_format(style='sci', axis='y', scilimits=(-2,2))
plt.setp(ax, xlabel="$t$", ylabel="$|\Delta m(t)|$", xticks=[1.0, 4.0, 7.0, 10.0])
plt.savefig("plots/mass_err_comparison_ttn_rank.pdf");

## Wall time

### Load data

#### Partition 0

In [None]:
time_series = TimeSeries("output/lambda_phage_p0_r5-3_e_tau1e-3")
walltime00 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p0_r5-4_e_tau1e-3")
walltime10 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p0_r5-5_e_tau1e-3")
walltime20 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p0_r9-9_e_tau1e-3")
walltime30 = time_series.getWallTime()

#### Partition 1

In [None]:
time_series = TimeSeries("output/lambda_phage_p1_r5-3_e_tau1e-3")
walltime01 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p1_r5-4_e_tau1e-3")
walltime11 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p1_r5-5_e_tau1e-3")
walltime21 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p1_r9-9_e_tau1e-3")
walltime31 = time_series.getWallTime()

#### Partition 2

In [None]:
time_series = TimeSeries("output/lambda_phage_p2_r5-3_e_tau1e-3")
walltime02 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p2_r5-4_e_tau1e-3")
walltime12 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p2_r5-5_e_tau1e-3")
walltime22 = time_series.getWallTime()

time_series = TimeSeries("output/lambda_phage_p2_r9-9_e_tau1e-3")
walltime32 = time_series.getWallTime()

### Plot

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(6, 4.5))
labels = ["$\mathcal{{P}}_{}$".format(i) for i in range(3)]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
walltime0 = [walltime00, walltime01, walltime02]
axs[0, 0].bar(labels, walltime0, color=colors)
axs[0, 0].set_title("$r = (5, 3)$")

walltime1 = [walltime10, walltime11, walltime12]
axs[0, 1].bar(labels, walltime1, color=colors)
axs[0, 1].set_title("$r = (5, 4)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

walltime2 = [walltime20, walltime21, walltime22]
axs[1, 0].bar(labels, walltime2, color=colors)
axs[1, 0].set_title("$r = (5, 5)$")

walltime3 = [walltime30, walltime31, walltime32]
axs[1, 1].bar(labels, walltime3, color=colors)
axs[1, 1].set_title("$r = (9, 9)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

# plt.setp(axs, ylim=[1e0, 1e6], yscale="log")
plt.subplots_adjust(wspace=0.06, hspace=0.5)
# fig.tight_layout()
fig.suptitle("wall time [$\mathrm{s}$]", fontsize=16)

plt.savefig("plots/wall_time_comparison_ttn.pdf", bbox_inches="tight");

## Memory requirements

In [None]:
def memoryRequirementsP0(rank):
    x0 = rank[0] * 656
    x10 = rank[1] * 121
    x11 = rank[1] * 11
    q = 1 * rank[0] ** 2
    q1 = rank[0] * rank[1] ** 2
    return (x0 + x10 + x11 + q + q1) * 8 * 1e-3

def memoryRequirementsP1(rank):
    x00 = rank[1] * 656
    x01 = rank[1] * 121
    x1 = rank[0] * 11
    q = 1 * rank[0] ** 2
    q0 = rank[0] * rank[1] ** 2
    return (x00 + x01 + x1 + q + q0) * 8 * 1e-3

def memoryRequirementsP2(rank):
    x00 = rank[1] * 656
    x01 = rank[1] * 11
    x1 = rank[0] * 121
    q = 1 * rank[0] ** 2
    q0 = rank[0] * rank[1] ** 2
    return (x00 + x01 + x1 + q + q0) * 8 * 1e-3

def memoryRequirementsPfull(rank):
    x0 = rank[0] * 656
    x1 = rank[0] * 1331
    q = 1 * rank[0] ** 2
    return (x0 + x1 + q) * 8 * 1e-3

In [None]:
rank = [5, 3]
fig, axs = plt.subplots(2, 2, figsize=(6, 4))
labels = ["$\mathcal{{P}}_{}$".format(i) for i in range(3)]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
memory_req = [memoryRequirementsP0(rank), memoryRequirementsP1(rank), memoryRequirementsP2(rank)]
axs[0, 0].bar(labels, memory_req, color=colors)
axs[0, 0].set_title("$r = (5, 3)$")

rank = [5, 4]
memory_req = [memoryRequirementsP0(rank), memoryRequirementsP1(rank), memoryRequirementsP2(rank)]
axs[0, 1].bar(labels, memory_req, color=colors)
axs[0, 1].set_title("$r = (5, 4)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

rank = [5, 5]
memory_req = [memoryRequirementsP0(rank), memoryRequirementsP1(rank), memoryRequirementsP2(rank)]
print("r=(5,5)", memoryRequirementsP0(rank))
axs[1, 0].bar(labels, memory_req, color=colors)
axs[1, 0].set_title("$r = (5, 5)$")

rank = [9, 9]
memory_req = [memoryRequirementsP0(rank), memoryRequirementsP1(rank), memoryRequirementsP2(rank)]
print("r=(9,9)", memoryRequirementsP0(rank))

axs[1, 1].bar(labels, memory_req, color=colors)
axs[1, 1].set_title("$r = (9, 9)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

# plt.setp(axs, ylim=[1e0, 1e6], yscale="log")
plt.subplots_adjust(hspace=0.5)
# fig.tight_layout()
fig.suptitle("memory [$\mathrm{kB}$]", fontsize=16)

plt.savefig("plots/memory_comparison_ttn.pdf");

In [None]:
tree = readTree("output/lambda_phage_p0_r10-10_i_tau2e-2/output_t500.nc")
P_slice_p0_r_10_10, P_marginal_p0_r_10_10 = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/lambda_phage_p0_r10-10_e_tau1e-3/output_t10000.nc")
P_slice_p0_r_10_10_e, P_marginal_p0_r_10_10_e = tree.calculateObservables(slice_vec)

In [None]:
with np.load("scripts/reference_solutions/lp_ode_ref_r5.npz") as data:
    P_exact = data['P_full']

In [None]:
from scripts.reference_solutions.ode_helper import calculateObservables

In [None]:
tree = readTree("output/lambda_phage_test_p0_r5_i_tau2e-2/output_t500.nc")
P_test_slice_p0_r_5, P_test_marginal_p0_r_5 = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/lambda_phage_test_p0_r7_e_tau1e-3_householder/output_t10000.nc")
P_test_new_slice_p0_r_7_householder, P_test_new_marginal_p0_r_7_householder = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/lambda_phage_test_p0_r10_e_tau1e-3_householder/output_t10000.nc")
P_test_new_slice_p0_r_10_householder, P_test_new_marginal_p0_r_10_householder = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/lambda_phage_test_p1_r10_e_tau1e-3_householder/output_t10000.nc")
P_test_new_slice_p1_r_10_householder, P_test_new_marginal_p1_r_10_householder = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/lambda_phage_test_p0_r10_i_tau1e-3_householder/output_t10000.nc")
P_test_new_slice_p0_r_10_householder_i, P_test_new_marginal_p0_r_10_householder_i = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/lambda_phage_test_p0_r12_e_tau1e-3_householder/output_t10000.nc")
P_test_new_slice_p0_r_12_householder, P_test_new_marginal_p0_r_12_householder = tree.calculateObservables(slice_vec)

In [None]:
with np.load("scripts/reference_solutions/lp_test_ode_ref.npz") as data:
    P_test_marginal0 = data['P_marginal0']
    P_test_marginal1 = data['P_marginal1']
    n = data['n']

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 3))
# ax1.plot(np.arange(tree.grid.n[0]), P_test_marginal_p0_r_5[0])
ax1.plot(np.arange(tree.grid.n[0]), P_test_new_marginal_p0_r_7_householder[0])
ax1.plot(np.arange(tree.grid.n[0]), P_test_new_marginal_p0_r_10_householder_i[0])
ax1.plot(np.arange(tree.grid.n[0]), P_test_new_marginal_p0_r_10_householder[0])
ax1.plot(np.arange(tree.grid.n[0]), P_test_new_marginal_p0_r_12_householder[0])
ax1.plot(np.arange(tree.grid.n[0]), P_test_new_marginal_p1_r_10_householder[0])
ax1.plot(np.arange(n[0]), P_test_marginal0)

# ax2.plot(np.arange(tree.grid.n[1]), P_test_marginal_p0_r_5[1])
ax2.plot(np.arange(tree.grid.n[1]), P_test_new_marginal_p0_r_7_householder[1])
ax2.plot(np.arange(tree.grid.n[1]), P_test_new_marginal_p0_r_10_householder_i[1])
ax2.plot(np.arange(tree.grid.n[1]), P_test_new_marginal_p0_r_10_householder[1])
ax2.plot(np.arange(tree.grid.n[1]), P_test_new_marginal_p0_r_12_householder[1])
ax2.plot(np.arange(tree.grid.n[1]), P_test_new_marginal_p1_r_10_householder[1])
ax2.plot(np.arange(n[1]), P_test_marginal1)