# Diffusive toggle switch

In [None]:
import matplotlib.gridspec as gridspec
import matplotlib.pyplot as plt
import numpy as np

from scripts.grid_class import GridParms
from scripts.tree_class import Tree
from scripts.notebooks.output_helper import *
from scripts.reference_solutions.ssa_helper import SSASol

plt.style.use("./scripts/notebooks/custom_style.mplstyle")
%matplotlib inline

## Load initial data

### TTN integrator

#### Partition 0, TTF

In [None]:
tree = readTree("output/dts_p0_r5_e_tau1e-3/output_t500000.nc")
slice_vec = np.zeros(tree.grid.d())
ttn_slice_p0_r_5, ttn_marginal_p0_r_5 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r6_e_tau1e-3/output_t500000.nc")
ttn_slice_p0_r_6, ttn_marginal_p0_r_6 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r6_e_tau5e-3_householder/output_t100000.nc")
ttn_slice_p0_r_6_hh, ttn_marginal_p0_r_6_hh= tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r7_e_tau1e-2/output_t50000.nc")
ttn_slice_p0_r_7_1e2, ttn_marginal_p0_r_7_1e2 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r7_e_tau1e-3/output_t500000.nc")
ttn_slice_p0_r_7, ttn_marginal_p0_r_7 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r7_e_tau5e-2_householder/output_t10000.nc")
ttn_slice_p0_r_7_hh, ttn_marginal_p0_r_7_hh = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r8_e_tau1e-3/output_t500000.nc")
ttn_slice_p0_r_8, ttn_marginal_p0_r_8 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r9_e_tau1e-3/output_t500000.nc")
ttn_slice_p0_r_9, ttn_marginal_p0_r_9 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r9_e_tau5e-2_householder/output_t10000.nc")
ttn_slice_p0_r_9_hh, ttn_marginal_p0_r_9_hh = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r10_e_tau5e-2_householder/output_t10000.nc")
ttn_slice_p0_r_10_hh, ttn_marginal_p0_r_10_hh = tree.calculateObservables(slice_vec)

# tree = readTree("output/dts_p0_r9_e_tau5e-2/output_t10000.nc")
# ttn_slice_p0_r_9_tau5e2, ttn_marginal_p0_r_9_tau5e2 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r10_e_tau1e-2/output_t50000.nc")
ttn_slice_p0_r_10, ttn_marginal_p0_r_10 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r10_e_tau2e-2/output_t25000.nc")
ttn_slice_p0_r_10_tau2e2, ttn_marginal_p0_r_10_tau2e2 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r10_e_tau5e-2/output_t10000.nc")
ttn_slice_p0_r_10_tau5e2, ttn_marginal_p0_r_10_tau5e2 = tree.calculateObservables(slice_vec)

tree = readTree("output/dts_p0_r11_e_tau1e-2/output_t50000.nc")
ttn_slice_p0_r_11, ttn_marginal_p0_r_11 = tree.calculateObservables(slice_vec)

In [None]:
tree = readTree("output/dts_p1_r5_e_tau1e-3/output_t500000.nc")
ttn_slice_p1_r_5, ttn_marginal_p1_r_5 = tree.calculateObservables(slice_vec)

### Matrix integrator

In [None]:
mat = readTree("output/ts_r5_e_tau1e-3/output_t500000.nc")
ts_slice_p0_r_5, ts_marginal_p0_r_5 = mat.calculateObservables(slice_vec[:2])

mat = readTree("output/ts_r20_e_tau1e-2/output_t50000.nc")
ts_slice_p0_r_20, ts_marginal_p0_r_20 = mat.calculateObservables(slice_vec[:2])

### SSA

In [None]:
idx_2D = np.array([0, 1])

In [None]:
ssa_1e4 = np.load("scripts/reference_solutions/diffusive_toggle_switch_ssa_1e+04.npy")
ssa_1e4_sol = SSASol(ssa_1e4)
ssa_marginal_1e4, _, ssa_slice_1e4, _ = ssa_1e4_sol.calculateObservables(slice_vec, idx_2D)
ssa_wall_time = 867.1

In [None]:
ssa_1e5 = np.load("scripts/reference_solutions/diffusive_toggle_switch_ssa_1e+05.npy")
ssa_1e5_sol = SSASol(ssa_1e5)
ssa_marginal_1e5, _, ssa_slice_1e5, _ = ssa_1e5_sol.calculateObservables(slice_vec, idx_2D)
ssa_wall_time = 1866.7

In [None]:
ssa_1e6 = np.load("scripts/reference_solutions/diffusive_toggle_switch_ssa_1e+06.npy")
ssa_1e6_sol = SSASol(ssa_1e6)
ssa_marginal_1e6, _, ssa_slice_1e6, _ = ssa_1e6_sol.calculateObservables(slice_vec, idx_2D)
ssa_wall_time = 4353.9

In [None]:
ssa_1e7 = np.load("scripts/reference_solutions/diffusive_toggle_switch_ssa_1e+07.npy")
ssa_1e7_sol = SSASol(ssa_1e7)
ssa_marginal_1e7, ssa_marginal2D_1e7, ssa_slice_1e7, ssa_slice2D_1e7 = ssa_1e7_sol.calculateObservables(slice_vec, idx_2D)
ssa_wall_time = 1e5

In [None]:
print(ssa_marginal2D_1e7.reshape(ssa_1e7_sol.n[0], ssa_1e7_sol.n[1]))

## Convergence with increasing rank

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,8))
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_7[0], "v", label="TTN, $r=7$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_8[0], "<", label="TTN, $r=8$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_9[0], "^", label="TTN, $r=9$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_10[0], ">", label="TTN, $r=10$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p1_r_5[0], "k--", label="TT, $r=5$")
ax1.set_xlabel("$x_0$")

ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_7[1], "v", label="TTN, $r=7$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_8[1], "<", label="TTN, $r=8$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_9[1], "^", label="TTN, $r=9$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_10[1], ">", label="TTN, $r=10$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p1_r_5[1], "k--", label="TT, $r=5$")
ax2.set_xlabel("$x_1$")

ax2.legend()
plt.savefig("plots/dts_fig1.pdf")

## Householder orthogonalization

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,8))
ax1.plot(np.arange(tree.grid.n[0]), ttn_slice_p0_r_7_hh[0], "^", label="TTN, $r=7$, Householder")
ax1.plot(np.arange(tree.grid.n[0]), ttn_slice_p0_r_9[0], "^", label="TTN, $r=9$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_slice_p0_r_9_hh[0], ">", label="TTN, $r=9$, Householder")
ax1.plot(np.arange(tree.grid.n[0]), ttn_slice_p0_r_10_hh[0], ">", label="TTN, $r=10$, Householder")
ax1.plot(np.arange(ssa_1e7_sol.n[0])+ssa_1e7_sol.n_min[0], ssa_slice_1e7[-1][0], ".-", label="SSA, $10^7$ runs")
ax1.set_xlabel("$x_0$")

ax2.plot(np.arange(tree.grid.n[1]), ttn_slice_p0_r_7_hh[1], "^", label="TTN, $r=7$, Householder")
ax2.plot(np.arange(tree.grid.n[1]), ttn_slice_p0_r_9[1], "^", label="TTN, $r=9$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_slice_p0_r_9_hh[1], ">", label="TTN, $r=9$, Householder")
ax2.plot(np.arange(tree.grid.n[1]), ttn_slice_p0_r_10_hh[1], ">", label="TTN, $r=10$, Householder")
ax2.plot(np.arange(ssa_1e7_sol.n[1])+ssa_1e7_sol.n_min[1], ssa_slice_1e7[-1][1], ".-", label="SSA, $10^7$ runs")
ax2.set_xlabel("$x_1$")

ax2.legend()
plt.savefig("plots/dts_fig1a.pdf")

## Comparison with one-site solution

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,8))
# ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_8[0], "k.", label="$r=8$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_9[0], "v", label="$r=9$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_10[0], "<", label="$r=10$")
ax1.plot(np.arange(mat.grid.n[0]), ts_marginal_p0_r_5[0], "k--", label="one-site toggle switch, $r=5$")
ax1.plot(np.arange(mat.grid.n[0]), ts_marginal_p0_r_20[0], "k:", label="one-site toggle switch, $r=10$")
ax2.set_xlabel("$x_0$")

# ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_8[1], "k.", label="$r=8$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_9[1], "v", label="$r=9$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_10[1], "<", label="$r=10$")
ax2.plot(np.arange(mat.grid.n[1]), ts_marginal_p0_r_5[1], "k--", label="one-site toggle switch, $r=5$")
ax2.plot(np.arange(mat.grid.n[1]), ts_marginal_p0_r_20[1], "k:", label="one-site toggle switch, $r=10$")
ax2.set_xlabel("$x_1$")

ax2.legend()
plt.savefig("plots/dts_fig2.pdf")

## Comparison of different time step sizes

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_7_1e2[0], label="TTN, $r=7$, $\\tau=0.01$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_7[0], label="TTN, $r=7$, $\\tau=0.001$")
ax1.set_xlabel("$x_0$")

ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_7_1e2[1], label="TTN, $r=7$, $\\tau=0.01$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_7[1], label="TTN, $r=7$, $\\tau=0.001$")
ax2.set_xlabel("$x_1$")

ax2.legend()
plt.savefig("plots/dts_fig3.pdf")

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6, 8))
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_10_tau5e2[0], label="TTN, $r=10$, $\\tau=0.05$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_10_tau2e2[0], label="TTN, $r=10$, $\\tau=0.02$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_10[0], label="TTN, $r=10$, $\\tau=0.01$")
ax1.set_xlabel("$x_0$")

ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_10_tau5e2[1], label="TTN, $r=10$, $\\tau=0.05$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_10_tau2e2[1], label="TTN, $r=10$, $\\tau=0.02$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_10[1], label="TTN, $r=10$, $\\tau=0.01$")
ax2.set_xlabel("$x_1$")

ax2.legend()
plt.savefig("plots/dts_fig3a.pdf")

## Comparison with SSA

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,8))
# ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_11[0], ".-", label="TTN, $r=11$")
# ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_7[0], ".-", label="TTN, $r=7$")
# ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_8[0], ".-", label="TTN, $r=8$")
# ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_9[0], ".-", label="TTN, $r=9$")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_9_hh[0], "v", label="TTN, $r=9$, Householder")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_10_hh[0], "<", label="TTN, $r=10$, Householder")
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_10[0], "^", label="TTN, $r=10$")
# ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_11[0], "k--", label="TTN, $r=11$")
# ax1.plot(np.arange(ssa_1e4_sol.n[0])+ssa_1e4_sol.n_min[0], ssa_marginal_1e4[-1][0], "x", label="SSA, $10^4$ runs")
# ax1.plot(np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0], ssa_marginal_1e5[-1][0], "x", label="SSA, $10^5$ runs")
ax1.plot(np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0], ssa_marginal_1e5[-1][0], ".-", label="SSA, $10^5$ runs")
ax1.plot(np.arange(ssa_1e6_sol.n[0])+ssa_1e6_sol.n_min[0], ssa_marginal_1e6[-1][0], ".-", label="SSA, $10^6$ runs")
ax1.plot(np.arange(ssa_1e7_sol.n[0])+ssa_1e7_sol.n_min[0], ssa_marginal_1e7[-1][0], ".-", label="SSA, $10^7$ runs")
ax1.set_xlabel("$x_0$")

# ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_11[1], ".-", label="TTN, $r=11$")
# ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_7[1], ".-", label="TTN, $r=7$")
# ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_8[1], ".-", label="TTN, $r=8$")
# ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_9[1], ".-", label="TTN, $r=9$")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_9_hh[1], "v", label="TTN, $r=9$, Householder")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_10_hh[1], "<", label="TTN, $r=10$, Householder")
ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_10[1], "^", label="TTN, $r=10$")
# ax2.plot(np.arange(tree.grid.n[1]), ttn_marginal_p0_r_11[1], "k:", label="TTN, $r=11$")
# ax2.plot(np.arange(ssa_1e4_sol.n[1])+ssa_1e4_sol.n_min[1], ssa_marginal_1e4[-1][1], "x", label="SSA, $10^4$ runs")
# ax2.plot(np.arange(ssa_1e5_sol.n[1])+ssa_1e5_sol.n_min[1], ssa_marginal_1e5[-1][1], "x", label="SSA, $10^5$ runs")
ax2.plot(np.arange(ssa_1e5_sol.n[1])+ssa_1e5_sol.n_min[1], ssa_marginal_1e5[-1][1], ".-", label="SSA, $10^5$ runs")
ax2.plot(np.arange(ssa_1e6_sol.n[1])+ssa_1e6_sol.n_min[1], ssa_marginal_1e6[-1][1], ".-", label="SSA, $10^6$ runs")
ax2.plot(np.arange(ssa_1e7_sol.n[1])+ssa_1e7_sol.n_min[1], ssa_marginal_1e7[-1][1], ".-", label="SSA, $10^7$ runs")
ax2.set_xlabel("$x_1$")

ax2.legend()
plt.savefig("plots/dts_fig4.pdf")

## Check symmetry

In [None]:
print("|A0-A3| =", np.linalg.norm(ttn_marginal_p0_r_10[0] - ttn_marginal_p0_r_10[6]))
print("|B0-B3| =", np.linalg.norm(ttn_marginal_p0_r_10[1] - ttn_marginal_p0_r_10[7]))

print("|A1-A2| =", np.linalg.norm(ttn_marginal_p0_r_10[2] - ttn_marginal_p0_r_10[4]))
print("|B1-B2| =", np.linalg.norm(ttn_marginal_p0_r_10[3] - ttn_marginal_p0_r_10[5]))

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(np.arange(tree.grid.n[0]), ttn_marginal_p0_r_9[0], "k.", label="A, compartment 0")
ax1.plot(np.arange(tree.grid.n[6]), ttn_marginal_p0_r_9[6], "rx", label="A, compartment 3")
ax1.legend()

ax2.plot(np.arange(tree.grid.n[2]), ttn_marginal_p0_r_9[2], "k.", label="A, compartment 1")
ax2.plot(np.arange(tree.grid.n[4]), ttn_marginal_p0_r_9[4], "rx", label="A, compartment 2")
ax2.legend()

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1)
ax1.plot(np.arange(ssa_1e5_sol.n[0])+ssa_1e5_sol.n_min[0], ssa_marginal_1e5[-1][0], "k.", label="A, compartment 0")
ax1.plot(np.arange(ssa_1e5_sol.n[6])+ssa_1e5_sol.n_min[6], ssa_marginal_1e5[-1][6], "rx", label="A, compartment 3")
ax1.legend()

ax2.plot(np.arange(ssa_1e5_sol.n[2])+ssa_1e5_sol.n_min[2], ssa_marginal_1e5[-1][2], "k.", label="B, compartment 1")
ax2.plot(np.arange(ssa_1e5_sol.n[4])+ssa_1e5_sol.n_min[4], ssa_marginal_1e5[-1][4], "rx", label="B, compartment 2")
ax2.legend()

#### Partition 1, QTT

In [None]:
tree = readTree("output/dts_p1_r5_e_tau1e-3/output_t500000.nc")
ttn_slice_p0_r_11, ttn_marginal_p0_r_11 = tree.calculateObservables(slice_vec)

## Mass error

In [None]:
time_series = TimeSeries("output/dts_p0_r10_e_tau1e-2")
mass_err_p0_r10 = time_series.getMassErr()
time_series = TimeSeries("output/dts_p0_r11_e_tau1e-2")
mass_err_p0_r11 = time_series.getMassErr()

In [None]:
time_series = TimeSeries("output/dts_p0_r10_e_tau1e-2")
print("r = 10:", time_series.getMaxMassErr())
time_series = TimeSeries("output/dts_p0_r11_e_tau1e-2")
print("r = 11:", time_series.getMaxMassErr())

In [None]:
time = time_series.time

In [None]:
fig, ax = plt.subplots()
ax.plot(time, np.abs(mass_err_p0_r10), label="TTN, $r=10$")
ax.plot(time, np.abs(mass_err_p0_r11), label="TTN, $r=11$")
ax.set_xlabel("$t$")
ax.set_ylabel("$|\Delta m(t)|$")
ax.set_yscale("log")
ax.legend()
plt.savefig("plots/dts_fig5.pdf")