# BAX pore assembly

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

from scripts.grid_class import GridParms
from scripts.tree_class import Tree
from scripts.notebooks.output_helper import *
from scripts.reference_solutions.ssa_helper import SSASol

plt.style.use("./scripts/notebooks/custom_style.mplstyle")
%matplotlib inline

In [None]:
slice_vec = np.array([12, 9, 2, 1, 0, 0, 0, 0, 0, 50, 0])

## Load initial data

### Tree tensor integrator

#### Partition 0
(0 1 2)(((3 4 6 7)(5 8))(9 10))

#### Partition 1
((0 1 2)(3 4 5))((6 7 9)(8 10))

In [None]:
tree = readTree("output/bax_p1_r5-5-5_i_tau2e-2/output_t7250.nc")
P_slice_p1_r_5_5_5, P_sum_p1_r_5_5_5 = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/bax_p1_r5-10-10_i_tau2e-2/output_t7250.nc")
P_slice_p1_r_5_10_10, P_sum_p1_r_5_10_10 = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/bax_p1_r5-15-15_i_tau2e-2/output_t7250.nc")
P_slice_p1_r_5_15_15, P_sum_p1_r_5_15_15 = tree.calculateObservables(slice_vec)

#### Partition 2
((0 1)(2 3 4))((5 6 7 8)(9 10))

#### Two partitions
(0 1 2 3 4)(5 6 7 8 9 10)

In [None]:
tree = readTree("output/bax_pfull_r5_i_tau2e-2/output_t7250.nc")
P_slice_full, P_sum_full = tree.calculateObservables(slice_vec)

### Matrix integrator
Two partitions, $r=5$, explicit Euler with variable time step size ($\min_i({{\tau_i}})=1.0$) and 100 substeps

In [None]:
with xr.open_dataset("output/bax_matrix/output_t100.nc") as ds:
    grid = GridInfo(ds)
    lr_sol = LRSol(ds, grid)
    idx_2D = np.array([0, 1])
    P_sum_matrix = lr_sol.marginalDistributions()
    P_slice_matrix = lr_sol.slicedDistributions(slice_vec)

### SSA

In [None]:
ssa_1e7 = np.load("scripts/reference_solutions/bax_ssa_1e7.npy")
ssa_1e7_sol = SSASol(ssa_1e7)
P_sum_ssa_1e7, _, P_slice_ssa_1e7, _ = ssa_1e7_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
ssa_2e6 = np.load("scripts/reference_solutions/bax_ssa_2e6.npy")
ssa_2e6_sol = SSASol(ssa_2e6)
P_sum_ssa_2e6, _, P_slice_ssa_2e6, _ = ssa_2e6_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
ssa_1e6 = np.load("scripts/reference_solutions/bax_ssa_1e6.npy")
ssa_1e6_sol = SSASol(ssa_1e6)
P_sum_ssa_1e6, _, P_slice_ssa_1e6, _ = ssa_1e6_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
ssa_1e5 = np.load("scripts/reference_solutions/bax_ssa_1e5.npy")
ssa_1e5_sol = SSASol(ssa_1e5)
P_sum_ssa_1e5, _, P_slice_ssa_1e5, _ = ssa_1e5_sol.calculateObservables(slice_vec, idx_2D)

In [None]:
ssa_1e4 = np.load("scripts/reference_solutions/bax_ssa_1e4.npy")
ssa_1e4_sol = SSASol(ssa_1e4)
P_sum_ssa_1e4, _, P_slice_ssa_1e4, _ = ssa_1e4_sol.calculateObservables(slice_vec, idx_2D)

## Sliced distributions

### TTN integrator

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5), sharex='col')
axs[0, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_full[0][11:26], '.-', label="2 partitions")
axs[0, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_p1_r_5_10_10[0][11:26], 'v', label="$\mathcal{{P}}_1$, $r = (5, 10, 10)$")
axs[0, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_p1_r_5_15_15[0][11:26], '^', label="$\mathcal{{P}}_1$, $r = (5, 15, 15)$")

axs[1, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_full[0][11:26], '.-')
axs[1, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_p1_r_5_10_10[0][11:26], 'v')
axs[1, 0].plot(np.arange(tree.grid.n[0])[11:26], P_slice_p1_r_5_15_15[0][11:26], '^')

axs[0, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_full[9][45:], '.-')
axs[0, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_p1_r_5_10_10[9][45:], 'v')
axs[0, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_p1_r_5_15_15[9][45:], '^')

axs[1, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_full[9][45:], '.-')
axs[1, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_p1_r_5_10_10[9][45:], 'v')
axs[1, 1].plot(np.arange(tree.grid.n[9])[45:], P_slice_p1_r_5_15_15[9][45:], '^')

axs[0, 0].set_ylabel("$P_\\mathrm{{S}}(x_0)$")
axs[0, 0].set_yscale("log")
axs[0, 0].minorticks_off()

axs[1, 0].set_xlabel(xlabel="$x_0$")
axs[1, 0].set_ylabel("$P_\\mathrm{{S}}(x_0)$")
axs[1, 0].set_xticks(np.arange(11, 25+2, 2))
axs[1, 0].set_xlim([10, 26])

axs[0, 1].set_ylabel("$P_\\mathrm{{S}}(x_9)$")
axs[0, 1].set_yscale("log")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

axs[1, 1].set_xlabel(xlabel="$x_9$")
axs[1, 1].set_ylabel("$P_\\mathrm{{S}}(x_9)$")
axs[1, 1].set_xticks(np.arange(45, 55+2, 2))
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="upper center")
plt.subplots_adjust(hspace=.0)

plt.savefig("plots/sliced_ttn.pdf");

In [None]:
tree = readTree("output/bax_p1_r5-5-5_i_tau2e-2/output_t7250.nc")
P_slice_p1_r_5_5_5, P_sum_p1_r_5_5_5 = tree.calculateObservables(slice_vec)

tree = readTree("output/bax_p1_r5-10-10_i_tau2e-2/output_t7250.nc")
P_slice_p1_r_5_10_10, P_sum_p1_r_5_10_10 = tree.calculateObservables(slice_vec)

tree = readTree("output/bax_p1_r5-15-15_i_tau2e-2/output_t7250.nc")
P_slice_p1_r_5_15_15, P_sum_p1_r_5_15_15 = tree.calculateObservables(slice_vec)

### SSA

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5), sharex='col')
axs[0, 0].plot(np.arange(tree.grid.n[0])[17:33], P_slice_full[0][17:33], '.-', label="$P_\\mathrm{{S}}^\\mathrm{{TTN,ref}}$")
axs[0, 0].plot((np.arange(ssa_1e7_sol.n[0])+ssa_1e7_sol.n_min[0])[16:31], P_slice_ssa_1e7[-1][0][16:31], 'v', label="$10^7$ runs")
axs[0, 0].plot((np.arange(ssa_1e6_sol.n[0])+ssa_1e6_sol.n_min[0])[15:30], P_slice_ssa_1e6[-1][0][15:30], '<', label="$10^6$ runs")
axs[0, 0].plot((np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0])[14:29], P_slice_ssa_1e5[-1][0][14:29], '^', label="$10^5$ runs")
axs[0, 0].plot((np.arange(ssa_1e4_sol.n[0])+ssa_1e4_sol.n_min[0])[14:29], P_slice_ssa_1e4[-1][0][14:29], '>', label="$10^4$ runs")
axs[0, 0].set_ylabel("$P_\\mathrm{{S}}(x_0)$")
axs[0, 0].set_yscale("log")

axs[1, 0].plot(np.arange(tree.grid.n[0])[17:33], P_slice_full[0][17:33], '.-')
axs[1, 0].plot((np.arange(ssa_1e7_sol.n[0])+ssa_1e7_sol.n_min[0])[16:31], P_slice_ssa_1e7[-1][0][16:31], 'v')
axs[1, 0].plot((np.arange(ssa_1e6_sol.n[0])+ssa_1e6_sol.n_min[0])[15:30], P_slice_ssa_1e6[-1][0][15:30], '<')
axs[1, 0].plot((np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0])[14:29], P_slice_ssa_1e5[-1][0][14:29], '^')
axs[1, 0].plot((np.arange(ssa_1e4_sol.n[0])+ssa_1e4_sol.n_min[0])[14:29], P_slice_ssa_1e4[-1][0][14:29], '>')
axs[1, 0].set_xlabel(xlabel="$x_0$")
axs[1, 0].set_ylabel("$P_\\mathrm{{S}}(x_0)$")
axs[1, 0].set_xticks(np.arange(17, 33+2, 2))
axs[1, 0].set_xlim([16, 33])

axs[0, 1].plot(np.arange(tree.grid.n[9])[46:], P_slice_full[9][46:], '.-')
axs[0, 1].plot((np.arange(ssa_1e7_sol.n[9])+ssa_1e7_sol.n_min[9]), P_slice_ssa_1e7[-1][9], 'v')
axs[0, 1].plot((np.arange(ssa_1e6_sol.n[9])+ssa_1e6_sol.n_min[9]), P_slice_ssa_1e6[-1][9], '<')
axs[0, 1].plot((np.arange(ssa_1e5_sol.n[9])+ssa_1e5_sol.n_min[9]), P_slice_ssa_1e5[-1][9], '^')
axs[0, 1].plot((np.arange(ssa_1e4_sol.n[9])+ssa_1e4_sol.n_min[9]), P_slice_ssa_1e4[-1][9], '>')
axs[0, 1].set_ylabel("$P_\\mathrm{{S}}(x_9)$")
axs[0, 1].set_yscale("log")

axs[1, 1].plot(np.arange(tree.grid.n[9])[46:], P_slice_full[9][46:], '.-')
axs[1, 1].plot((np.arange(ssa_1e7_sol.n[9])+ssa_1e7_sol.n_min[9]), P_slice_ssa_1e7[-1][9], 'v')
axs[1, 1].plot((np.arange(ssa_1e6_sol.n[9])+ssa_1e6_sol.n_min[9]), P_slice_ssa_1e6[-1][9], '<')
axs[1, 1].plot((np.arange(ssa_1e5_sol.n[9])+ssa_1e5_sol.n_min[9]), P_slice_ssa_1e5[-1][9], '^')
axs[1, 1].plot((np.arange(ssa_1e4_sol.n[9])+ssa_1e4_sol.n_min[9]), P_slice_ssa_1e4[-1][9], '>')
axs[1, 1].set_xlabel(xlabel="$x_9$")
axs[1, 1].set_ylabel("$P_\\mathrm{{S}}(x_9)$")
# axs[1, 1].set_xticks(np.arange(8, 15))
# axs[1, 1].set_xlim([7, 15])

axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")
fig.legend(*axs[0, 0].get_legend_handles_labels(), loc="upper center", ncols=5)
plt.subplots_adjust(hspace=.0)

plt.savefig("plots/sliced_ssa.pdf");

## Comparison of sliced distribution between TTN integrator and SSA

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(9, 6))
nn = 0
lim = [17, 32]

intvl = slice(lim[0], lim[1])
axs[0, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_full[nn][intvl], '.-', label="$P_\\mathrm{{S}}^\\mathrm{{TTN,ref}}$")
axs[0, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_15_15[0][intvl], 'v', label="$r = (5, 15, 15)$")
axs[0, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_10_10[0][intvl], '<', label="$r = (5, 10, 10)$")
axs[0, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_5_5[0][intvl], '^', label="$r = (5, 5, 5)$")
axs[0, 0].set_ylabel("$P_\\mathrm{{S}}(x_{})$".format(nn))
axs[0, 0].set_yscale("log")
axs[0, 0].set_title("PS-TTN integrator")
axs[0, 0].legend()

axs[1, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_full[nn][intvl], '.-')
axs[1, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_15_15[0][intvl], 'v')
axs[1, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_10_10[0][intvl], '<')
axs[1, 0].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_p1_r_5_5_5[0][intvl], '^')
axs[1, 0].set_xlabel(xlabel="$x_{}$".format(nn))
axs[1, 0].set_ylabel("$P_\\mathrm{{S}}(x_{})$".format(nn))
axs[1, 0].set_xticks(np.arange(17, 32+2, 2))

axs[0, 1].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_full[nn][intvl], '.-', label="$P_\\mathrm{{S}}^\\mathrm{{TTN,ref}}$")
axs[0, 1].plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn])[16:], P_slice_ssa_1e7[-1][nn][16:], 'v', label="$10^7$ runs")
axs[0, 1].plot((np.arange(ssa_1e6_sol.n[nn])+ssa_1e6_sol.n_min[nn]), P_slice_ssa_1e6[-1][nn], '<', label="$10^6$ runs")
axs[0, 1].plot((np.arange(ssa_1e5_sol.n[nn])+ssa_1e5_sol.n_min[nn]), P_slice_ssa_1e5[-1][nn], '^', label="$10^5$ runs")
axs[0, 1].plot((np.arange(ssa_1e4_sol.n[nn])+ssa_1e4_sol.n_min[nn]), P_slice_ssa_1e4[-1][nn], '>', label="$10^4$ runs")
axs[0, 1].set_ylabel("$P_\\mathrm{{S}}(x_{})$".format(nn))
axs[0, 1].set_yscale("log")
axs[0, 1].set_title("SSA")
axs[0, 1].legend()

axs[1, 1].plot(np.arange(tree.grid.n[nn])[intvl], P_slice_full[nn][intvl], '.-')
axs[1, 1].plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn])[16:31], P_slice_ssa_1e7[-1][nn][16:31], 'v')
axs[1, 1].plot((np.arange(ssa_1e6_sol.n[nn])+ssa_1e6_sol.n_min[nn])[15:30], P_slice_ssa_1e6[-1][nn][15:30], '<')
axs[1, 1].plot((np.arange(ssa_1e5_sol.n[nn])+ssa_1e5_sol.n_min[nn])[14:29], P_slice_ssa_1e5[-1][nn][14:29], '^')
axs[1, 1].plot((np.arange(ssa_1e4_sol.n[nn])+ssa_1e4_sol.n_min[nn])[14:29], P_slice_ssa_1e4[-1][nn][14:29], '>')
axs[1, 1].set_xlabel(xlabel="$x_{}$".format(nn))
axs[1, 1].set_ylabel("$P_\\mathrm{{S}}(x_{})$".format(nn))
axs[1, 1].set_xticks(np.arange(17, 32+2, 2))

axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.setp(axs[0, :], ylim=[5e-12, 5e-2])
plt.setp(axs[1, :], ylim=[-0.0005, 0.0105])
plt.setp(axs, xlim=[16, 32])

plt.subplots_adjust(wspace=0.1, hspace=0.0)

plt.savefig("plots/sliced_ttn_ssa.pdf");

In [None]:
lim = [17, 32]
intvl = slice(lim[0], lim[1])
P_slice_p1_r_5_5_5_err = np.abs(P_slice_p1_r_5_5_5[0][intvl] - P_slice_full[0][intvl])
P_slice_p1_r_5_10_10_err = np.abs(P_slice_p1_r_5_10_10[0][intvl] - P_slice_full[0][intvl])
P_slice_p1_r_5_15_15_err = np.abs(P_slice_p1_r_5_15_15[0][intvl] - P_slice_full[0][intvl])

# P_slice_ssa_1e7_err = np.abs(P_slice_ssa_1e7[-1][0][16:31] - P_slice_full[0][intvl])
# P_slice_ssa_1e6_err = np.abs(P_slice_ssa_1e6[-1][0][15:30] - P_slice_full[0][intvl])
# P_slice_ssa_1e5_err = np.abs(P_slice_ssa_1e5[-1][0][14:29] - P_slice_full[0][intvl])
# P_slice_ssa_1e4_err = np.abs(P_slice_ssa_1e4[-1][0][14:29] - P_slice_full[0][intvl])

P_slice_ssa_1e6_err = np.abs(P_slice_ssa_1e6[-1][0][15:30] - P_slice_ssa_1e7[-1][0][16:31])
P_slice_ssa_1e5_err = np.abs(P_slice_ssa_1e5[-1][0][14:29] - P_slice_ssa_1e7[-1][0][16:31])
P_slice_ssa_1e4_err = np.abs(P_slice_ssa_1e4[-1][0][14:29] - P_slice_ssa_1e7[-1][0][16:31])

In [None]:
fig, axs = plt.subplots(1, 2, sharey="row")

axs[0].plot(np.arange(tree.grid.n[0])[intvl], P_slice_p1_r_5_15_15_err, 'v-', label="$r = (5, 15, 15)$")
axs[0].plot(np.arange(tree.grid.n[0])[intvl], P_slice_p1_r_5_10_10_err, '<-', label="$r = (5, 10, 10)$")
axs[0].plot(np.arange(tree.grid.n[0])[intvl], P_slice_p1_r_5_5_5_err, '^-', label="$r = (5, 5, 5)$")

# axs[1].plot((np.arange(ssa_1e7_sol.n[0])+ssa_1e7_sol.n_min[0])[16:31], P_slice_ssa_1e7_err, 'v-', label="$10^7$ runs")
axs[1].plot((np.arange(ssa_1e6_sol.n[0])+ssa_1e6_sol.n_min[0])[15:30], P_slice_ssa_1e6_err, 'v-', label="$10^6$ runs")
axs[1].plot((np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0])[14:29], P_slice_ssa_1e5_err, '<-', label="$10^5$ runs")
axs[1].plot((np.arange(ssa_1e4_sol.n[0])+ssa_1e4_sol.n_min[0])[14:29], P_slice_ssa_1e4_err, '^-', label="$10^4$ runs")

axs[0].set_ylabel("$\Vert P_\\mathrm{{S}}(x_0) - P_\\mathrm{{S}}^\\mathrm{{TTN,ref}}(x_0) \Vert$")
# axs[0].set_yscale("log")
axs[0].set_title("PS-TTN integrator")
axs[0].legend()
axs[0].set_xticks(np.arange(17, 32+2, 2))
axs[0].set_xlim([16, 32])
axs[0].set_xlabel(xlabel="$x_0$")
axs[0].ticklabel_format(style='sci', axis='y', scilimits=(-2,2))

axs[1].set_title("SSA", y=1.05)
axs[1].legend()
axs[1].set_xticks(np.arange(17, 32+2, 2))
axs[1].set_xlim([16, 32])
axs[1].set_xlabel(xlabel="$x_0$")
plt.subplots_adjust(wspace=.0)

plt.savefig("plots/sliced_ttn_ssa_err.pdf");

## Marginal distributions

### TTN integrator

In [None]:
nn = 0
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
ax1.plot(np.arange(grid.n[nn]), P_sum_matrix[nn], '.-', label="2 partitions")
ax1.plot(np.arange(tree.grid.n[nn]), P_sum_p1_r_5_10_10[0], 'v', label="$\mathcal{{P}}_1$, $r = (5, 10, 10)$")
ax1.plot(np.arange(tree.grid.n[nn]), P_sum_p1_r_5_15_15[0], '^', label="$\mathcal{{P}}_1$, $r = (5, 15, 25)$")
# ax1.plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_sum_ssa_1e7[-1][nn], 'v', label="SSA ($10^7$ runs)")
ax1.set_xlabel("$x_0$")
ax1.set_ylabel("$P_\\mathrm{{M}}(x_{})$".format(nn))

nn = 9
ax2.plot(np.arange(grid.n[nn])[45:], P_sum_full[nn][45:], '.-')
ax2.plot(np.arange(tree.grid.n[nn])[45:], P_sum_p1_r_5_10_10[9][45:], 'v')
ax2.plot(np.arange(tree.grid.n[nn])[45:], P_sum_p1_r_5_15_15[9][45:], '^')
# ax2.plot((np.arange(ssa1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_sum_ssa_1e7[-1][nn], 'v')
ax2.set_xlabel("$x_1$")
ax2.yaxis.tick_right()
ax2.yaxis.set_ticks_position("both")
ax2.yaxis.set_label_position("right")
ax2.set_ylabel("$P_\\mathrm{{M}}(x_{})$".format(nn))
fig.legend(*ax1.get_legend_handles_labels(), ncols=3, loc="upper center")

plt.savefig("plots/marginal_ttn.pdf");

### SSA

In [None]:
nn = 0
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))
ax1.plot(np.arange(grid.n[nn]), P_sum_matrix[nn], '.-', label="$P_\\mathrm{{M}}^\\mathrm{TTN,ref}$")
ax1.plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_sum_ssa_1e7[-1][nn], 'v', label="$10^7$ runs")
ax1.plot((np.arange(ssa_1e6_sol.n[nn])+ssa_1e6_sol.n_min[nn]), P_sum_ssa_1e6[-1][nn], '<', label="$10^6$ runs")
ax1.plot((np.arange(ssa_1e5_sol.n[nn])+ssa_1e5_sol.n_min[nn]), P_sum_ssa_1e5[-1][nn], '^', label="$10^5$ runs")
ax1.plot((np.arange(ssa_1e4_sol.n[nn])+ssa_1e4_sol.n_min[nn]), P_sum_ssa_1e4[-1][nn], '>', label="$10^4$ runs")
ax1.set_xlabel("$x_0$")
ax1.set_ylabel("$P_\\mathrm{{M}}(x_{})$".format(nn))

nn = 9
ax2.plot(np.arange(grid.n[nn])[45:], P_sum_full[nn][45:], '.-')
ax2.plot((np.arange(ssa_1e7_sol.n[nn])+ssa_1e7_sol.n_min[nn]), P_sum_ssa_1e7[-1][nn], 'v')
ax2.plot((np.arange(ssa_1e6_sol.n[nn])+ssa_1e6_sol.n_min[nn]), P_sum_ssa_1e6[-1][nn], '<')
ax2.plot((np.arange(ssa_1e5_sol.n[nn])+ssa_1e5_sol.n_min[nn]), P_sum_ssa_1e5[-1][nn], '^')
ax2.plot((np.arange(ssa_1e4_sol.n[nn])+ssa_1e4_sol.n_min[nn]), P_sum_ssa_1e4[-1][nn], '>')
ax2.set_xlabel("$x_1$")
ax2.yaxis.tick_right()
ax2.yaxis.set_ticks_position("both")
ax2.yaxis.set_label_position("right")
ax2.set_ylabel("$P_\\mathrm{{M}}(x_{})$".format(nn))
fig.legend(*ax1.get_legend_handles_labels(), ncols=5, loc="upper center")

plt.savefig("plots/ssa_ttn.pdf");

## Comparison of TTN and SSA results with matrix integrator

### Load data

#### Get walltimes

In [None]:
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau2e-2")
walltime_p1_r_5_5_5 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau2e-2")
walltime_p1_r_5_10_10 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau2e-2")
walltime_p1_r_5_15_15 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/bax_p1_r10-15-25_i_tau1e-2")
walltime_p1_r_10_15_25 = time_series.getWallTime()

In [None]:
time_series = TimeSeries("output/bax_pfull_r5_i_tau5e-2")
walltime_pfull = time_series.getWallTime()

In [None]:
walltime_ssa_1e4 = 84
walltime_ssa_1e5 = 129
walltime_ssa_1e6 = 358
walltime_ssa_2e6 = 660
walltime_ssa_1e7 = 2898
walltime_matrix = 1.3e5

In [None]:
walltimes = [walltime_ssa_1e6, walltime_ssa_1e5, walltime_ssa_1e4, walltime_p1_r_10_15_25, walltime_p1_r_5_15_15, walltime_p1_r_5_10_10, walltime_p1_r_5_5_5]
labels_walltime = ["$10^6$ runs", "$10^5$ runs", "$10^4$ runs", "$r = (10, 15, 25)$", "$r = (5, 15, 15)$", "$r = (5, 10, 10)$", "$r = (5, 5, 5)$"]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
color = [colors[1], colors[1], colors[1], colors[0], colors[0], colors[0], colors[0]]
bar_labels = ['SSA', '_SSA', '_SSA', 'PS-TTN integrator', '_PS-TTN integrator', '_PS-TTN integrator', '_PS-TTN integrator']

In [None]:
wt = [walltime_ssa_1e7, walltime_ssa_1e6, walltime_ssa_1e5, walltime_ssa_1e4, walltime_p1_r_5_15_15, walltime_p1_r_5_10_10, walltime_p1_r_5_5_5]
labels_wt = ["$10^7$ runs", "$10^6$ runs", "$10^5$ runs", "$10^4$ runs", "$r = (5, 15, 15)$", "$r = (5, 10, 10)$", "$r = (5, 5, 5)$"]
b_color = [colors[1], colors[1], colors[1], colors[1], colors[0], colors[0], colors[0]]
b_labels = ['SSA', '_SSA', '_SSA', '_SSA', 'PS-TTN integrator', '_PS-TTN integrator', '_PS-TTN integrator']

fig, ax = plt.subplots()
ax.barh(labels_wt, wt, label=b_labels, color=b_color)
ax.set_xscale("log")
ax.set_xlabel("wall time [$\mathrm{{s}}$]")
# ax10.set_xlim([1e2, 2e4])
fig.legend(loc="upper center", ncols=2)

plt.savefig("plots/comparison_walltime_ttn_ssa.pdf", bbox_inches="tight");

#### Error between TTN integrator and the matrix integrator reference solution

In [None]:
tree = readTree("output/bax_p1_r5-5-5_i_tau2e-2/output_t7250.nc")
sliced_r_5_5_5, marginal_r_5_5_5 = tree.calculateObservables(slice_vec)

tree = readTree("output/bax_p1_r5-10-10_i_tau2e-2/output_t7250.nc")
sliced_r_5_10_10, marginal_r_5_10_10 = tree.calculateObservables(slice_vec)

tree = readTree("output/bax_p1_r5-15-15_i_tau2e-2/output_t7250.nc")
sliced_r_5_15_15, marginal_r_5_15_15 = tree.calculateObservables(slice_vec)

tree = readTree("output/bax_p1_r10-15-25_i_tau1e-2/output_t14500.nc")
sliced_r_10_15_25, marginal_r_10_15_25 = tree.calculateObservables(slice_vec)

In [None]:
sliced_err_r_5_5_5 = np.zeros(len(sliced_r_5_5_5))
marginal_err_r_5_5_5 = np.zeros(len(marginal_r_5_5_5))
sliced_err_r_5_10_10 = np.zeros(len(sliced_r_5_5_5))
marginal_err_r_5_10_10 = np.zeros(len(marginal_r_5_5_5))
sliced_err_r_5_15_15 = np.zeros(len(sliced_r_5_5_5))
marginal_err_r_5_15_15 = np.zeros(len(marginal_r_5_5_5))
sliced_err_r_5_20_20 = np.zeros(len(sliced_r_5_5_5))
marginal_err_r_5_20_20 = np.zeros(len(marginal_r_5_5_5))

for i in range(tree.grid.d()):
    sliced_err_r_5_5_5[i] = np.linalg.norm(sliced_r_5_5_5[i] - P_slice_full[i]) # Frobenius norm
    marginal_err_r_5_5_5[i] = np.linalg.norm(marginal_r_5_5_5[i] - P_sum_full[i]) # Frobenius norm

    sliced_err_r_5_10_10[i] = np.linalg.norm(sliced_r_5_10_10[i] - P_slice_full[i]) # Frobenius norm
    marginal_err_r_5_10_10[i] = np.linalg.norm(marginal_r_5_10_10[i] - P_sum_full[i]) # Frobenius norm

    sliced_err_r_5_15_15[i] = np.linalg.norm(sliced_r_5_15_15[i] - P_slice_full[i]) # Frobenius norm
    marginal_err_r_5_15_15[i] = np.linalg.norm(marginal_r_5_15_15[i] - P_sum_full[i]) # Frobenius norm

    sliced_err_r_5_20_20[i] = np.linalg.norm(sliced_r_10_15_25[i] - P_slice_full[i]) # Frobenius norm
    marginal_err_r_5_20_20[i] = np.linalg.norm(marginal_r_10_15_25[i] - P_sum_full[i]) # Frobenius norm

#### Error between SSA and the SSA reference solution

In [None]:
d = len(ssa_1e7_sol.n)
SSA_marginal_err_1e4 = np.zeros(d)
SSA_marginal_err_1e5 = np.zeros(d)
SSA_marginal_err_1e6 = np.zeros(d)
SSA_marginal_err_2e6 = np.zeros(d)
SSA_sliced_err_1e4 = np.zeros(d)
SSA_sliced_err_1e5 = np.zeros(d)
SSA_sliced_err_1e6 = np.zeros(d)
SSA_sliced_err_2e6 = np.zeros(d)

for i in range(d):
    n_start_1e4 = ssa_1e4_sol.n_min[i] - ssa_1e7_sol.n_min[i]
    n_start_1e5 = ssa_1e5_sol.n_min[i] - ssa_1e7_sol.n_min[i]
    n_start_1e6 = ssa_1e6_sol.n_min[i] - ssa_1e7_sol.n_min[i]
    n_start_2e6 = ssa_2e6_sol.n_min[i] - ssa_1e7_sol.n_min[i]

    SSA_marginal_err_1e4[i] = np.linalg.norm(P_sum_ssa_1e7[-1][i][n_start_1e4 : n_start_1e4+ssa_1e4_sol.n[i]] - P_sum_ssa_1e4[-1][i][:ssa_1e7_sol.n[i]])
    SSA_marginal_err_1e5[i] = np.linalg.norm(P_sum_ssa_1e7[-1][i][n_start_1e5 : n_start_1e5+ssa_1e5_sol.n[i]] - P_sum_ssa_1e5[-1][i][:ssa_1e7_sol.n[i]])
    SSA_marginal_err_1e6[i] = np.linalg.norm(P_sum_ssa_1e7[-1][i][n_start_1e6 : n_start_1e6+ssa_1e6_sol.n[i]] - P_sum_ssa_1e6[-1][i][:ssa_1e7_sol.n[i]])
    SSA_marginal_err_2e6[i] = np.linalg.norm(P_sum_ssa_1e7[-1][i][n_start_2e6 : n_start_2e6+ssa_2e6_sol.n[i]] - P_sum_ssa_2e6[-1][i][:ssa_1e7_sol.n[i]])

    SSA_sliced_err_1e4[i] = np.linalg.norm(P_slice_ssa_1e7[-1][i][n_start_1e4 : n_start_1e4+ssa_1e4_sol.n[i]] - P_slice_ssa_1e4[-1][i][:ssa_1e7_sol.n[i]])
    SSA_sliced_err_1e5[i] = np.linalg.norm(P_slice_ssa_1e7[-1][i][n_start_1e5 : n_start_1e5+ssa_1e5_sol.n[i]] - P_slice_ssa_1e5[-1][i][:ssa_1e7_sol.n[i]])
    SSA_sliced_err_1e6[i] = np.linalg.norm(P_slice_ssa_1e7[-1][i][n_start_1e6 : n_start_1e6+ssa_1e6_sol.n[i]] - P_slice_ssa_1e6[-1][i][:ssa_1e7_sol.n[i]])
    SSA_sliced_err_2e6[i] = np.linalg.norm(P_slice_ssa_1e7[-1][i][n_start_2e6 : n_start_2e6+ssa_2e6_sol.n[i]] - P_slice_ssa_2e6[-1][i][:ssa_1e7_sol.n[i]])

#### Error between the matrix integrator reference solution and the SSA reference solution

In [None]:
d = len(ssa_1e7_sol.n)
SSA_DLR_marginal_err_ref_1e6 = np.zeros(d)
SSA_DLR_marginal_err_ref_2e6 = np.zeros(d)
SSA_DLR_marginal_err_ref_1e7 = np.zeros(d)
SSA_DLR_sliced_err_ref_1e6 = np.zeros(d)
SSA_DLR_sliced_err_ref_2e6 = np.zeros(d)
SSA_DLR_sliced_err_ref_1e7 = np.zeros(d)
for i in range(d):
    SSA_DLR_marginal_err_ref_1e6[i] = np.linalg.norm(P_sum_full[i][ssa_1e6_sol.n_min[i] : ssa_1e6_sol.n_min[i]+ssa_1e6_sol.n[i]] - P_sum_ssa_1e6[-1][i][:grid.n[i]])
    SSA_DLR_marginal_err_ref_2e6[i] = np.linalg.norm(P_sum_full[i][ssa_2e6_sol.n_min[i] : ssa_2e6_sol.n_min[i]+ssa_2e6_sol.n[i]] - P_sum_ssa_2e6[-1][i][:grid.n[i]])
    SSA_DLR_marginal_err_ref_1e7[i] = np.linalg.norm(P_sum_full[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_sum_ssa_1e7[-1][i][:grid.n[i]])
    SSA_DLR_sliced_err_ref_1e6[i] = np.linalg.norm(P_slice_full[i][ssa_1e6_sol.n_min[i] : ssa_1e6_sol.n_min[i]+ssa_1e6_sol.n[i]] - P_slice_ssa_1e6[-1][i][:grid.n[i]])
    SSA_DLR_sliced_err_ref_2e6[i] = np.linalg.norm(P_slice_full[i][ssa_2e6_sol.n_min[i] : ssa_2e6_sol.n_min[i]+ssa_2e6_sol.n[i]] - P_slice_ssa_2e6[-1][i][:grid.n[i]])
    SSA_DLR_sliced_err_ref_1e7[i] = np.linalg.norm(P_slice_full[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_slice_ssa_1e7[-1][i][:grid.n[i]])

#### Error between the TTN solution and the SSA reference solution

In [None]:
d = len(ssa_1e7_sol.n)
marginal_err_r_5_5_5_SSA = np.zeros(d)
marginal_err_r_5_10_10_SSA = np.zeros(d)
marginal_err_r_5_15_15_SSA = np.zeros(d)
marginal_err_r_10_15_25_SSA = np.zeros(d)
sliced_err_r_5_5_5_SSA = np.zeros(d)
sliced_err_r_5_10_10_SSA = np.zeros(d)
sliced_err_r_5_15_15_SSA = np.zeros(d)
sliced_err_r_10_15_25_SSA = np.zeros(d)

for i in range(d):
    marginal_err_r_5_5_5_SSA[i] = np.linalg.norm(marginal_r_5_5_5[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_sum_ssa_1e7[-1][i][:grid.n[i]])
    marginal_err_r_5_10_10_SSA[i] = np.linalg.norm(marginal_r_5_10_10[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_sum_ssa_1e7[-1][i][:grid.n[i]])
    marginal_err_r_5_15_15_SSA[i] = np.linalg.norm(marginal_r_5_15_15[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_sum_ssa_1e7[-1][i][:grid.n[i]])
    marginal_err_r_10_15_25_SSA[i] = np.linalg.norm(marginal_r_10_15_25[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_sum_ssa_1e7[-1][i][:grid.n[i]])

    sliced_err_r_5_5_5_SSA[i] = np.linalg.norm(sliced_r_5_5_5[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_slice_ssa_1e7[-1][i][:grid.n[i]])
    sliced_err_r_5_10_10_SSA[i] = np.linalg.norm(sliced_r_5_10_10[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_slice_ssa_1e7[-1][i][:grid.n[i]])
    sliced_err_r_5_15_15_SSA[i] = np.linalg.norm(sliced_r_5_15_15[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_slice_ssa_1e7[-1][i][:grid.n[i]])
    sliced_err_r_10_15_25_SSA[i] = np.linalg.norm(sliced_r_10_15_25[i][ssa_1e7_sol.n_min[i] : ssa_1e7_sol.n_min[i]+ssa_1e7_sol.n[i]] - P_slice_ssa_1e7[-1][i][:grid.n[i]])

In [None]:
s = 10

red_marginal_lr_1e2 = marginal_r_5_10_10[s][ssa_1e7_sol.n_min[s] : ssa_1e7_sol.n_min[s]+ssa_1e7_sol.n[s]]
red_marginal_lr = marginal_r_10_15_25[s][ssa_1e7_sol.n_min[s] : ssa_1e7_sol.n_min[s]+ssa_1e7_sol.n[s]]
red_marginal_ssa = P_sum_ssa_1e7[-1][s][:grid.n[s]]

plt.plot(range(red_marginal_lr_1e2.size), red_marginal_lr, 'x')
plt.plot(range(red_marginal_lr.size), red_marginal_lr)
plt.plot(range(red_marginal_lr.size), red_marginal_ssa, ':')
print(np.linalg.norm(red_marginal_lr-red_marginal_ssa))

In [None]:
print(marginal_err_r_10_15_25_SSA)
print(sliced_err_r_5_5_5_SSA)

### Plots

In [None]:
time_series = TimeSeries("output/bax_p1_r10-15-25_i_tau1e-2")
mass_err = np.abs(time_series.getMassErr())
fig, ax = plt.subplots()
ax.plot(time_series.time, mass_err, '-')
ax.set_xlabel("$t$")
ax.set_ylabel("$\Delta m(t)$")
ax.set_yscale("log")
ax.set_ylim([8e-9, 1e-2])

plt.savefig("plots/mass_err_ttn_r10-15-25.pdf");

In [None]:
gs = gridspec.GridSpec(2, 8)

fig = plt.figure(figsize=(7, 8))
ax00 = plt.subplot(gs[0, :4])
ax00.plot(np.arange(grid.d), marginal_err_r_5_5_5_SSA, '.-', label="$r = (5, 5, 5)$")
ax00.plot(np.arange(grid.d), marginal_err_r_5_10_10_SSA, '.-', label="$r = (5, 10, 10)$")
ax00.plot(np.arange(grid.d), marginal_err_r_5_15_15_SSA, '.-', label="$r = (5, 15, 15)$")
ax00.plot(np.arange(grid.d), marginal_err_r_10_15_25_SSA, '.-', label="$r = (10, 15, 25)$")
ax00.set_yscale("log")
ax00.set_xlabel("species $S_i$")
ax00.legend()
ax00.set_title("$\Vert P_M^\mathrm{{TTN}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")

ax01 = plt.subplot(gs[0, 4:])
ax01.plot(np.arange(grid.d), SSA_marginal_err_1e4, 'x-', label="$10^4$ runs")
ax01.plot(np.arange(grid.d), SSA_marginal_err_1e5, 'x-', label="$10^5$ runs")
ax01.plot(np.arange(grid.d), SSA_marginal_err_1e6, 'x-', label="$10^6$ runs")
# ax01.plot(np.arange(grid.d), SSA_marginal_err_2e6, '.-', label="$2 \\times 10^6$ runs")
ax01.set_yscale("log")
ax01.set_xlabel("species $S_i$")
ax01.legend()
ax01.set_title("$\Vert P_M^\mathrm{{SSA}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$")
ax01.yaxis.tick_right()
ax01.yaxis.set_ticks_position("both")
ax01.yaxis.set_label_position("right")

plt.setp((ax00, ax01), ylim=[1e-7, 1e-1], xticks=[0, 2, 4, 6, 8, 10])

ax10 = plt.subplot(gs[1:, 2:7])
ax10.barh(labels_walltime, walltimes, label=bar_labels, color=color)
ax10.set_xscale("log")
ax10.set_xlabel("wall time [$\mathrm{{s}}$]")
ax10.legend()

plt.subplots_adjust(hspace=0.3, wspace=0.0)
# plt.tight_layout()

plt.savefig("plots/comparison_marginal_ttn_ssa.pdf", bbox_inches="tight");

In [None]:
gs = gridspec.GridSpec(2, 8)

fig = plt.figure(figsize=(7, 8))
ax00 = plt.subplot(gs[0, :4])
ax00.plot(np.arange(grid.d), sliced_err_r_5_5_5_SSA, '.-', label="$r = (5, 5, 5)$")
ax00.plot(np.arange(grid.d), sliced_err_r_5_10_10_SSA, '.-', label="$r = (5, 10, 10)$")
ax00.plot(np.arange(grid.d), sliced_err_r_5_15_15_SSA, '.-', label="$r = (5, 15, 15)$")
ax00.plot(np.arange(grid.d), sliced_err_r_10_15_25_SSA, '.-', label="$r = (10, 15, 25)$")
ax00.set_yscale("log")
ax00.set_xlabel("species $S_i$")
ax00.legend()
ax00.set_title("$\Vert P_S^\mathrm{{TTN}}(x_i)-P_S^\mathrm{{SSA,ref}}(x_i) \Vert$")

ax01 = plt.subplot(gs[0, 4:])
ax01.plot(np.arange(grid.d), SSA_sliced_err_1e4, 'x-', label="$10^4$ runs")
ax01.plot(np.arange(grid.d), SSA_sliced_err_1e5, 'x-', label="$10^5$ runs")
ax01.plot(np.arange(grid.d), SSA_sliced_err_1e6, 'x-', label="$10^6$ runs")
# ax01.plot(np.arange(grid.d), SSA_marginal_err_2e6, '.-', label="$2 \\times 10^6$ runs")
ax01.set_yscale("log")
ax01.set_xlabel("species $S_i$")
ax01.legend()
ax01.set_title("$\Vert P_S^\mathrm{{SSA}}(x_i)-P_S^\mathrm{{SSA,ref}}(x_i) \Vert$")
ax01.yaxis.tick_right()
ax01.yaxis.set_ticks_position("both")
ax01.yaxis.set_label_position("right")

plt.setp((ax00, ax01), ylim=[1e-6, 1e-1], xticks=[0, 2, 4, 6, 8, 10])

ax10 = plt.subplot(gs[1:, 2:7])
ax10.barh(labels_walltime, walltimes, label=bar_labels, color=color)
ax10.set_xscale("log")
ax10.set_xlabel("wall time [$\mathrm{{s}}$]")
ax10.legend()

plt.subplots_adjust(hspace=0.3, wspace=0.0)
# plt.tight_layout()

plt.savefig("plots/comparison_sliced_ttn_ssa.pdf", bbox_inches="tight");

## Error depending on time step size

In [None]:
dm_max_r_5_15_25 = []
dm_max_r_5_15_15 = []
dm_max_r_5_10_10 = []
dm_max_r_5_5_5 = []
tau = []

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau5e-1")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau5e-1")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau5e-1")
dm_max_r_5_10_10.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau5e-1")
dm_max_r_5_5_5.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau2e-1")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau2e-1")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau2e-1")
dm_max_r_5_10_10.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau2e-1")
dm_max_r_5_5_5.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau1e-1")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau1e-1")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau1e-1")
dm_max_r_5_10_10.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau1e-1")
dm_max_r_5_5_5.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau5e-2")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau5e-2")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau5e-2")
dm_max_r_5_10_10.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau5e-2")
dm_max_r_5_5_5.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau2e-2")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau2e-2")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau2e-2")
dm_max_r_5_10_10.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau2e-2")
dm_max_r_5_5_5.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau1e-2")
dm_max_r_5_15_25.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau1e-2")
dm_max_r_5_15_15.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau1e-2")
dm_max_r_5_10_10.append(np.abs(time_series.getMaxMassErr()))
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau1e-2")
dm_max_r_5_5_5.append(np.abs(time_series.getMaxMassErr()))
tau.append(time_series.getTau())

### Plot

In [None]:
fig, ax = plt.subplots()
t = np.linspace(0.01, 0.5)

ax.loglog(tau, dm_max_r_5_5_5, '.-', label="$r=(5, 5, 5)$")
ax.loglog(tau, dm_max_r_5_10_10, '.-', label="$r=(5, 10, 10)$")
ax.loglog(tau, dm_max_r_5_15_15, '.-', label="$r=(5, 15, 15)$")
# ax.loglog(tau, dm_max_r_5_15_25, '.-', label="$r=(5, 15, 25)$")
ax.loglog(t, t**3*10, 'k:', label="$\\tau^3$")
ax.set_xlabel("time step size $\\tau$")
ax.set_ylabel("$\max_t(|\Delta m(t)|)$")
fig.legend(loc="upper center", ncols=5)

plt.savefig("plots/time_mass_err_ttn.pdf");

## Error depending on rank and partition

### Load data

#### Partition 0

In [None]:
sliced_err00, marginal_err00 = calculateDistributionError("output/bax_p0_r5-5-5_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err10, marginal_err10 = calculateDistributionError("output/bax_p0_r5-10-10_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err20, marginal_err20 = calculateDistributionError("output/bax_p0_r5-15-15_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err30, marginal_err30 = calculateDistributionError("output/bax_p0_r5-20-20_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

#### Partition 1

In [None]:
sliced_err01, marginal_err01 = calculateDistributionError("output/bax_p1_r5-5-5_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err11, marginal_err11 = calculateDistributionError("output/bax_p1_r5-10-10_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err21, marginal_err21 = calculateDistributionError("output/bax_p1_r5-15-15_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err31, marginal_err31 = calculateDistributionError("output/bax_p1_r5-20-20_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

#### Partition 2

In [None]:
sliced_err02, marginal_err02 = calculateDistributionError("output/bax_p2_r5-5-5_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err12, marginal_err12 = calculateDistributionError("output/bax_p2_r5-10-10_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err22, marginal_err22 = calculateDistributionError("output/bax_p2_r5-15-15_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

sliced_err32, marginal_err32 = calculateDistributionError("output/bax_p2_r5-20-20_i_tau2e-2/output_t7250.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

#### Inset

In [None]:
sliced_err_r_5_15_25, marginal_err_r_5_15_25 = calculateDistributionError("output/bax_p1_r10-15-25_i_tau1e-2/output_t14500.nc", P_slice_ssa_1e7[-1], P_sum_ssa_1e7[-1], slice_vec, ssa_1e7_sol)

### Plot

In [None]:
# from mpl_toolkits.axes_grid1.inset_locator import inset_axes
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
cols = ["$\mathcal{{P}}_0$", "$\mathcal{{P}}_1$", "$\mathcal{{P}}_2$"]
rows = ["$r = (5, 5, 5)$", "$r = (5, 10, 10)$", "$r = (5, 15, 15)$", "$r = (5, 20, 20)$"]
label = "$\Vert P_M(x_i)-P_M(x_i)^\mathrm{{TTN,ref}} \Vert$"

fig, axs = plt.subplots(4, 3, figsize=(10, 8), sharex='col', sharey='row')

line, = axs[0, 0].plot(np.arange(tree.grid.d()), marginal_err00, '.-', color=colors[0])
axs[1, 0].plot(np.arange(tree.grid.d()), marginal_err10, '.-', color=colors[0])
axs[2, 0].plot(np.arange(tree.grid.d()), marginal_err20, '.-', color=colors[0])
axs[3, 0].plot(np.arange(tree.grid.d()), marginal_err30, '.-', color=colors[0])

axs[0, 1].plot(np.arange(tree.grid.d()), marginal_err01, '.-', color=colors[1])
axs[1, 1].plot(np.arange(tree.grid.d()), marginal_err11, '.-', color=colors[1])
axs[2, 1].plot(np.arange(tree.grid.d()), marginal_err21, '.-', color=colors[1])
axs[3, 1].plot(np.arange(tree.grid.d()), marginal_err31, '.-', color=colors[1])

axs[0, 2].plot(np.arange(tree.grid.d()), marginal_err02, '.-', color=colors[2])
axs[1, 2].plot(np.arange(tree.grid.d()), marginal_err12, '.-', color=colors[2])
axs[2, 2].plot(np.arange(tree.grid.d()), marginal_err22, '.-', color=colors[2])
axs[3, 2].plot(np.arange(tree.grid.d()), marginal_err32, '.-', color=colors[2])

# axins = inset_axes(axs[-1, 1], width='65%', height='65%', loc='upper center')
# axins.plot(np.arange(tree.grid.d()), SSA_marginal_err_r_5_15_25, '--')
# axins.plot(np.arange(tree.grid.d()), DLR_marginal_err_r_5_15_25, '.-')
# axins.set_ylim([-0.02*0.6, 0.42*0.6])
# axins.annotate("TTN with 2 partitions\n($r=5$)", xy=(0.5, 1), xytext=(0, -30),
#                 xycoords='axes fraction', textcoords='offset points',
#                 size=11, ha='center', va='baseline')

plt.setp(axs, ylim=[-0.02, 0.22])
# plt.setp(axs[:, 0], ylim=[-0.02, 0.22])
# plt.setp(axs[:, 1], ylim=[-0.02, 0.47])
# plt.setp(axs[:, 2], ylim=[-0.02, 0.34])

pad = 5 # in points

for ax, col in zip(axs[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad+5),
                xycoords='axes fraction', textcoords='offset points',
                size=14, ha='center', va='baseline')

for ax, row in zip(axs[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad-pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=14, ha='right', va='center')

plt.setp(axs[-1, :], xlabel="species $S_i$", xticks=[0, 2, 4, 6, 8, 10])
fig.suptitle(label, fontsize=16)
plt.subplots_adjust(hspace=.0, wspace=.0)

plt.savefig("plots/comparison_marginal_ttn.pdf", bbox_inches="tight");

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5))

axs[0, 0].plot(np.arange(tree.grid.d()), marginal_err00, '.-', label="$\mathcal{{P}}_0$")
axs[0, 0].plot(np.arange(tree.grid.d()), marginal_err01, '.-', label="$\mathcal{{P}}_1$")
axs[0, 0].plot(np.arange(tree.grid.d()), marginal_err02, '.-', label="$\mathcal{{P}}_2$")
axs[0, 0].set_title("$r = (5, 5, 5)$")

axs[0, 1].plot(np.arange(tree.grid.d()), marginal_err10, '.-')
axs[0, 1].plot(np.arange(tree.grid.d()), marginal_err11, '.-')
axs[0, 1].plot(np.arange(tree.grid.d()), marginal_err12, '.-')
axs[0, 1].set_title("$r = (5, 10, 10)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

axs[1, 0].plot(np.arange(tree.grid.d()), marginal_err20, '.-')
axs[1, 0].plot(np.arange(tree.grid.d()), marginal_err21, '.-')
axs[1, 0].plot(np.arange(tree.grid.d()), marginal_err22, '.-')
axs[1, 0].set_title("$r = (5, 15, 15)$")

axs[1, 1].plot(np.arange(tree.grid.d()), marginal_err30, '.-')
axs[1, 1].plot(np.arange(tree.grid.d()), marginal_err31, '.-')
axs[1, 1].plot(np.arange(tree.grid.d()), marginal_err32, '.-')
axs[1, 1].set_title("$r = (5, 20, 20)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.setp(axs, ylim=[-0.02, 0.22])
plt.setp(axs, xlabel="species $S_i$", xticks=[0, 2, 4, 6, 8, 10])
plt.subplots_adjust(hspace=0.5)
fig.suptitle("$\Vert P_M^\mathrm{{TTN}}(x_i)-P_M^\mathrm{{SSA,ref}}(x_i) \Vert$", fontsize=16, y=1.05)
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="center", bbox_to_anchor=(0.5, 0.95))

plt.savefig("plots/err_comparison_ttn_alt.pdf", bbox_inches="tight");

In [None]:
# from mpl_toolkits.axes_grid1.inset_locator import inset_axes
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
cols = ["$\mathcal{{P}}_0$", "$\mathcal{{P}}_1$", "$\mathcal{{P}}_2$"]
rows = ["$r = (5, 5, 5)$", "$r = (5, 10, 10)$", "$r = (5, 15, 15)$", "$r = (5, 20, 20)$"]
label = "$\Vert P_S(x_i)-P_S(x_i)^\mathrm{{TTN,ref}} \Vert$"

fig, axs = plt.subplots(4, 3, figsize=(10, 8), sharex='col', sharey='row')

line1, = axs[0, 0].plot(np.arange(tree.grid.d()), sliced_err00, '.-', color=colors[0])
axs[1, 0].plot(np.arange(tree.grid.d()), sliced_err10, '.-', color=colors[0])
axs[2, 0].plot(np.arange(tree.grid.d()), sliced_err20, '.-', color=colors[0])
axs[3, 0].plot(np.arange(tree.grid.d()), sliced_err30, '.-', color=colors[0])

axs[0, 1].plot(np.arange(tree.grid.d()), sliced_err01, '.-', color=colors[1])
axs[1, 1].plot(np.arange(tree.grid.d()), sliced_err11, '.-', color=colors[1])
axs[2, 1].plot(np.arange(tree.grid.d()), sliced_err21, '.-', color=colors[1])
axs[3, 1].plot(np.arange(tree.grid.d()), sliced_err31, '.-', color=colors[1])

axs[0, 2].plot(np.arange(tree.grid.d()), sliced_err02, '.-', color=colors[2])
axs[1, 2].plot(np.arange(tree.grid.d()), sliced_err12, '.-', color=colors[2])
axs[2, 2].plot(np.arange(tree.grid.d()), sliced_err22, '.-', color=colors[2])
axs[3, 2].plot(np.arange(tree.grid.d()), sliced_err32, '.-', color=colors[2])

plt.setp(axs, ylim=[-0.0005, 0.0081])

pad = 5 # in points

for ax, col in zip(axs[0], cols):
    ax.annotate(col, xy=(0.5, 1), xytext=(0, pad+5),
                xycoords='axes fraction', textcoords='offset points',
                size=14, ha='center', va='baseline')

for ax, row in zip(axs[:,0], rows):
    ax.annotate(row, xy=(0, 0.5), xytext=(-ax.yaxis.labelpad-pad, 0),
                xycoords=ax.yaxis.label, textcoords='offset points',
                size=14, ha='right', va='center')

plt.setp(axs[-1, :], xlabel="species $S_i$", xticks=[0, 2, 4, 6, 8, 10])
fig.suptitle(label, fontsize=16)
plt.subplots_adjust(hspace=.0, wspace=.0)

plt.savefig("plots/comparison_sliced_ttn.pdf", bbox_inches="tight");

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5))

axs[0, 0].plot(np.arange(tree.grid.d()), sliced_err00, '.-', label="$\mathcal{{P}}_0$")
axs[0, 0].plot(np.arange(tree.grid.d()), sliced_err01, '.-', label="$\mathcal{{P}}_1$")
axs[0, 0].plot(np.arange(tree.grid.d()), sliced_err02, '.-', label="$\mathcal{{P}}_2$")
axs[0, 0].set_title("$r = (5, 5, 5)$")

axs[0, 1].plot(np.arange(tree.grid.d()), sliced_err10, '.-')
axs[0, 1].plot(np.arange(tree.grid.d()), sliced_err11, '.-')
axs[0, 1].plot(np.arange(tree.grid.d()), sliced_err12, '.-')
axs[0, 1].set_title("$r = (5, 10, 10)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

axs[1, 0].plot(np.arange(tree.grid.d()), sliced_err20, '.-')
axs[1, 0].plot(np.arange(tree.grid.d()), sliced_err21, '.-')
axs[1, 0].plot(np.arange(tree.grid.d()), sliced_err22, '.-')
axs[1, 0].set_title("$r = (5, 15, 15)$")

axs[1, 1].plot(np.arange(tree.grid.d()), sliced_err30, '.-')
axs[1, 1].plot(np.arange(tree.grid.d()), sliced_err31, '.-')
axs[1, 1].plot(np.arange(tree.grid.d()), sliced_err32, '.-')
axs[1, 1].set_title("$r = (5, 20, 20)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.setp(axs, ylim=[1e-5, 1e-1], yscale="log")
plt.setp(axs, xlabel="species $S_i$", xticks=[0, 2, 4, 6, 8, 10])
plt.subplots_adjust(hspace=0.5)
fig.suptitle("$\Vert P_S^\mathrm{{TTN}}(x_i)-P_S^\mathrm{{SSA,ref}}(x_i) \Vert$", fontsize=16, y=1.05)
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="center", bbox_to_anchor=(0.5, 0.95))

plt.savefig("plots/err_comparison_ttn_slice_alt.pdf", bbox_inches="tight");

## Mass error

### Load data

#### Partition 0

In [None]:
time_series = TimeSeries("output/bax_p0_r5-5-5_i_tau2e-2")
mass_err00 = np.abs(time_series.getMassErr())
time = time_series.time

time_series = TimeSeries("output/bax_p0_r5-10-10_i_tau2e-2")
mass_err10 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p0_r5-15-15_i_tau2e-2")
mass_err20 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p0_r5-20-20_i_tau2e-2")
mass_err30 = np.abs(time_series.getMassErr())

#### Partition 1

In [None]:
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau2e-2")
mass_err01 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau2e-2")
mass_err11 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p1_r5-15-25_i_tau2e-2")
mass_err21 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p1_r5-20-20_i_tau2e-2")
mass_err31 = np.abs(time_series.getMassErr())

#### Partition 2

In [None]:
time_series = TimeSeries("output/bax_p2_r5-5-5_i_tau2e-2")
mass_err02 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p2_r5-10-10_i_tau2e-2")
mass_err12 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p2_r5-15-15_i_tau2e-2")
mass_err22 = np.abs(time_series.getMassErr())

time_series = TimeSeries("output/bax_p2_r5-20-20_i_tau2e-2")
mass_err32 = np.abs(time_series.getMassErr())

### Plot

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5))

axs[0, 0].plot(time, mass_err00, label="$\mathcal{{P}}_0$")
axs[0, 0].plot(time, mass_err01, label="$\mathcal{{P}}_1$")
axs[0, 0].plot(time, mass_err02, label="$\mathcal{{P}}_2$")
axs[0, 0].set_title("$r = (5, 5, 5)$")

axs[0, 1].plot(time, mass_err10)
axs[0, 1].plot(time, mass_err11)
axs[0, 1].plot(time, mass_err12)
axs[0, 1].set_title("$r = (5, 10, 10)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

axs[1, 0].plot(time, mass_err20)
axs[1, 0].plot(time, mass_err21)
axs[1, 0].plot(time, mass_err22)
axs[1, 0].set_title("$r = (5, 15, 15)$")

axs[1, 1].plot(time, mass_err30)
axs[1, 1].plot(time, mass_err31)
axs[1, 1].plot(time, mass_err32)
axs[1, 1].set_title("$r = (5, 20, 20)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

plt.setp(axs, xlabel="$t$", xticks=[0.0, 50.0, 100.0, 145.0], ylim=[3e-9, 1e-2], yscale="log")
plt.subplots_adjust(hspace=0.5)
fig.suptitle("$|\Delta m(t)|$", fontsize=16, y=1.05)
fig.legend(*axs[0, 0].get_legend_handles_labels(), ncols=3, loc="center", bbox_to_anchor=(0.5, 0.95))

plt.savefig("plots/mass_err_comparison_ttn.pdf", bbox_inches="tight");

In [None]:
fig, ax = plt.subplots()
ax.plot(time, mass_err01, label="$r = (5, 5, 5)$")
ax.plot(time, mass_err11, label="$r = (5, 10, 10)$")
ax.plot(time, mass_err21, label="$r = (5, 15, 15)$")
ax.plot(time, mass_err31, label="$r = (5, 20, 20)$")
ax.legend()
plt.setp(ax, xlabel="$t$", ylabel="$|\Delta m(t)|$", xticks=[0.0, 25.0, 50.0, 75.0, 100.0, 125.0, 145.0])
plt.savefig("plots/mass_err_comparison_ttn_rank.pdf");

## Wall time

### Load data

#### Partition 0

In [None]:
time_series = TimeSeries("output/bax_p0_r5-5-5_i_tau2e-2")
walltime00 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p0_r5-10-10_i_tau2e-2")
walltime10 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p0_r5-15-15_i_tau2e-2")
walltime20 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p0_r5-20-20_i_tau2e-2")
walltime30 = time_series.getWallTime()

#### Partition 1

In [None]:
time_series = TimeSeries("output/bax_p1_r5-5-5_i_tau2e-2")
walltime01 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p1_r5-10-10_i_tau2e-2")
walltime11 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p1_r5-15-15_i_tau2e-2")
walltime21 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p1_r5-20-20_i_tau2e-2")
walltime31 = time_series.getWallTime()

#### Partition 2

In [None]:
time_series = TimeSeries("output/bax_p2_r5-5-5_i_tau2e-2")
walltime02 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p2_r5-10-10_i_tau2e-2")
walltime12 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p2_r5-15-15_i_tau2e-2")
walltime22 = time_series.getWallTime()

time_series = TimeSeries("output/bax_p2_r5-20-20_i_tau2e-2")
walltime32 = time_series.getWallTime()

### Plot

In [None]:
fig, axs = plt.subplots(2, 2, figsize=(7, 5))
labels = ["$\mathcal{{P}}_{}$".format(i) for i in range(3)]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
walltime0 = [walltime00, walltime01, walltime02]
axs[0, 0].bar(labels, walltime0, color=colors)
axs[0, 0].set_title("$r = (5, 5, 5)$")

walltime1 = [walltime10, walltime11, walltime12]
axs[0, 1].bar(labels, walltime1, color=colors)
axs[0, 1].set_title("$r = (5, 10, 10)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

walltime2 = [walltime20, walltime21, walltime22]
axs[1, 0].bar(labels, walltime2, color=colors)
axs[1, 0].set_title("$r = (5, 15, 15)$")

walltime3 = [walltime30, walltime31, walltime32]
axs[1, 1].bar(labels, walltime3, color=colors)
axs[1, 1].set_title("$r = (5, 20, 20)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

# plt.setp(axs, ylabel="wall time [$\mathrm{s}$]")
# plt.setp(axs, ylim=[1e0, 1e6], yscale="log")
plt.subplots_adjust(hspace=0.3)
fig.suptitle("wall time [$\mathrm{s}$]", fontsize=16)

plt.savefig("plots/wall_time_comparison_ttn.pdf");

## Memory requirements

In [None]:
def memoryRequirementsP0(rank):
    x0 = rank[0] * 11776
    x100 = rank[2] * 1936
    x101 = rank[2] * 44
    x11 = rank[1] * 3136
    q = 1 * rank[0] ** 2
    q1 = rank[0] * rank[1] ** 2
    q10 = rank[1] * rank[2] ** 2
    return (x0 + x100 + x101 + x11 + q + q1 + q10) * 8 * 1e-6

def memoryRequirementsP1(rank):
    x00 = rank[1] * 11776
    x01 = rank[1] * 1331
    x10 = rank[2] * 896
    x11 = rank[2] * 224
    q = 1 * rank[0] ** 2
    q0 = rank[0] * rank[1] ** 2
    q1 = rank[0] * rank[2] ** 2
    return (x00 + x01 + x10 + x11 + q + q0 + q1) * 8 * 1e-6

def memoryRequirementsP2(rank):
    x00 = rank[1] * 736
    x01 = rank[1] * 1936
    x10 = rank[2] * 704
    x11 = rank[2] * 3136
    q = 1 * rank[0] ** 2
    q0 = rank[0] * rank[1] ** 2
    q1 = rank[0] * rank[2] ** 2
    return (x00 + x01 + x10 + x11 + q + q0 + q1) * 8 * 1e-6

def memoryRequirementsPfull(rank):
    x0 = rank[0] * 1424896
    x1 = rank[0] * 2207744
    q = 1 * rank[0] ** 2
    return (x0 + x1 + q) * 8 * 1e-6

In [None]:
rank = [5, 5, 5]
fig, axs = plt.subplots(2, 2, figsize=(7, 5))
labels = ["$\mathcal{{P}}_{}$".format(i) for i in range(3)]
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']
memory_req = [memoryRequirementsP0(rank), memoryRequirementsP1(rank), memoryRequirementsP2(rank)]
axs[0, 0].bar(labels, memory_req, color=colors)
axs[0, 0].set_title("$r = (5, 5, 5)$")

rank = [5, 10, 10]
memory_req = [memoryRequirementsP0(rank), memoryRequirementsP1(rank), memoryRequirementsP2(rank)]
axs[0, 1].bar(labels, memory_req, color=colors)
axs[0, 1].set_title("$r = (5, 10, 10)$")
axs[0, 1].yaxis.tick_right()
axs[0, 1].yaxis.set_ticks_position("both")
axs[0, 1].yaxis.set_label_position("right")

rank = [5, 15, 15]
memory_req = [memoryRequirementsP0(rank), memoryRequirementsP1(rank), memoryRequirementsP2(rank)]
axs[1, 0].bar(labels, memory_req, color=colors)
axs[1, 0].set_title("$r = (5, 15, 15)$")

rank = [5, 20, 20]
memory_req = [memoryRequirementsP0(rank), memoryRequirementsP1(rank), memoryRequirementsP2(rank)]
axs[1, 1].bar(labels, memory_req, color=colors)
axs[1, 1].set_title("$r = (5, 20, 20)$")
axs[1, 1].yaxis.tick_right()
axs[1, 1].yaxis.set_ticks_position("both")
axs[1, 1].yaxis.set_label_position("right")

# plt.setp(axs, ylim=[1e0, 1e6], yscale="log")
plt.subplots_adjust(hspace=0.3)
fig.suptitle("memory [$\mathrm{MB}$]", fontsize=16)

plt.savefig("plots/memory_comparison_ttn.pdf");