# BAX pore assembly

In [None]:
import numpy as np
import matplotlib.pyplot as plt
plt.style.use("./scripts/output/notebooks/custom_style.mplstyle")
import xarray as xr

from scripts.grid_class import GridParms
from scripts.tree_class import Tree
from scripts.output.output_helper import *
from scripts.reference_solutions.ssa_helper import SSASol

%matplotlib inline

In [None]:
slice_vec = np.array([12, 9, 2, 1, 0, 0, 0, 0, 0, 50, 0])

## Load initial data

### Tree tensor integrator

#### Partition 0
(0 1 2)(((3 4 6 7)(5 8))(9 10))

#### Partition 1
((0 1 2)(3 4 5))((6 7 9)(8 10))

In [None]:
tree = readTree("output/bax_p1_r5-10-10_i_tau5e-2/output_t2900.nc")
P_slice_p1_r_5_10_10_s0, P_sum_p1_r_5_10_10_s0 = tree.calculateObservables(0, slice_vec)
P_slice_p1_r_5_10_10_s1, P_sum_p1_r_5_10_10_s1 = tree.calculateObservables(1, slice_vec)
P_slice_p1_r_5_10_10_s3, P_sum_p1_r_5_10_10_s3 = tree.calculateObservables(3, slice_vec)
P_slice_p1_r_5_10_10_s9, P_sum_p1_r_5_10_10_s9 = tree.calculateObservables(9, slice_vec)

In [None]:
tree = readTree("output/bax_p1_r5-15-25_i_tau5e-2/output_t2900.nc")
P_slice_p1_r_5_15_25_s0, P_sum_p1_r_5_15_25_s0 = tree.calculateObservables(0, slice_vec)
P_slice_p1_r_5_15_25_s1, P_sum_p1_r_5_15_25_s1 = tree.calculateObservables(1, slice_vec)
P_slice_p1_r_5_15_25_s3, P_sum_p1_r_5_15_25_s3 = tree.calculateObservables(3, slice_vec)
P_slice_p1_r_5_15_25_s9, P_sum_p1_r_5_15_25_s9 = tree.calculateObservables(9, slice_vec)

#### Partition 2
((0 1)(2 3 4))((5 6 7 8)(9 10))

#### Two partitions
(0 1 2 3 4)(5 6 7 8 9 10)

In [None]:
tree = readTree("output/bax_pfull_r5_i_tau5e-2/output_t2900.nc")

P_slice_full = []
P_sum_full = []
for i in range(tree.grid.d()):
    P_slice_tensor_full, P_sum_tensor_full = tree.calculateObservables(i, slice_vec)
    P_slice_full.append(P_slice_tensor_full)
    P_sum_full.append(P_sum_tensor_full)

### Matrix integrator
Two partitions, $r=5$, explicit Euler with variable time step size ($\min_i({{\tau_i}})=1.0$) and 100 substeps

In [None]:
with xr.open_dataset("output/bax_matrix/output_t100.nc") as ds:
    grid = GridInfo(ds)
    lr_sol = LRSol(ds, grid)
    idx_2D = np.array([0, 1])
    P_sum_matrix = lr_sol.marginalDistributions()
    P_slice_matrix = lr_sol.slicedDistributions(slice_vec)

### SSA

In [None]:
ssa_1e7 = np.load("scripts/reference_solutions/bax_ssa_1e7.npy")
ssa_1e7_sol = SSASol(ssa_1e7)
P_sum_ssa_1e7, _, P_slice_ssa_1e7, _ = ssa_1e7_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
ssa_2e6 = np.load("scripts/reference_solutions/bax_ssa_2e6.npy")
ssa_2e6_sol = SSASol(ssa_2e6)
P_sum_ssa_2e6, _, P_slice_ssa_2e6, _ = ssa_2e6_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
ssa_1e6 = np.load("scripts/reference_solutions/bax_ssa_1e6.npy")
ssa_1e6_sol = SSASol(ssa_1e6)
P_sum_ssa_1e6, _, P_slice_ssa_1e6, _ = ssa_1e6_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
ssa_1e5 = np.load("scripts/reference_solutions/bax_ssa_1e5.npy")
ssa_1e5_sol = SSASol(ssa_1e5)
P_sum_ssa_1e5, _, P_slice_ssa_1e5, _ = ssa_1e5_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
ssa_1e4 = np.load("scripts/reference_solutions/bax_ssa_1e4.npy")
ssa_1e4_sol = SSASol(ssa_1e4)
P_sum_ssa_1e4, _, P_slice_ssa_1e4, _ = ssa_1e4_sol.calculateObservables(slice_vec, idx_2D)

## Sliced distributions

### TTN integrator

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5), sharex='col', sharey='row')
axs[0, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_full[0][11:26], '.-', label="2 partitions")
axs[0, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_p1_r_5_10_10_s0[11:26], 'v', label="$\mathcal{{P}}_1$, $r = (5, 10, 10)$")
axs[0, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_p1_r_5_15_25_s0[11:26], '^', label="$\mathcal{{P}}_1$, $r = (5, 15, 25)$")
axs[0, 0].set_ylabel("$P_\\mathrm{{S}}(x_0)$")
axs[0, 0].set_yscale("log")

axs[1, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_full[0][11:26], '.-')
axs[1, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_p1_r_5_10_10_s0[11:26], 'v')
axs[1, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_p1_r_5_15_25_s0[11:26], '^')
axs[1, 0].set_xlabel(xlabel="$x_0$")
axs[1, 0].set_ylabel("$P_\\mathrm{{S}}(x_0)$")
axs[1, 0].set_xticks(np.arange(11, 25+2, 2))
axs[1, 0].set_xlim([10, 26])

axs[0, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_full[9][45:], '.-')
axs[0, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_p1_r_5_10_10_s9[45:], 'v')
axs[0, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_p1_r_5_15_25_s9[45:], '^')
axs[0, 1].set_ylabel("$P_\\mathrm{{S}}(x_9)$")
axs[0, 1].set_yscale("log")

axs[1, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_full[9][45:], '.-')
axs[1, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_p1_r_5_10_10_s9[45:], 'v')
axs[1, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_p1_r_5_15_25_s9[45:], '^')
axs[1, 1].set_xlabel(xlabel="$x_9$")
axs[1, 1].set_ylabel("$P_\\mathrm{{S}}(x_9)$")
axs[1, 1].set_xticks(np.arange(45, 55+2, 2))

axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="upper center")
plt.subplots_adjust(hspace=.0, wspace=.0)

plt.savefig("plots/sliced_ttn.pdf");

### SSA

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5), sharex='col', sharey='row')
axs[0, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_full[0][11:26], '.-', label="$P_\\mathrm{{S}}^\\mathrm{{TTN,ref}}$")
axs[0, 0].plot((np.arange(ssa_1e7_sol.n[0])+ssa_1e7_sol.n_min[0]), P_slice_ssa_1e7[-1][0], 'v', label="$10^7$ runs")
axs[0, 0].plot((np.arange(ssa_1e6_sol.n[0])+ssa_1e6_sol.n_min[0]), P_slice_ssa_1e6[-1][0], '<', label="$10^6$ runs")
axs[0, 0].plot((np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0]), P_slice_ssa_1e5[-1][0], '^', label="$10^5$ runs")
axs[0, 0].plot((np.arange(ssa_1e4_sol.n[0])+ssa_1e4_sol.n_min[0]), P_slice_ssa_1e4[-1][0], '>', label="$10^4$ runs")
axs[0, 0].set_ylabel("$P_\\mathrm{{S}}(x_0)$")
axs[0, 0].set_yscale("log")

axs[1, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_full[0][11:26], '.-')
axs[1, 0].plot((np.arange(ssa_1e7_sol.n[0])+ssa_1e7_sol.n_min[0]), P_slice_ssa_1e7[-1][0], 'v')
axs[1, 0].plot((np.arange(ssa_1e6_sol.n[0])+ssa_1e6_sol.n_min[0]), P_slice_ssa_1e6[-1][0], '<')
axs[1, 0].plot((np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0]), P_slice_ssa_1e5[-1][0], '^')
axs[1, 0].plot((np.arange(ssa_1e4_sol.n[0])+ssa_1e4_sol.n_min[0]), P_slice_ssa_1e4[-1][0], '>')
axs[1, 0].set_xlabel(xlabel="$x_0$")
axs[1, 0].set_ylabel("$P_\\mathrm{{S}}(x_0)$")
axs[1, 0].set_xticks(np.arange(11, 25+2, 2))
axs[1, 0].set_xlim([10, 26])

axs[0, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_full[9][45:], '.-')
axs[0, 1].plot((np.arange(ssa_1e7_sol.n[9])+ssa_1e7_sol.n_min[9]), P_slice_ssa_1e7[-1][9], 'v')
axs[0, 1].plot((np.arange(ssa_1e6_sol.n[9])+ssa_1e6_sol.n_min[9]), P_slice_ssa_1e6[-1][9], '<')
axs[0, 1].plot((np.arange(ssa_1e5_sol.n[9])+ssa_1e5_sol.n_min[9]), P_slice_ssa_1e5[-1][9], '^')
axs[0, 1].plot((np.arange(ssa_1e4_sol.n[9])+ssa_1e4_sol.n_min[9]), P_slice_ssa_1e4[-1][9], '>')
axs[0, 1].set_ylabel("$P_\\mathrm{{S}}(x_9)$")
axs[0, 1].set_yscale("log")

axs[1, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_full[9][45:], '.-')
axs[1, 1].plot((np.arange(ssa_1e7_sol.n[9])+ssa_1e7_sol.n_min[9]), P_slice_ssa_1e7[-1][9], 'v')
axs[1, 1].plot((np.arange(ssa_1e6_sol.n[9])+ssa_1e6_sol.n_min[9]), P_slice_ssa_1e6[-1][9], '<')
axs[1, 1].plot((np.arange(ssa_1e5_sol.n[9])+ssa_1e5_sol.n_min[9]), P_slice_ssa_1e5[-1][9], '^')
axs[1, 1].plot((np.arange(ssa_1e4_sol.n[9])+ssa_1e4_sol.n_min[9]), P_slice_ssa_1e4[-1][9], '>')
axs[1, 1].set_xlabel(xlabel="$x_9$")
axs[1, 1].set_ylabel("$P_\\mathrm{{S}}(x_9)$")
# axs[1, 1].set_xticks(np.arange(8, 15))
# axs[1, 1].set_xlim([7, 15])

axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")
fig.legend(*axs[0, 0].get_legend_handles_labels(), loc="upper center", ncols=5)
plt.subplots_adjust(hspace=.0, wspace=.0)

plt.savefig("plots/sliced_ssa.pdf");

## Comparison of sliced distribution between TTN integrator and SSA

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(9, 6), sharex='col', sharey='row')
nn = 0
lim = [11, 26]

intvl = slice(lim[0], lim[1])
axs[0, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_full[nn][intvl], '.-', label="$P_\\mathrm{{S}}^\\mathrm{{TTN,ref}}$")
axs[0, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_10_10_s0[intvl], 'v', label="$\mathcal{{P}}_1$, $r = (5, 10, 10)$")
axs[0, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_15_25_s0[intvl], '^', label="$\mathcal{{P}}_1$, $r = (5, 15, 25)$")
axs[0, 0].set_ylabel("$P_\\mathrm{{S}}(x_{})$".format(nn))
axs[0, 0].set_yscale("log")
axs[0, 0].set_title("PS-TTN integrator")
axs[0, 0].legend()

axs[1, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_full[nn][intvl], '.-')
axs[1, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_10_10_s0[intvl], 'v')
axs[1, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_15_25_s0[intvl], '^')
axs[1, 0].set_xlabel(xlabel="$x_{}$".format(nn))
axs[1, 0].set_ylabel("$P_\\mathrm{{S}}(x_{})$".format(nn))
axs[1, 0].set_xticks(np.arange(11, 25+2, 2))
axs[1, 0].set_xlim([10, 26])

axs[0, 1].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_full[nn][intvl], '.-', label="$P_\\mathrm{{S}}^\\mathrm{{TTN,ref}}$")
axs[0, 1].plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_slice_ssa_1e7[-1][nn], 'v', label="$10^7$ runs")
axs[0, 1].plot((np.arange(ssa_1e6_sol.n[nn])+ssa_1e6_sol.n_min[nn]), P_slice_ssa_1e6[-1][nn], '<', label="$10^6$ runs")
axs[0, 1].plot((np.arange(ssa_1e5_sol.n[nn])+ssa_1e5_sol.n_min[nn]), P_slice_ssa_1e5[-1][nn], '^', label="$10^5$ runs")
axs[0, 1].plot((np.arange(ssa_1e4_sol.n[nn])+ssa_1e4_sol.n_min[nn]), P_slice_ssa_1e4[-1][nn], '>', label="$10^4$ runs")
axs[0, 1].set_ylabel("$P_\\mathrm{{S}}(x_{})$".format(nn))
axs[0, 1].set_yscale("log")
axs[0, 1].set_title("SSA")
axs[0, 1].legend()

axs[1, 1].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_full[nn][intvl], '.-')
axs[1, 1].plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_slice_ssa_1e7[-1][nn], 'v')
axs[1, 1].plot((np.arange(ssa_1e6_sol.n[nn])+ssa_1e6_sol.n_min[nn]), P_slice_ssa_1e6[-1][nn], '<')
axs[1, 1].plot((np.arange(ssa_1e5_sol.n[nn])+ssa_1e5_sol.n_min[nn]), P_slice_ssa_1e5[-1][nn], '^')
axs[1, 1].plot((np.arange(ssa_1e4_sol.n[nn])+ssa_1e4_sol.n_min[nn]), P_slice_ssa_1e4[-1][nn], '>')
axs[1, 1].set_xlabel(xlabel="$x_{}$".format(nn))
axs[1, 1].set_ylabel("$P_\\mathrm{{S}}(x_{})$".format(nn))
axs[1, 1].set_xticks(np.arange(11, 25+2, 2))
axs[1, 1].set_xlim([10, 26])

axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.subplots_adjust(hspace=.0, wspace=.2)

plt.savefig("plots/sliced_ttn_ssa.pdf");

## Marginal distributions

### TTN integrator

In [None]:
nn = 0
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
ax1.plot(np.arange(grid.n[nn]), P_sum_matrix[nn], '.-', label="2 partitions")
ax1.plot(np.arange(tree.grid.n[nn]), P_sum_p1_r_5_10_10_s0, 'v', label="$\mathcal{{P}}_1$, $r = (5, 10, 10)$")
ax1.plot(np.arange(tree.grid.n[nn]), P_sum_p1_r_5_15_25_s0, '^', label="$\mathcal{{P}}_1$, $r = (5, 15, 25)$")
# ax1.plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_sum_ssa_1e7[-1][nn], 'v', label="SSA ($10^7$ runs)")
ax1.set_xlabel("$x_0$")
ax1.set_ylabel("$P_\\mathrm{{M}}(x_{})$".format(nn))

nn = 9
ax2.plot(np.arange(grid.n[nn])[45:], P_sum_full[nn][45:], '.-')
ax2.plot(np.arange(tree.grid.n[nn])[45:], P_sum_p1_r_5_10_10_s9[45:], 'v')
ax2.plot(np.arange(tree.grid.n[nn])[45:], P_sum_p1_r_5_15_25_s9[45:], '^')
# ax2.plot((np.arange(ssa1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_sum_ssa_1e7[-1][nn], 'v')
ax2.set_xlabel("$x_1$")
ax2.yaxis.tick_right()
ax2.yaxis.set_ticks_position("both")
ax2.yaxis.set_label_position("right")
ax2.set_ylabel("$P_\\mathrm{{M}}(x_{})$".format(nn))
fig.legend(*ax1.get_legend_handles_labels(), ncols=3, loc="upper center")

plt.savefig("plots/marginal_ttn.pdf");

### SSA

In [None]:
nn = 0
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
ax1.plot(np.arange(grid.n[nn]), P_sum_matrix[nn], '.-', label="$P_\\mathrm{{M}}^\\mathrm{TTN,ref}$")
ax1.plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_sum_ssa_1e7[-1][nn], 'v', label="$10^7$ runs")
ax1.plot((np.arange(ssa_1e6_sol.n[nn])+ssa_1e6_sol.n_min[nn]), P_sum_ssa_1e6[-1][nn], '<', label="$10^6$ runs")
ax1.plot((np.arange(ssa_1e5_sol.n[nn])+ssa_1e5_sol.n_min[nn]), P_sum_ssa_1e5[-1][nn], '^', label="$10^5$ runs")
ax1.plot((np.arange(ssa_1e4_sol.n[nn])+ssa_1e4_sol.n_min[nn]), P_sum_ssa_1e4[-1][nn], '>', label="$10^4$ runs")
ax1.set_xlabel("$x_0$")
ax1.set_ylabel("$P_\\mathrm{{M}}(x_{})$".format(nn))

nn = 9
ax2.plot(np.arange(grid.n[nn])[45:], P_sum_full[nn][45:], '.-')
ax2.plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_sum_ssa_1e7[-1][nn], 'v')
ax2.plot((np.arange(ssa_1e6_sol.n[nn])+ssa_1e6_sol.n_min[nn]), P_sum_ssa_1e6[-1][nn], '<')
ax2.plot((np.arange(ssa_1e5_sol.n[nn])+ssa_1e5_sol.n_min[nn]), P_sum_ssa_1e5[-1][nn], '^')
ax2.plot((np.arange(ssa_1e4_sol.n[nn])+ssa_1e4_sol.n_min[nn]), P_sum_ssa_1e4[-1][nn], '>')
ax2.set_xlabel("$x_1$")
ax2.yaxis.tick_right()
ax2.yaxis.set_ticks_position("both")
ax2.yaxis.set_label_position("right")
ax2.set_ylabel("$P_\\mathrm{{M}}(x_{})$".format(nn))
fig.legend(*ax1.get_legend_handles_labels(), ncols=5, loc="upper center")

plt.savefig("plots/ssa_ttn.pdf");

## Comparison of TTN and SSA results with matrix integrator

### Load data

#### Get walltimes

In [None]:
time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau5e-2")
walltime_p1_r_5_10_10 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau5e-2")
walltime_p1_r_5_15_25 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/bax_pfull_r5_i_tau5e-2")
walltime_pfull = time_series.getWallTime()

In [None]:
walltime_ssa_1e6 = 358
walltime_ssa_2e6 = 660
walltime_ssa_1e7 = 2898
walltime_matrix = 1.3e5

In [None]:
walltimes = [walltime_pfull, walltime_p1_r_5_15_25, walltime_p1_r_5_10_10, walltime_ssa_1e7, walltime_ssa_2e6, walltime_ssa_1e6]
labels = ["2 partitions", "$\mathcal{{P}}_1$\n$r = (5, 15, 25)$", "$\mathcal{{P}}_1$\n$r = (5, 10, 10)$", "$10^7$ runs", "$2 \\times 10^6$ runs", "$10^6$ runs"]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
color = [colors[0], colors[0], colors[0], colors[1], colors[1], colors[1]]
bar_labels = ['PS-TTN integrator', '_PS-TTN integrator', '_PS-TTN integrator', 'SSA', '_SSA', '_SSA']

#### Error between TTN integrator and the matrix integrator reference solution

In [None]:
DLR_marginal_err_r_5_5_5, _ = calculateMarginalDistributionError("output/bax_p1_r5-5-5_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err_r_5_10_10, _ = calculateMarginalDistributionError("output/bax_p1_r5-10-10_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err_r_5_15_15, _ = calculateMarginalDistributionError("output/bax_p1_r5-15-15_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err_r_5_15_25, _ = calculateMarginalDistributionError("output/bax_p1_r5-15-25_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

#### Error between SSA and the SSA reference solution

In [None]:
d = len(ssa_1e7_sol.n)
SSA_marginal_err_1e4 = np.zeros(d)
SSA_marginal_err_1e5 = np.zeros(d)
SSA_marginal_err_1e6 = np.zeros(d)
SSA_marginal_err_2e6 = np.zeros(d)

for i in range(d):
    n_start_1e4 = ssa_1e4_sol.n_min[i] - ssa_1e7_sol.n_min[i]
    n_start_1e5 = ssa_1e5_sol.n_min[i] - ssa_1e7_sol.n_min[i]
    n_start_1e6 = ssa_1e6_sol.n_min[i] - ssa_1e7_sol.n_min[i]
    n_start_2e6 = ssa_2e6_sol.n_min[i] - ssa_1e7_sol.n_min[i]

    SSA_marginal_err_1e4[i] = np.linalg.norm(P_sum_ssa_1e7[-1][i][n_start_1e4 : n_start_1e4+ssa_1e4_sol.n[i]] - P_sum_ssa_1e4[-1][i][:ssa_1e7_sol.n[i]], np.inf)
    SSA_marginal_err_1e5[i] = np.linalg.norm(P_sum_ssa_1e7[-1][i][n_start_1e5 : n_start_1e5+ssa_1e5_sol.n[i]] - P_sum_ssa_1e5[-1][i][:ssa_1e7_sol.n[i]], np.inf)
    SSA_marginal_err_1e6[i] = np.linalg.norm(P_sum_ssa_1e7[-1][i][n_start_1e6 : n_start_1e6+ssa_1e6_sol.n[i]] - P_sum_ssa_1e6[-1][i][:ssa_1e7_sol.n[i]], np.inf)
    SSA_marginal_err_2e6[i] = np.linalg.norm(P_sum_ssa_1e7[-1][i][n_start_2e6 : n_start_2e6+ssa_2e6_sol.n[i]] - P_sum_ssa_2e6[-1][i][:ssa_1e7_sol.n[i]], np.inf)

#### Error between the matrix integrator reference solution and the SSA reference solution

In [None]:
d = len(ssa_1e7_sol.n)
SSA_DLR_marginal_err_ref_1e6 = np.zeros(d)
SSA_DLR_marginal_err_ref_2e6 = np.zeros(d)
SSA_DLR_marginal_err_ref_1e7 = np.zeros(d)
for i in range(d):
    SSA_DLR_marginal_err_ref_1e6[i] = np.linalg.norm(P_sum_full[i][ssa_1e6_sol.n_min[i] : ssa_1e6_sol.n_min[i]+ssa_1e6_sol.n[i]] - P_sum_ssa_1e6[-1][i][:grid.n[i]], np.inf)
    SSA_DLR_marginal_err_ref_2e6[i] = np.linalg.norm(P_sum_full[i][ssa_2e6_sol.n_min[i] : ssa_2e6_sol.n_min[i]+ssa_2e6_sol.n[i]] - P_sum_ssa_2e6[-1][i][:grid.n[i]], np.inf)
    SSA_DLR_marginal_err_ref_1e7[i] = np.linalg.norm(P_sum_full[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_sum_ssa_1e7[-1][i][:grid.n[i]], np.inf)

### Plots

In [None]:
time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau5e-2")
mass_err = np.abs(time_series.getMassErr())
fig, ax = plt.subplots()
ax.plot(time_series.time, mass_err, '-')
ax.set_xlabel("time $t$")
ax.set_ylabel("$\Delta m$")
ax.set_yscale("log")
ax.set_ylim([8e-9, 1e-2])

plt.savefig("plots/mass_err_ttn_r5-15-25.pdf");

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(12, 8))
axs[0, 0].plot(np.arange(grid.d), DLR_marginal_err_r_5_5_5, '.-', label="$r = (5, 5, 5)$")
axs[0, 0].plot(np.arange(grid.d), DLR_marginal_err_r_5_10_10, '.-', label="$r = (5, 10, 10)$")
axs[0, 0].plot(np.arange(grid.d), DLR_marginal_err_r_5_15_15, '.-', label="$r = (5, 15, 15)$")
axs[0, 0].plot(np.arange(grid.d), DLR_marginal_err_r_5_15_25, '.-', label="$r = (5, 15, 25)$")
axs[0, 0].set_yscale("log")
axs[0, 0].set_xlabel("species $S_i$")
axs[0, 0].legend()
axs[0, 0].set_ylabel("$\max_{{x_i}}|P_M(x_i)-P_M(x_i)^\mathrm{{TTN,ref}}|$")
axs[0, 0].set_ylim([1e-7, 1e-1])

axs[0, 1].plot(np.arange(grid.d), SSA_marginal_err_1e4, '.-', label="$10^4$ runs")
axs[0, 1].plot(np.arange(grid.d), SSA_marginal_err_1e5, '.-', label="$10^5$ runs")
axs[0, 1].plot(np.arange(grid.d), SSA_marginal_err_1e6, '.-', label="$10^6$ runs")
axs[0, 1].plot(np.arange(grid.d), SSA_marginal_err_2e6, '.-', label="$2 \\times 10^6$ runs")
axs[0, 1].set_yscale("log")
axs[0, 1].set_xlabel("species $S_i$")
axs[0, 1].legend()
axs[0, 1].set_ylabel("$\max_{{x_i}}|P_M(x_i)-P_M(x_i)^\mathrm{{SSA,ref}}|$")
axs[0, 1].set_ylim([1e-7, 1e-1])

axs[1, 0].plot(np.arange(grid.d), SSA_DLR_marginal_err_ref_1e6, '.-')
axs[1, 0].plot(np.arange(grid.d), SSA_DLR_marginal_err_ref_2e6, '.-')
axs[1, 0].plot(np.arange(grid.d), SSA_DLR_marginal_err_ref_1e7, '.-')
axs[1, 0].set_yscale("log")
axs[1, 0].set_xlabel("species $S_i$")
axs[1, 0].set_ylabel("$\max_{{x_i}}|P_M(x_i)^\mathrm{{TTN,ref}}-P_M(x_i)^\mathrm{{SSA,ref}}|$")

axs[1, 1].barh(labels, walltimes, label=bar_labels, color=color)
axs[1, 1].set_xscale("log")
axs[1, 1].set_xlabel("wall time [$\mathrm{{s}}$]")
axs[1, 1].legend()
plt.subplots_adjust(wspace=.3, hspace=.3)

plt.savefig("plots/comparison_ttn_ssa.pdf");

## Error depending on time step size

In [None]:
dm_max_r_5_15_25 = []
dm_max_r_5_15_15 = []
tau = []

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau5e-1")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau5e-1")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau2e-1")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau2e-1")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau1e-1")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau1e-1")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau5e-2")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau5e-2")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau2e-2")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau2e-2")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau1e-2")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau1e-2")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

### Plot

In [None]:
fig, ax = plt.subplots()
t = np.linspace(0.01, 0.5)

ax.loglog(tau, dm_max_r_5_15_15, '.-', label="$r=(5, 15, 15)$")
ax.loglog(tau, dm_max_r_5_15_25, '.-', label="$r=(5, 15, 25)$")
ax.loglog(t, t**3*10, 'k:', label="$\\tau^3$")
ax.set_xlabel("time step size $\\tau$")
ax.set_ylabel("$\max(\Delta m)$")
fig.legend(loc="upper center", ncols=3)

plt.savefig("plots/time_mass_err_ttn.pdf");

## Comparison with deterministic solution

In [None]:
time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau5e-2")
concentrations = time_series.calculateMoments()
t = time_series.time

In [None]:
from pysb.integrate import odesolve
from scripts.models.bax_pysb import model
concentrations_ode = odesolve(model, t)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7, 4))
deviation = np.sqrt(concentrations[1]-concentrations[0]**2)
observables = ["oBax1", "oBax2"]
observables_alt = ["$S_0$", "$S_1$", "$S_2$"]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
for i, (o, o_alt) in enumerate(zip(observables, observables_alt)):
    ax1.plot(t, concentrations[0][:, i], '-', label=o_alt, color=colors[i], alpha=0.7)
    ax1.fill_between(t, concentrations[0][:, i]-deviation[:, i], concentrations[0][:, i]+deviation[:, i], color=colors[i], alpha=.2)
    ax1.plot(t, concentrations_ode[o], '--', color=colors[i], alpha=1.0)
ax1.set_ylabel("$\langle x_i(t) \\rangle$")
ax1.set_ylim([0.0, 45.0])

observables = ["oBax3", "oBax4", "oBax5"]
observables_alt = ["$S_2$", "$S_3$", "$S_4$"]
for idx_o, (o, o_alt) in enumerate(zip(observables, observables_alt)):
    i = idx_o + 2
    ax2.plot(t, concentrations[0][:, i], '-', label=o_alt, color=colors[i], alpha=0.7)
    ax2.fill_between(t, concentrations[0][:, i]-deviation[:, i], concentrations[0][:, i]+deviation[:, i], color=colors[i], alpha=.2)
    ax2.plot(t, concentrations_ode[o], '--', color=colors[i], alpha=1.0)
ax2.set_ylim([0.0, 4])
ax2.yaxis.tick_right()
ax2.yaxis.set_ticks_position("both")
ax2.yaxis.set_label_position("right")
plt.setp((ax1, ax2), xlabel="$t$", xlim=[0.0, 145.0], xticks=[0.0, 50.0, 100.0, 145.0]);

lines_labels = [ax.get_legend_handles_labels() for ax in fig.axes]
lines, labels = [sum(ll, []) for ll in zip(*lines_labels)]
fig.legend(lines, labels, ncols=5, loc="upper center")

plt.savefig("plots/concentrations.pdf");

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), layout='constrained')
for i, o in enumerate(model.observables):
    ax.plot(t, np.abs((concentrations[0][:, i]-concentrations_ode[o.name])), '-', label="$S_{{{}}}$".format(i))
ax.set_xlabel("$t$")
ax.set_ylabel("$|c_i^\mathrm{{CME}}(t) - c_i^\mathrm{{ODE}}(t)|$");
for line in ax.get_lines()[0:4]:
    line.set_linestyle('-')
for line in ax.get_lines()[4:8]:
    line.set_linestyle('--')
for line in ax.get_lines()[8:]:
    line.set_linestyle('-.')
fig.legend(*ax.get_legend_handles_labels(), loc="outside center right")

plt.savefig("plots/concentrations_err.pdf");

## Error depending on rank and partition

### Load data

#### Partition 0

In [None]:
DLR_marginal_err00, SSA_marginal_err00 = calculateMarginalDistributionError("output/bax_p0_r5-5-5_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err10, SSA_marginal_err10 = calculateMarginalDistributionError("output/bax_p0_r5-10-10_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err20, SSA_marginal_err20 = calculateMarginalDistributionError("output/bax_p0_r5-15-15_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err30, SSA_marginal_err30 = calculateMarginalDistributionError("output/bax_p0_r5-20-20_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

#### Partition 1

In [None]:
DLR_marginal_err01, SSA_marginal_err01 = calculateMarginalDistributionError("output/bax_p1_r5-5-5_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err11, SSA_marginal_err11 = calculateMarginalDistributionError("output/bax_p1_r5-10-10_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err21, SSA_marginal_err21 = calculateMarginalDistributionError("output/bax_p1_r5-15-15_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err31, SSA_marginal_err31 = calculateMarginalDistributionError("output/bax_p1_r5-20-20_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

#### Partition 2

In [None]:
DLR_marginal_err02, SSA_marginal_err02 = calculateMarginalDistributionError("output/bax_p2_r5-5-5_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err12, SSA_marginal_err12 = calculateMarginalDistributionError("output/bax_p2_r5-10-10_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err22, SSA_marginal_err22 = calculateMarginalDistributionError("output/bax_p2_r5-15-15_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

DLR_marginal_err32, SSA_marginal_err32 = calculateMarginalDistributionError("output/bax_p2_r5-20-20_i_tau5e-2/output_t2900.nc", P_sum_full, P_sum_ssa_1e7[-1], ssa_1e7_sol)

#### Inset

In [None]:
DLR_marginal_err_r_5_15_25, SSA_marginal_err_r_5_15_25 = calculateMarginalDistributionError("output/bax_p1_r10-15-25_i_tau1e-2/output_t14500.nc", P_sum_matrix, P_sum_ssa_1e7[-1], ssa_1e7_sol)

### Plot

In [None]:
# from mpl_toolkits.axes_grid1.inset_locator import inset_axes
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
cols = ["$\mathcal{{P}}_0$", "$\mathcal{{P}}_1$", "$\mathcal{{P}}_2$"]
rows = ["$r = (5, 5, 5)$", "$r = (5, 10, 10)$", "$r = (5, 15, 15)$", "$r = (5, 20, 20)$"]
labels = ["$\max_{{x_i}}|P_M(x_i)-P_M(x_i)^\mathrm{{SSA,ref}}|$", 
          "$\max_{{x_i}}|P_M(x_i)-P_M(x_i)^\mathrm{{TTN,ref}}|$"]

fig, axs = plt.subplots(4, 3, figsize=(10, 8), sharex='col', sharey='row')

line0, = axs[0, 0].plot(np.arange(tree.grid.d()), SSA_marginal_err00, '--', color=colors[0])
line1, = axs[0, 0].plot(np.arange(tree.grid.d()), DLR_marginal_err00, '.-', color=colors[0])
axs[1, 0].plot(np.arange(tree.grid.d()), SSA_marginal_err10, '--', color=colors[0])
axs[1, 0].plot(np.arange(tree.grid.d()), DLR_marginal_err10, '.-', color=colors[0])
axs[2, 0].plot(np.arange(tree.grid.d()), SSA_marginal_err20, '--', color=colors[0])
axs[2, 0].plot(np.arange(tree.grid.d()), DLR_marginal_err20, '.-', color=colors[0])
axs[3, 0].plot(np.arange(tree.grid.d()), SSA_marginal_err30, '--', color=colors[0])
axs[3, 0].plot(np.arange(tree.grid.d()), DLR_marginal_err30, '.-', color=colors[0])

axs[0, 1].plot(np.arange(tree.grid.d()), SSA_marginal_err01, '--', color=colors[1])
axs[0, 1].plot(np.arange(tree.grid.d()), DLR_marginal_err01, '.-', color=colors[1])
axs[1, 1].plot(np.arange(tree.grid.d()), SSA_marginal_err11, '--', color=colors[1])
axs[1, 1].plot(np.arange(tree.grid.d()), DLR_marginal_err11, '.-', color=colors[1])
axs[2, 1].plot(np.arange(tree.grid.d()), SSA_marginal_err21, '--', color=colors[1])
axs[2, 1].plot(np.arange(tree.grid.d()), DLR_marginal_err21, '.-', color=colors[1])
axs[3, 1].plot(np.arange(tree.grid.d()), SSA_marginal_err31, '--', color=colors[1])
axs[3, 1].plot(np.arange(tree.grid.d()), DLR_marginal_err31, '.-', color=colors[1])

axs[0, 2].plot(np.arange(tree.grid.d()), SSA_marginal_err02, '--', color=colors[2])
axs[0, 2].plot(np.arange(tree.grid.d()), DLR_marginal_err02, '.-', color=colors[2])
axs[1, 2].plot(np.arange(tree.grid.d()), SSA_marginal_err12, '--', color=colors[2])
axs[1, 2].plot(np.arange(tree.grid.d()), DLR_marginal_err12, '.-', color=colors[2])
axs[2, 2].plot(np.arange(tree.grid.d()), SSA_marginal_err22, '--', color=colors[2])
axs[2, 2].plot(np.arange(tree.grid.d()), DLR_marginal_err22, '.-', color=colors[2])
axs[3, 2].plot(np.arange(tree.grid.d()), SSA_marginal_err32, '--', color=colors[2])
axs[3, 2].plot(np.arange(tree.grid.d()), DLR_marginal_err32, '.-', color=colors[2])

# axins = inset_axes(axs[-1, 1], width='65%', height='65%', loc='upper center')
# axins.plot(np.arange(tree.grid.d()), SSA_marginal_err_r_5_15_25, '--')
# axins.plot(np.arange(tree.grid.d()), DLR_marginal_err_r_5_15_25, '.-')
# axins.set_ylim([-0.02*0.6, 0.42*0.6])
# axins.annotate("TTN with 2 partitions\n($r=5$)", xy=(0.5, 1), xytext=(0, -30),
#                 xycoords='axes fraction', textcoords='offset points',
#                 size=11, ha='center', va='baseline')

plt.setp(axs, ylim=[-0.02, 0.42])
# plt.setp(axs[:, 0], ylim=[-0.02, 0.22])
# plt.setp(axs[:, 1], ylim=[-0.02, 0.47])
# plt.setp(axs[:, 2], ylim=[-0.02, 0.34])

pad = 5 # in points

for ax, col in zip(axs[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad+5),
                xycoords='axes fraction', textcoords='offset points',
                size=14, ha='center', va='baseline')

for ax, row in zip(axs[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad-pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=14, ha='right', va='center')

plt.setp(axs[-1, :], xlabel="species $S_i$", xticks=[0, 2, 4, 6, 8, 10])
# fig.suptitle(labels[1])
fig.legend([line0, line1], labels, ncols=2, loc="upper center");
plt.subplots_adjust(hspace=.0, wspace=.0)

plt.savefig("plots/comparison_ttn.pdf");

## Mass error

### Load data

#### Partition 0

In [None]:
time_series = TimeSeries("output/bax_p0_r5-5-5_i_tau5e-2")
mass_err00 = np.abs(time_series.getMassErr())
time = time_series.time

time_series = TimeSeries("output/bax_p0_r5-10-10_i_tau5e-2")
mass_err10 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p0_r5-15-15_i_tau5e-2")
mass_err20 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p0_r5-20-20_i_tau5e-2")
mass_err30 = np.abs(time_series.getMassErr())

#### Partition 1

In [None]:
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau5e-2")
mass_err01 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau5e-2")
mass_err11 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau5e-2")
mass_err21 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p1_r5-20-20_i_tau5e-2")
mass_err31 = np.abs(time_series.getMassErr())

#### Partition 2

In [None]:
time_series = TimeSeries("output/bax_p2_r5-5-5_i_tau5e-2")
mass_err02 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p2_r5-10-10_i_tau5e-2")
mass_err12 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p2_r5-15-15_i_tau5e-2")
mass_err22 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p2_r5-20-20_i_tau5e-2")
mass_err32 = np.abs(time_series.getMassErr())

### Plot

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5))

axs[0, 0].plot(time, mass_err00, label="$\mathcal{{P}}_0$")
axs[0, 0].plot(time, mass_err01, label="$\mathcal{{P}}_1$")
axs[0, 0].plot(time, mass_err02, label="$\mathcal{{P}}_2$")
axs[0, 0].set_title("$r = (5, 5, 5)$")

axs[0, 1].plot(time, mass_err10)
axs[0, 1].plot(time, mass_err11)
axs[0, 1].plot(time, mass_err12)
axs[0, 1].set_title("$r = (5, 10, 10)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

axs[1, 0].plot(time, mass_err20)
axs[1, 0].plot(time, mass_err21)
axs[1, 0].plot(time, mass_err22)
axs[1, 0].set_title("$r = (5, 15, 15)$")

axs[1, 1].plot(time, mass_err30)
axs[1, 1].plot(time, mass_err31)
axs[1, 1].plot(time, mass_err32)
axs[1, 1].set_title("$r = (5, 20, 20)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.setp(axs, xlabel="time $t$", ylabel="$\Delta m$", xlim=[0.0, 145.0], xticks=[0.0, 50.0, 100.0, 145.0], ylim=[8e-9, 1e-2], yscale="log")
plt.subplots_adjust(hspace=0.5)
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="upper center")

plt.savefig("plots/mass_err_comparison_ttn.pdf");

## Wall time

### Load data

#### Partition 0

In [None]:
time_series = TimeSeries("output/bax_p0_r5-5-5_i_tau5e-2")
walltime00 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p0_r5-10-10_i_tau5e-2")
walltime10 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p0_r5-15-15_i_tau5e-2")
walltime20 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p0_r5-20-20_i_tau5e-2")
walltime30 = time_series.getWallTime()

#### Partition 1

In [None]:
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau5e-2")
walltime01 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau5e-2")
walltime11 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau5e-2")
walltime21 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p1_r5-20-20_i_tau5e-2")
walltime31 = time_series.getWallTime()

#### Partition 2

In [None]:
time_series = TimeSeries("output/bax_p2_r5-5-5_i_tau5e-2")
walltime02 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p2_r5-10-10_i_tau5e-2")
walltime12 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p2_r5-15-15_i_tau5e-2")
walltime22 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p2_r5-20-20_i_tau5e-2")
walltime32 = time_series.getWallTime()

### Plot

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5))
labels = ["$\mathcal{{P}}_{}$".format(i) for i in range(3)]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
walltime0 = [walltime00, walltime01, walltime02]
axs[0, 0].bar(labels, walltime0, color=colors)
axs[0, 0].set_title("$r = (5, 5, 5)$")

walltime1 = [walltime10, walltime11, walltime12]
axs[0, 1].bar(labels, walltime1, color=colors)
axs[0, 1].set_title("$r = (5, 10, 10)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

walltime2 = [walltime20, walltime21, walltime22]
axs[1, 0].bar(labels, walltime2, color=colors)
axs[1, 0].set_title("$r = (5, 15, 15)$")

walltime3 = [walltime30, walltime31, walltime32]
axs[1, 1].bar(labels, walltime3, color=colors)
axs[1, 1].set_title("$r = (5, 20, 20)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.setp(axs[:, 0], ylabel="wall time [$\mathrm{s}$]")
# plt.setp(axs, ylim=[1e0, 1e6], yscale="log")
plt.subplots_adjust(hspace=0.3)

plt.savefig("plots/wall_time_comparison_ttn.pdf");