# BAX pore assembly

In [None]:
from scripts.reference_solutions.ssa_helper import SSASol
from output_helper import *
%matplotlib inline

In [None]:
with xr.open_dataset("output/bax_general_230724_so_tau1e-1_ss10/output_t500.nc") as ds:
    grid = GridInfo(ds)
    print(grid.t)
    lr_sol = LRSol(ds, grid)
    slice_vec = np.array([5, 2, 2, 2, 2, 2, 0, 0, 0, 3, 47])
    idx_2D = np.array([9, 10])
    P_marginal = lr_sol.marginalDistributions()
    P_marginal2D = lr_sol.marginalDistribution2D(idx_2D)
    P_sliced = lr_sol.slicedDistributions(slice_vec)
    P_sliced2D = lr_sol.slicedDistribution2D(slice_vec, idx_2D)

In [None]:
result = np.load("scripts/reference_solutions/bax_ssa_general_ref_1e4_5.npy")
ssa_sol = SSASol(result)
P_marginal_ssa, P_marginal2D_ssa, P_sliced_ssa, P_sliced2D_ssa = ssa_sol.calculateObservables(slice_vec, idx_2D)

Additional parameters:

In [None]:
# idx = np.arange(11)
# mesh = [grid.bin[j] * np.arange(n_el) + grid.liml[j] for j, n_el in enumerate(grid.n)]
# mesh_ssa = [np.arange(el) + nssa_min[j] for j, el in enumerate(nssa)]
# mesh_ssa = [np.arange(el) + ssa_sol.n_min[j] for j, el in enumerate(ssa_sol.n)]

## Figures 1 and 2
One-dimensional marginal and sliced distributions

In [None]:
fig, axs = plt.subplots(6, 2, figsize=(6, 20))
for i in range(11):
    axs.flat[i].plot(range(grid.n[i]) + grid.liml[i], P_marginal[i])
    axs.flat[i].plot(range(ssa_sol.n[i]) + ssa_sol.n_min[i], P_marginal_ssa[-1][i])

In [None]:
# fig, axs = plt.subplots(6, 2, figsize=(6, 20))
# plotP1Dmult(axs, [P_marginal, P_marginal_ssa[-1]], [mesh, mesh_ssa], ["DLR approx.", "SSA"], idx, "MD")
# axs.flat[0].legend()
# axs.flat[-1].remove()
# fig.tight_layout()
# # plt.setp(axs.flat, yscale="log")
# fig.savefig('plots/bax_fig1.pdf')

fig, axs = plt.subplots(6, 2, figsize=(6, 20))
for i in range(11):
    # axs.flat[i].plot(range(grid.n[i]) + grid.liml[i], P_sliced[i])
    axs.flat[i].plot(range(ssa_sol.n[i]) + ssa_sol.n_min[i], P_sliced_ssa[-1][i])

## Figures 3 and 4

In [None]:
result = np.load("scripts/reference_solutions/bax_ssa_general_ref_1e4.npy")
ssa_sol_1e4 = SSASol(result)
P_marginal_ssa_1e4, _, P_sliced_ssa_1e4, _ = ssa_sol_1e4.calculateObservables(slice_vec, idx_2D)

result = np.load("scripts/reference_solutions/bax_ssa_general_ref_1e5.npy")
ssa_sol_1e5 = SSASol(result)
P_marginal_ssa_1e5, _, P_sliced_ssa_1e5, _ = ssa_sol_1e5.calculateObservables(slice_vec, idx_2D)

result = np.load("scripts/reference_solutions/bax_ssa_general_ref_1e6.npy")
ssa_sol_1e6 = SSASol(result)
P_marginal_ssa_1e6, _, P_sliced_ssa_1e6, _ = ssa_sol_1e6.calculateObservables(slice_vec, idx_2D)

In [None]:
mesh_dlr = grid.bin[0] * np.arange(grid.n[0]) + grid.liml[0]
mesh_ssa_1e4 = np.arange(ssa_sol_1e4.n[0]) + ssa_sol_1e4.n_min[0]
mesh_ssa_1e5 = np.arange(ssa_sol_1e5.n[0]) + ssa_sol_1e5.n_min[0]
mesh_ssa_1e6 = np.arange(ssa_sol_1e6.n[0]) + ssa_sol_1e6.n_min[0]

P_sliced_ssa_1e4_red = np.interp(mesh_dlr, mesh_ssa_1e4, P_sliced_ssa_1e4[-1][0])
P_sliced_ssa_1e5_red = np.interp(mesh_dlr, mesh_ssa_1e5, P_sliced_ssa_1e5[-1][0])
P_sliced_ssa_1e6_red = np.interp(mesh_dlr, mesh_ssa_1e6, P_sliced_ssa_1e6[-1][0])

fig, axs = plt.subplots(3, 1, figsize=(5, 7.5))
# axs[0].plot(mesh_dlr, P_sliced[1], "bx", label="DLR approx.")
axs[0].plot(mesh_dlr, P_sliced_ssa_1e4_red, "gs", label="SSA")

# axs[1].plot(mesh_dlr, P_sliced[1], "bx", label="DLR approx.")
axs[1].plot(mesh_dlr, P_sliced_ssa_1e5_red, "gs")

# axs[2].plot(mesh_dlr, P_sliced[1], "bx", label="DLR approx.")
axs[2].plot(mesh_dlr, P_sliced_ssa_1e6_red, "gs")

max_error_ssa_1e4 = np.max(np.abs(P_sliced_ssa_1e4_red - P_sliced[-1][0]))
max_error_ssa_1e5 = np.max(np.abs(P_sliced_ssa_1e5_red - P_sliced[-1][0]))
max_error_ssa_1e6 = np.max(np.abs(P_sliced_ssa_1e6_red - P_sliced[-1][0]))

title = "\\textbf{{10\,000 runs}}\n"
error_str_split = "{:.2e}".format(max_error_ssa_1e4).split('e')
error_str = "{} \\cdot 10^{{{}}}".format(error_str_split[0], int(error_str_split[1]))
title += "$\mathrm{{max}}(|\\textrm{{{}}}-\\textrm{{{}}}|)$ = ${}$\n".format("SSA", "DLR approx.", error_str)
axs[0].set_title(title, y=0.95)

title = "\\textbf{{100\,000 runs}}\n"
error_str_split = "{:.2e}".format(max_error_ssa_1e5).split('e')
error_str = "{} \\cdot 10^{{{}}}".format(error_str_split[0], int(error_str_split[1]))
title += "$\mathrm{{max}}(|\\textrm{{{}}}-\\textrm{{{}}}|)$ = ${}$\n".format("SSA", "DLR approx.", error_str)
axs[1].set_title(title, y=0.95)

title = "\\textbf{{1\,000\,000 runs}}\n"
error_str_split = "{:.2e}".format(max_error_ssa_1e6).split('e')
error_str = "{} \\cdot 10^{{{}}}".format(error_str_split[0], int(error_str_split[1]))
title += "$\mathrm{{max}}(|\\textrm{{{}}}-\\textrm{{{}}}|)$ = ${}$\n".format("SSA", "DLR approx.", error_str)
axs[2].set_title(title, y=0.95)

plt.setp(axs.flat, yscale="log", xlabel="$x_1$", ylabel="$P_\mathrm{{S}}(x_1)$")
fig.tight_layout()
axs[0].legend(fontsize="9")
fig.savefig('plots/bax_general_fig1.pgf')
fig.savefig('plots/bax_general_fig1.pdf')

In [None]:
P_marginal_ssa_1e4_red = np.interp(mesh_dlr, mesh_ssa_1e4, P_marginal_ssa_1e4[-1][0])
P_marginal_ssa_1e5_red = np.interp(mesh_dlr, mesh_ssa_1e5, P_marginal_ssa_1e5[-1][0])
P_marginal_ssa_1e6_red = np.interp(mesh_dlr, mesh_ssa_1e6, P_marginal_ssa_1e6[-1][0])

fig, axs = plt.subplots(3, 1, figsize=(5, 7.5))
# axs[0].plot(mesh_dlr, P_sliced[1], "bx", label="DLR approx.")
axs[0].plot(mesh_dlr, P_marginal_ssa_1e4_red, "gs", label="SSA")

# axs[1].plot(mesh_dlr, P_sliced[1], "bx", label="DLR approx.")
axs[1].plot(mesh_dlr, P_marginal_ssa_1e5_red, "gs")

# axs[2].plot(mesh_dlr, P_sliced[1], "bx", label="DLR approx.")
axs[2].plot(mesh_dlr, P_marginal_ssa_1e6_red, "gs")

max_error_ssa_1e4 = np.max(np.abs(P_marginal_ssa_1e4_red - P_marginal[-1][0]))
max_error_ssa_1e5 = np.max(np.abs(P_marginal_ssa_1e5_red - P_marginal[-1][0]))
max_error_ssa_1e6 = np.max(np.abs(P_marginal_ssa_1e6_red - P_marginal[-1][0]))

title = "\\textbf{{10\,000 runs}}\n"
error_str_split = "{:.2e}".format(max_error_ssa_1e4).split('e')
error_str = "{} \\cdot 10^{{{}}}".format(
    error_str_split[0], int(error_str_split[1]))
title += "$\mathrm{{max}}(|\\textrm{{{}}}-\\textrm{{{}}}|)$ = ${}$\n".format(
    "SSA", "DLR approx.", error_str)
axs[0].set_title(title, y=0.95)

title = "\\textbf{{100\,000 runs}}\n"
error_str_split = "{:.2e}".format(max_error_ssa_1e5).split('e')
error_str = "{} \\cdot 10^{{{}}}".format(
    error_str_split[0], int(error_str_split[1]))
title += "$\mathrm{{max}}(|\\textrm{{{}}}-\\textrm{{{}}}|)$ = ${}$\n".format(
    "SSA", "DLR approx.", error_str)
axs[1].set_title(title, y=0.95)

title = "\\textbf{{1\,000\,000 runs}}\n"
error_str_split = "{:.2e}".format(max_error_ssa_1e6).split('e')
error_str = "{} \\cdot 10^{{{}}}".format(
    error_str_split[0], int(error_str_split[1]))
title += "$\mathrm{{max}}(|\\textrm{{{}}}-\\textrm{{{}}}|)$ = ${}$\n".format(
    "SSA", "DLR approx.", error_str)
axs[2].set_title(title, y=0.95)

plt.setp(axs.flat, yscale="log", xlabel="$x_1$",
         ylabel="$P_\mathrm{{S}}(x_1)$")
fig.tight_layout()
axs[0].legend(fontsize="9")
fig.savefig('plots/bax_general_fig2.pgf')
fig.savefig('plots/bax_general_fig2.pdf')

## Figures 5 and 6

In [None]:
result = np.load("scripts/reference_solutions/bax_ssa_general_ref_1e4_5.npy")
ssa_sol_1e4 = SSASol(result)
P_marginal_ssa_1e4, P_marginal2D_ssa_1e4, P_sliced_ssa_1e4, P_sliced2D_ssa_1e4 = ssa_sol_1e4.calculateObservables(
    slice_vec, idx_2D)

In [None]:
n_min = np.amax(np.stack((grid.liml, ssa_sol_1e4.n_min)), axis=0)
n_max = np.amin(np.stack((grid.n + grid.liml, ssa_sol_1e4.n_max)), axis=0)
n = n_max - n_min + 1
dx = np.prod(n)

print(n)

P_marginal2D_red = np.zeros(n[:2])
P_marginal2D_ssa_1e4_red = np.zeros(n[:2])
# P_marginal2D_ode_red = np.zeros(n[:2])

# P_sliced2D_r4_red = np.zeros(n[:2])
# P_sliced2D_r7_red = np.zeros(n[:2])
# P_sliced2D_ode_red = np.zeros(n[:2])
P_sliced2D_red = np.zeros(n[:2])
P_sliced2D_ssa_1e4_red = np.zeros(n[:2])
# P_sliced2D_ssa_1e5_red = np.zeros(n[:2])
# P_sliced2D_ssa_1e6_red = np.zeros(n[:2])

liml_red = n_min - grid.liml
liml_red_ssa_1e4 = n_min - ssa_sol_1e4.n_min
# liml_red_ssa_1e5 = n_min - ssa_sol_1e5.n_min
# liml_red_ssa_1e6 = n_min - ssa_sol_1e6.n_min

for i in range(n[9]):
    for j in range(n[10]):
        P_marginal2D_red[i, j] = P_marginal2D[i + liml_red[9], j + liml_red[10]]
        P_marginal2D_ssa_1e4_red[i, j] = P_marginal2D_ssa_1e4[-1][i +
                                                              liml_red_ssa_1e4[9] + ssa_sol_1e4.n[9] * (j + liml_red_ssa_1e4[10])]

        P_sliced2D_red[i, j] = P_sliced2D[i +
                                                liml_red[9], j + liml_red[10]]
        P_sliced2D_ssa_1e4_red[i, j] = P_sliced2D_ssa_1e4[-1][i +
                                                              liml_red_ssa_1e4[9] + ssa_sol_1e4.n[9] * (j + liml_red_ssa_1e4[10])]
        # P_sliced2D_ssa_1e5_red[i, j] = P_sliced2D_ssa_1e5[-1][i +
        #                                                       liml_red_ssa_1e5[0] + ssa_sol_1e5.n[0] * (j + liml_red_ssa_1e5[1])]
        # P_sliced2D_ssa_1e6_red[i, j] = P_sliced2D_ssa_1e6[-1][i +
        #                                                       liml_red_ssa_1e6[0] + ssa_sol_1e6.n[0] * (j + liml_red_ssa_1e6[1])]

In [None]:
xx1, xx2 = np.meshgrid(
    np.arange(n[0]) + n_min[0], np.arange(n[1]) + n_min[1], indexing='ij')

fig, axs = plt.subplots(1, 2, figsize=(6, 4))
plt.setp(axs.flat, xlabel='$x_1$', ylabel='$x_2$')
plotP2D(axs, [P_marginal2D_red, P_marginal2D_ssa_1e4_red], [
        xx1, xx2], ["DLR approx. ($r=4$)", "SSA (10\,000 runs)"])
fig.tight_layout()
fig.savefig('plots/bax_general_fig3.pgf')
fig.savefig('plots/bax_general_fig3.pdf')


In [None]:
xx1, xx2 = np.meshgrid(
    np.arange(n[0]) + n_min[0], np.arange(n[1]) + n_min[1], indexing='ij')

fig, axs = plt.subplots(1, 2, figsize=(6, 4))
plt.setp(axs.flat, xlabel='$x_1$', ylabel='$x_2$')
plotP2D(axs, [P_sliced2D_red, P_sliced2D_ssa_1e4_red], [
        xx1, xx2], ["DLR approx. ($r=4$)", "SSA (10\,000 runs)"])
fig.tight_layout()
fig.savefig('plots/bax_general_fig4.pgf')
fig.savefig('plots/bax_general_fig4.pdf')
