In [1]:
import os

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd

jet = mpl.colormaps["jet"]
jet_12_colors = jet(np.linspace(0, 1, 15))

In [2]:
path1 = "../autodl-tmp/toggle/sim/toggle_trajs10000_end40_init1_1_0_0.npz"  # Gillespie data path

data1 = np.load(path1)
T = 40
sim_times = np.linspace(0, T, int(40/5e-3))
tfinal = sim_times[-1]
step = sim_times[1] - sim_times[0]
samples = data1["data"]  # batch_size x time_points x species
sim_Gx = samples[:, :, 0]
sim_Gy = samples[:, :, 1]
sim_Px = samples[:, :, 2]
sim_Py = samples[:, :, 3]

sim_Gx_mean = np.mean(sim_Gx, 0)
sim_Gx_std = np.std(sim_Gx, 0)
sim_Gy_mean = np.mean(sim_Gy, 0)
sim_Gy_std = np.std(sim_Gy, 0)
sim_Px_mean = np.mean(sim_Px, 0)
sim_Px_std = np.std(sim_Px, 0)
sim_Py_mean = np.mean(sim_Py, 0)
sim_Py_std = np.std(sim_Py, 0)

In [3]:
path2 = "../autodl-tmp/toggle/test2/rnn_samples.npz"

data2 = np.load(path2)
rnn_times = np.concatenate([np.zeros((1)), data2["times"]], axis=0)
samples = data2["samples"]  # time_points x batch_size x species
rnn_Gx = np.concatenate([np.ones((1, 10000)), samples[:, :, 0]], axis=0)  # insert initial state
rnn_Gy = np.concatenate([np.ones((1, 10000)), samples[:, :, 1]], axis=0)
rnn_Px = np.concatenate([np.zeros((1, 10000)), samples[:, :, 2]], axis=0)  # insert initial state
rnn_Py = np.concatenate([np.zeros((1, 10000)), samples[:, :, 3]], axis=0)

rnn_Gx_mean = np.mean(rnn_Gx, 1)
rnn_Gx_std = np.std(rnn_Gx, 1)
rnn_Gy_mean = np.mean(rnn_Gy, 1)
rnn_Gy_std = np.std(rnn_Gy, 1)
rnn_Px_mean = np.mean(rnn_Px, 1)
rnn_Px_std = np.std(rnn_Px, 1)
rnn_Py_mean = np.mean(rnn_Py, 1)
rnn_Py_std = np.std(rnn_Py, 1)

In [4]:
path3 = "../autodl-tmp/toggle/test4/met_samples.npz"
data3 = np.load(path3)
met_times = np.concatenate([np.zeros((1)), data3["times"]], axis=0)
samples = data3["samples"]  # time_points x batch_size x species
met_Gx = np.concatenate([np.ones((1, 10000)), samples[:, :, 0]], axis=0)  # insert initial state
met_Gy = np.concatenate([np.ones((1, 10000)), samples[:, :, 1]], axis=0)
met_Px = np.concatenate([np.zeros((1, 10000)), samples[:, :, 2]], axis=0)  # insert initial state
met_Py = np.concatenate([np.zeros((1, 10000)), samples[:, :, 3]], axis=0)

met_Gx_mean = np.mean(met_Gx, 1)
met_Gx_std = np.std(met_Gx, 1)
met_Gy_mean = np.mean(met_Gy, 1)
met_Gy_std = np.std(met_Gy, 1)
met_Px_mean = np.mean(met_Px, 1)
met_Px_std = np.std(met_Px, 1)
met_Py_mean = np.mean(met_Py, 1)
met_Py_std = np.std(met_Py, 1)

### Plot the time-related average counts

In [46]:
plt.rc('font', size=30)

fig, ax = plt.subplots(figsize=(9, 8))

plt.plot(sim_times, sim_Gx_mean, color=jet_12_colors[12, :], linewidth=3, label="Gx", alpha=0.7)
plt.plot(sim_times, sim_Gy_mean, color=jet_12_colors[3, :], linewidth=3, label="Gy", alpha=0.7)

plt.plot(rnn_times, rnn_Gx_mean, color=jet_12_colors[12, :], marker='o', linestyle='None', markersize=10, alpha=0.5, label="RNN-Gx")
plt.plot(rnn_times, rnn_Gy_mean, color=jet_12_colors[3, :], marker='o', linestyle='None', markersize=10, alpha=0.5, label="RNN-Gy")

plt.plot(met_times, met_Gx_mean, color=jet_12_colors[12, :], marker='s', linestyle='None', markersize=10, alpha=0.5, label="MET-Gx")
plt.plot(met_times, met_Gy_mean, color=jet_12_colors[3, :], marker='s', linestyle='None', markersize=10, alpha=0.5, label="MET-Gy")

plt.xlabel("Time (h)")
plt.ylabel("Average count")
plt.title("Gene")
plt.legend()
fig.set_size_inches(9, 8)
plt.savefig("toggle_gene_average_count.svg", bbox_inches="tight", dpi=400)
plt.close()

fig, ax = plt.subplots(figsize=(9, 8))

plt.plot(sim_times, sim_Px_mean, color=jet_12_colors[12, :], linewidth=3, label="Px", alpha=0.7)
plt.plot(sim_times, sim_Py_mean, color=jet_12_colors[3, :], linewidth=3, label="Py", alpha=0.7)

plt.plot(rnn_times, rnn_Px_mean, color=jet_12_colors[12, :], marker='o', linestyle='None', markersize=10, alpha=0.5, label="RNN-Px")
plt.plot(rnn_times, rnn_Py_mean, color=jet_12_colors[3, :], marker='o', linestyle='None', markersize=10, alpha=0.5, label="RNN-Py")

plt.plot(met_times, met_Px_mean, color=jet_12_colors[12, :], marker='s', linestyle='None', markersize=10, alpha=0.5, label="MET-Px")
plt.plot(met_times, met_Py_mean, color=jet_12_colors[3, :], marker='s', linestyle='None', markersize=10, alpha=0.5, label="MET-Py")

plt.xlabel("Time (h)")
plt.ylabel("Average count")
plt.title("Protein")
plt.legend()
fig.set_size_inches(9, 8)
plt.savefig("toggle_pro_average_count.svg", bbox_inches="tight", dpi=400)
plt.close()

### Plot the marginal distributions

In [48]:
for tc in [1, 3, 40]:
    fig, ax = plt.subplots(figsize=(9, 8))
    plt.rc('font', size=30)
    si = np.round(tc/step).astype(int)
    ri = np.abs(rnn_times-tc).argmin()
    mi = np.abs(met_times-tc).argmin()
    x1 = sim_Gx[:, si]
    x2 = rnn_Gx[ri, :]
    x3 = met_Gx[mi, :]
    df1 = pd.DataFrame(x1)
    df2 = pd.DataFrame(x2)
    df3 = pd.DataFrame(x3)
    p1 = df1.value_counts(normalize=True).sort_index()
    p2 = df2.value_counts(normalize=True).sort_index()
    p3 = df3.value_counts(normalize=True).sort_index()
    HD = round(np.sqrt(1-np.sum(np.sqrt(p1*p3))), 4)  # hellinger distance
    w1 = np.ones_like(x1)/float(len(x1))
    w2 = np.ones_like(x2)/float(len(x2))
    w3 = np.ones_like(x3)/float(len(x3))
    plt.hist([x3, x1, x2], bins=np.arange(0, 5, 0.5), weights=[w3, w1, w2], 
             color=[jet_12_colors[3, :], "darkgray", jet_12_colors[12, :]],
             alpha=0.7, orientation="vertical")
    plt.legend(["MET", "Gillespie", "RNN"], title=r"$D_{HD}=" + f"{HD}$", fontsize=30)
    plt.ylabel("Marginal probability")
    plt.xlabel("Counts")
    plt.title(fr"$t=${tc}")
    fig.set_size_inches(9, 8)
    plt.savefig(f"toggle_Gx_marginal_tc{tc}.svg", bbox_inches="tight", dpi=400)
    plt.close()
    
    fig, ax = plt.subplots(figsize=(9, 8))
    plt.rc('font', size=30)
    si = np.round(tc/step).astype(int)
    ri = np.abs(rnn_times-tc).argmin()
    mi = np.abs(met_times-tc).argmin()
    x1 = sim_Px[:, si]
    x2 = rnn_Px[ri, :]
    x3 = met_Px[mi, :]
    df1 = pd.DataFrame(x1)
    df2 = pd.DataFrame(x2)
    df3 = pd.DataFrame(x3)
    p1 = df1.value_counts(normalize=True).sort_index()
    p2 = df2.value_counts(normalize=True).sort_index()
    p3 = df3.value_counts(normalize=True).sort_index()
    HD = round(np.sqrt(1-np.sum(np.sqrt(p1*p3))), 4)  # hellinger distance
    w1 = np.ones_like(x1)/float(len(x1))
    w2 = np.ones_like(x2)/float(len(x2))
    w3 = np.ones_like(x3)/float(len(x3))
    plt.hist([x3, x1, x2], bins=np.arange(0, 120, 5), weights=[w3, w1, w2], 
             color=[jet_12_colors[3, :], "darkgray", jet_12_colors[12, :]],
             alpha=0.7, orientation="vertical")
    plt.legend(["MET", "Gillespie", "RNN"], title=r"$D_{HD}=" + f"{HD}$", fontsize=30)
    plt.ylabel("Marginal probability")
    plt.xlabel("Counts")
    plt.title(fr"$t=${tc}")
    fig.set_size_inches(9, 8)
    plt.savefig(f"toggle_Px_marginal_tc{tc}.svg", bbox_inches="tight", dpi=400)
    plt.close()

### Plot the joint distribution

In [49]:
for tc in [1, 3, 40]:
    ri = np.abs(rnn_times-tc).argmin()
    si = np.round(tc/step).astype(int)
    mi = np.abs(met_times-tc).argmin()
    x1 = sim_Px[:, si]
    y1 = sim_Py[:, si]
    x2 = rnn_Px[ri, :]
    y2 = rnn_Py[ri, :]
    x3 = met_Px[mi, :]
    y3 = met_Py[mi, :]

    plt.figure(figsize=(10, 8), num=None,  dpi=400)
    ax = plt.subplot(1,1, 1,facecolor=[68/255,1/255,80/255])
    plt.rc('font', size=30)
    h = plt.hist2d(y1, x1, norm=mpl.colors.LogNorm(vmax=0.025), 
                  bins=[40, 40], 
                  cmap="viridis", density=True)
    plt.colorbar(label="Probability")
    plt.ylabel("Py")
    plt.xlabel("Px")
    plt.ylim(0, 80)
    plt.xlim(0, 80)
    plt.title(fr"$t=${tc}")
    fig.set_size_inches(10, 8)
    plt.savefig(f"toggle_joint_sim_tc{tc}.svg", bbox_inches="tight", dpi=400)
    plt.close()

    plt.figure(figsize=(10, 8), num=None,  dpi=400)
    ax = plt.subplot(1,1, 1,facecolor=[68/255,1/255,80/255])
    plt.rc('font', size=30)
    h = plt.hist2d(y2, x2, norm=mpl.colors.LogNorm(vmax=0.025), 
                  bins=[40, 40], 
                  cmap="viridis", density=True)
    plt.colorbar(label="Probability")
    plt.ylabel("Py")
    plt.xlabel("Px")
    plt.ylim(0, 80)
    plt.xlim(0, 80)
    plt.title(fr"$t=${tc}")
    fig.set_size_inches(10, 8)
    plt.savefig(f"toggle_joint_rnn_tc{tc}.svg", bbox_inches="tight", dpi=400)
    plt.close()

    plt.figure(figsize=(10, 8), num=None,  dpi=400)
    ax = plt.subplot(1,1, 1,facecolor=[68/255,1/255,80/255])
    plt.rc('font', size=30)
    h = plt.hist2d(y3, x3, norm=mpl.colors.LogNorm(vmax=0.025), 
                  bins=[40, 40], 
                  cmap="viridis", density=True)
    plt.colorbar(label="Probability")
    plt.ylabel("Py")
    plt.xlabel("Px")
    plt.ylim(0, 80)
    plt.xlim(0, 80)
    plt.title(fr"$t=${tc}")
    fig.set_size_inches(10, 8)
    plt.savefig(f"toggle_joint_met_tc{tc}.svg", bbox_inches="tight", dpi=400)
    plt.close()

In [63]:
fig, axes = plt.subplots(3, 3, figsize=(40, 30))
plt.rc('font', size=30)
ts = [1, 3, 40]
for i in range(3):
    tc = ts[i]
    ri = np.abs(rnn_times-tc).argmin()
    si = np.round(tc/step).astype(int)
    mi = np.abs(met_times-tc).argmin()
    x1 = sim_Px[:, si]
    y1 = sim_Py[:, si]
    x2 = rnn_Px[ri, :]
    y2 = rnn_Py[ri, :]
    x3 = met_Px[mi, :]
    y3 = met_Py[mi, :]
    
    ax = axes[1, i]
    ax.set(facecolor=[68/255,1/255,80/255])
    h = ax.hist2d(y1, x1, norm=mpl.colors.LogNorm(vmax=0.025), 
                  bins=[40, 40], 
                  cmap="viridis", density=True)
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 80)
    
    ax = axes[2, i]
    ax.set(facecolor=[68/255,1/255,80/255])
    h = ax.hist2d(y2, x2, norm=mpl.colors.LogNorm(vmax=0.025), 
                  bins=[40, 40], 
                  cmap="viridis", density=True)
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 80)

    ax = axes[0, i]
    ax.set(facecolor=[68/255,1/255,80/255])
    h = ax.hist2d(y3, x3, norm=mpl.colors.LogNorm(vmax=0.025), 
                  bins=[40, 40], 
                  cmap="viridis", density=True)
    ax.set_xlim(0, 80)
    ax.set_ylim(0, 80)
    
axes[0, 0].set_ylabel(r"$P_y$", fontsize=60)
axes[1, 0].set_ylabel(r"$P_y$", fontsize=60)
axes[2, 0].set_ylabel(r"$P_y$", fontsize=60)
axes[2, 0].set_xlabel(r"$P_x$", fontsize=60)
axes[2, 1].set_xlabel(r"$P_x$", fontsize=60)
axes[2, 2].set_xlabel(r"$P_x$", fontsize=60)

axes[0, 0].set_title(r"$t=1$", fontsize=60)
axes[0, 1].set_title(r"$t=3$", fontsize=60)
axes[0, 2].set_title(r"$t=40$", fontsize=60)

cb = fig.colorbar(h[3], ax=axes[:, :], shrink=0.6, label="Probability", location="right")
    
fig.set_size_inches(40, 30)
plt.savefig(f"toggle_joint_met.svg", bbox_inches="tight", dpi=400)
plt.close()

### Plot mean and std

In [50]:
plt.figure(figsize=(9, 8), num=None,  dpi=400)
plt.rc('font', size=30)

g1 = []
g2 = []
g3 = []
for tc in np.arange(40):
    ri = np.abs(rnn_times-tc).argmin()
    si = np.round(tc/step).astype(int)
    mi = np.abs(met_times-tc).argmin()
    x1 = sim_Gx[:, si]
    x2 = rnn_Gx[ri, :]
    x3 = met_Gx[mi, :]
    df1 = pd.DataFrame(x1)
    df2 = pd.DataFrame(x2)
    df3 = pd.DataFrame(x3)
    p1 = df1.value_counts(normalize=True).sort_index()
    p2 = df2.value_counts(normalize=True).sort_index()
    p3 = df3.value_counts(normalize=True).sort_index()
    p1 = p1.values
    p2 = p2.values
    p3 = p3.values
    l1 = np.arange(len(p1))
    l2 = np.arange(len(p2))
    l3 = np.arange(len(p3))
    m1 = np.sum(p1*l1)
    m2 = np.sum(p2*l2)
    m3 = np.sum(p3*l3)
    std1 = np.sqrt(np.sum(p1*(l1-m1)*(l1-m1)))
    std2 = np.sqrt(np.sum(p2*(l2-m2)*(l2-m2)))
    std3 = np.sqrt(np.sum(p3*(l3-m3)*(l3-m3)))
    g1.append(m1)
    g2.append(m2)
    g3.append(m3)
    
    
plt.plot(g1, g2, marker='o', linestyle='None',
         color=jet_12_colors[12, :], markersize=15, alpha=0.5, label="RNN-Gx")

plt.plot(g1, g3, marker='o', linestyle='None',
         color=jet_12_colors[3, :], markersize=15, alpha=0.5, label="MET-Gx")

g1 = []
g2 = []
g3 = []
for tc in np.arange(40):
    ri = np.abs(rnn_times-tc).argmin()
    si = np.round(tc/step).astype(int)
    mi = np.abs(met_times-tc).argmin()
    x1 = sim_Gy[:, si]
    x2 = rnn_Gy[ri, :]
    x3 = met_Gy[mi, :]
    df1 = pd.DataFrame(x1)
    df2 = pd.DataFrame(x2)
    df3 = pd.DataFrame(x3)
    p1 = df1.value_counts(normalize=True).sort_index()
    p2 = df2.value_counts(normalize=True).sort_index()
    p3 = df3.value_counts(normalize=True).sort_index()
    p1 = p1.values
    p2 = p2.values
    p3 = p3.values
    l1 = np.arange(len(p1))
    l2 = np.arange(len(p2))
    l3 = np.arange(len(p3))
    m1 = np.sum(p1*l1)
    m2 = np.sum(p2*l2)
    m3 = np.sum(p3*l3)
    std1 = np.sqrt(np.sum(p1*(l1-m1)*(l1-m1)))
    std2 = np.sqrt(np.sum(p2*(l2-m2)*(l2-m2)))
    std3 = np.sqrt(np.sum(p3*(l3-m3)*(l3-m3)))
    g1.append(m1)
    g2.append(m2)
    g3.append(m3)
    
    
plt.plot(g1, g2, marker='s', linestyle='None',
         color=jet_12_colors[12, :], markersize=15, alpha=0.5, label="RNN-Gy")

plt.plot(g1, g3, marker='s', linestyle='None',
         color=jet_12_colors[3, :], markersize=15, alpha=0.5, label="MET-Gy")
    
x = np.linspace(0, 1.5, 100)
y = x
plt.plot(x, y, color="black", lw=2)
plt.xlabel("Gillespie")
plt.legend()
fig.set_size_inches(9, 8)
plt.savefig("toggle_G_mean_compare.svg", bbox_inches="tight", dpi=400)
plt.close()

In [51]:
plt.figure(figsize=(9, 8), num=None,  dpi=400)
plt.rc('font', size=30)

g1 = []
g2 = []
g3 = []
for tc in np.arange(40):
    ri = np.abs(rnn_times-tc).argmin()
    si = np.round(tc/step).astype(int)
    mi = np.abs(met_times-tc).argmin()
    x1 = sim_Gx[:, si]
    x2 = rnn_Gx[ri, :]
    x3 = met_Gx[mi, :]
    df1 = pd.DataFrame(x1)
    df2 = pd.DataFrame(x2)
    df3 = pd.DataFrame(x3)
    p1 = df1.value_counts(normalize=True).sort_index()
    p2 = df2.value_counts(normalize=True).sort_index()
    p3 = df3.value_counts(normalize=True).sort_index()
    p1 = p1.values
    p2 = p2.values
    p3 = p3.values
    # p1[p1<1e-4]=0
    # p2[p2<1e-4]=0
    # p3[p3<1e-4]=0
    l1 = np.arange(len(p1))
    l2 = np.arange(len(p2))
    l3 = np.arange(len(p3))
    m1 = np.sum(p1*l1)
    m2 = np.sum(p2*l2)
    m3 = np.sum(p3*l3)
    std1 = np.sqrt(np.sum(p1*(l1-m1)*(l1-m1)))
    std2 = np.sqrt(np.sum(p2*(l2-m2)*(l2-m2)))
    std3 = np.sqrt(np.sum(p3*(l3-m3)*(l3-m3)))
    g1.append(std1)
    g2.append(std2)
    g3.append(std3)
    
    
plt.plot(g1, g2, marker='o', linestyle='None',
         color=jet_12_colors[12, :], markersize=15, alpha=0.5, label="RNN-Gx")

plt.plot(g1, g3, marker='o', linestyle='None',
         color=jet_12_colors[3, :], markersize=15, alpha=0.5, label="MET-Gx")

g1 = []
g2 = []
g3 = []
for tc in np.arange(40):
    ri = np.abs(rnn_times-tc).argmin()
    si = np.round(tc/step).astype(int)
    mi = np.abs(met_times-tc).argmin()
    x1 = sim_Gy[:, si]
    x2 = rnn_Gy[ri, :]
    x3 = met_Gy[mi, :]
    df1 = pd.DataFrame(x1)
    df2 = pd.DataFrame(x2)
    df3 = pd.DataFrame(x3)
    p1 = df1.value_counts(normalize=True).sort_index()
    p2 = df2.value_counts(normalize=True).sort_index()
    p3 = df3.value_counts(normalize=True).sort_index()
    p1 = p1.values
    p2 = p2.values
    p3 = p3.values
    # p1[p1<1e-4]=0
    # p2[p2<1e-4]=0
    # p3[p3<1e-4]=0
    l1 = np.arange(len(p1))
    l2 = np.arange(len(p2))
    l3 = np.arange(len(p3))
    m1 = np.sum(p1*l1)
    m2 = np.sum(p2*l2)
    m3 = np.sum(p3*l3)
    std1 = np.sqrt(np.sum(p1*(l1-m1)*(l1-m1)))
    std2 = np.sqrt(np.sum(p2*(l2-m2)*(l2-m2)))
    std3 = np.sqrt(np.sum(p3*(l3-m3)*(l3-m3)))
    g1.append(std1)
    g2.append(std2)
    g3.append(std3)
    
    
plt.plot(g1, g2, marker='s', linestyle='None',
         color=jet_12_colors[12, :], markersize=15, alpha=0.5, label="RNN-Gy")

plt.plot(g1, g3, marker='s', linestyle='None',
         color=jet_12_colors[3, :], markersize=15, alpha=0.5, label="MET-Gy")
    
x = np.linspace(0, 1, 100)
y = x
plt.plot(x, y, color="black", lw=2)
plt.xlabel("Gillespie")
plt.legend()
fig.set_size_inches(9, 8)
plt.savefig("toggle_G_std_compare.svg", bbox_inches="tight", dpi=400)
plt.close()