In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import sys
sys.path.append("..")
from src.dataset import PhyloSimulator
import seaborn as sns
from src.utils import get_results
from torch import no_grad
import matplotlib.gridspec as gs
from scipy.stats import norm

In [None]:
beta_true =  [.02, .03, .06, .03, .03, .12, .03] # 
mu_eo = -3
sigma_eo = 1
observed_seed = 30
n_clusters = 34
capacity = [100, 20, 20, 20, 20, 20, 2]

# relative risks
np.array(beta_true) / np.array(capacity) / (beta_true[0] / capacity[0])

In [None]:
phylo = PhyloSimulator(
    beta_true, mu_eo, sigma_eo, observed_seed, n_sample=None, notebook_mode=True, time_first=False
)

x_o = phylo.get_observed_data()

In [None]:
genomic = {}
levels = ["facility"] + [f"Floor {i}" for i in range(1,6)] + ["room"]
for i, l in enumerate(levels):
    start = 7 + i * n_clusters
    stop = 7 + (i + 1) * n_clusters
    genomic[l] = np.array(x_o[start:stop])

In [None]:
plt.plot(x_o[0], label="Infected")
plt.plot(100 - x_o[0], label="Susceptible")
plt.ylabel("Patients")
plt.xlabel("Weeks")
plt.legend()
plt.savefig("images/phylo-sim/si_counts.png")
plt.show()

In [None]:
floor_rates = x_o[1:6]
labels=["Floor 1", "Floor 2", "Floor 3", "Floor 4", "Floor 5"]
for i, label  in enumerate(labels):
    # temp_pop = np.where(floor_pops[i] == 1e-3 1, floor_pops[i])
    plt.plot(floor_rates[i], label=label, linestyle="-", alpha=0.6)
plt.legend()
plt.xlabel("Weeks")
plt.ylabel("Infected Patients")
plt.savefig("images/phylo-sim/floor_counts.png")
plt.show()

In [None]:
plt.plot(x_o[6] / 2)
plt.show()

### visualizing clustering

In [None]:
cluster_agg = {}
for k, v in genomic.items():
    if k == "room":
        continue
    cluster_agg[k] = v.sum(1)
cluster_agg = pd.DataFrame(cluster_agg)
cluster_agg.index.name = "cluster"
# cluster_agg = cluster_agg.reset_index(names="cluster")
# cluster_agg = cluster_agg.melt(id_vars="cluster", value_vars=["facility", "floor_1", "floor_2", "floor_3", "floor_4", "floor_5"],
#                               value_name="patient-weeks", var_name="zone")

In [None]:
cluster_agg.drop(columns="facility").plot.bar(stacked=True)
plt.ylabel("Infected Weeks")
plt.xlabel("Cluster")

In [None]:
genomic["room"]

In [None]:
genomic["room"].sum(0) / 2
# up to six rooms have infected roommates belonging to the same cluster at any point

In [None]:
#looks like R_0 is around 2?
# \bar \beta = .15, roughly speaking
# average period of infectiousness: theoretically 13.3

# NPE Results

## gaussian density network

### ablation experiment: epid data only

In [None]:
df_abl = get_results("../multirun/2025-05-08/10-48-53").sort_values("val_loss")
mu = np.array(df_abl.iloc[0]["mu"][0])
sigma = np.array(df_abl.iloc[0]["sigma"])

In [None]:
np.exp(mu + np.diag(sigma)/2)

#### log scale

In [None]:
# 2025-04-24/15-59-52/"
df_gdn = get_results("../multirun/2025-04-25/11-39-55")
df_gdn = df_gdn[df_gdn["log_scale"] == True].sort_values("val_loss")
mu = np.array(df_gdn.iloc[0]["mu"][0])
sigma = np.array(df_gdn.iloc[0]["sigma"])

In [None]:
np.exp(mu + np.diag(sigma)/2)

In [None]:
df_gdn.iloc[0]["val_loss"]

In [None]:
D = np.diag(1 / np.sqrt(np.diag(sigma)))
corr = D @ sigma @ D
plt.matshow(corr, cmap="rocket")
plt.colorbar()
plt.xlabel(r"$\beta_j$")
plt.ylabel(r"$\beta_j$")
# plt.savefig("images/corr.png")
plt.show()

#### natty scale

In [None]:
# df_gdn_n = get_results("../multirun/2025-04-25/11-39-55")
# df_gdn_n = df_gdn_n[df_gdn_n["log_scale"] == False].sort_values("val_loss")
# mu = np.array(df_gdn_n.iloc[0]["mu"][0])
# sigma = np.array(df_gdn_n.iloc[0]["sigma"])

# # takeaway: point estimates are less accurate, covariance structure is off

### mean field estimation

In [None]:
df_mf = get_results("../multirun/2025-04-25/12-48-36")
df_mf = df_mf[df_mf["log_scale"] == True].sort_values("val_loss")
mu = np.array(df_mf.iloc[0]["mu"][0])
sigma = np.array(df_mf.iloc[0]["sigma"])

In [None]:
np.exp(mu + np.diag(sigma)/2)

In [None]:
df_mf.iloc[0]["val_loss"]

### RNN

In [None]:
df_rnn = get_results("../multirun/2025-04-25/16-45-51")
df_rnn = df_rnn.sort_values("val_loss")
mu = np.array(df_rnn.iloc[0]["mu"][0])
sigma = np.array(df_rnn.iloc[0]["sigma"])

In [None]:
np.exp(mu + np.diag(sigma)/2)

In [None]:
df_rnn.iloc[0]["val_loss"]

### Transformer

In [None]:
df_tf = pd.read_csv("df_phylo_tf.csv", index_col=0)
mu = np.array(eval(df_tf.iloc[0]["mu"])[0])
sigma = np.array(eval(df_tf.iloc[0]["sigma"]))

In [None]:
df_tf.iloc[0]["val_loss"]

In [None]:
np.exp(mu + np.diag(sigma)/2)

In [None]:
D = np.diag(1 / np.sqrt(np.diag(sigma)))
corr = D @ sigma @ D
plt.matshow(corr) # cmap="rocket")
plt.colorbar()
plt.xlabel(r"$\beta_j$")
plt.ylabel(r"$\beta_j$")
# plt.savefig("images/corr.png")
plt.show()

In [None]:
corr

## normalizing flow

In [None]:
df_nf = get_results("../multirun/2025-04-26/13-34-42")
df_nf = df_nf.sort_values("val_loss")
df_nf.head(1)

In [None]:
df_nf

In [None]:
df_nf = get_results("../multirun/2025-04-29/10-21-17")
df_nf = df_nf.sort_values("val_loss")
df_nf.head(1)

In [None]:
# 'd_theta', 'n_layers', 'd_model', 'lr', 'weight_decay', and 'embed_dim'

In [None]:
checkpt = "../multirun/2025-04-26/13-34-42/11/crkp/umnreprh/checkpoints/epoch=241-step=242.ckpt"
# how to load checkpoint?
from src.model import RealNVP
nf = RealNVP.load_from_checkpoint(checkpt, n_layers=4, weight_decay=0.02, lr=1e-3, d_model=80, embed_dim=16, d_theta=7, 
                                  d_x=x_o.shape[::-1],) 

In [None]:
nf.on_fit_start()
with no_grad():
    M = 100
    sample = nf.sample(M, x_o.T.unsqueeze(0))

In [None]:
torch.exp(sample).mean(0).cpu().numpy()

In [None]:
# sns.pairplot(pd.DataFrame(sample.cpu().numpy()))

In [None]:
sigma_nf = torch.cov(sample.T).cpu().numpy()

In [None]:
D = np.diag(1 / np.sqrt(np.diag(sigma_nf)))
corr = D @ sigma_nf @ D
plt.matshow(corr, cmap="rocket")
plt.colorbar()
plt.xlabel(r"$\beta_j$")
plt.ylabel(r"$\beta_j$")
# plt.savefig("images/corr.png")
plt.show()

### glossing results

In [None]:
# 2k
m = [-3.865, -3.673, -3.137, -3.829, -3.605, -2.753, -3.737]
s = [0.195, 0.661, 0.581, 0.727, 0.931, 0.375, 0.344]

for a,b in zip(m,s):
    print(np.exp(a + b/2))

In [None]:
beta_true

In [None]:
# normalizing flow. for comparison
[0.0168, 0.0118, 0.0785, 0.0382, 0.0178, 0.0736, 0.0172]

In [None]:
# 4k
m = [-3.513, -3.436, -3.02, -2.847, -3.681, -2.758, -3.665]                                                                                                                  
s = [0.156, 0.629, 0.427, 0.306, 0.518, 0.321, 0.336]
for a,b in zip(m,s):
    print(np.exp(a + b/2))

# viz

In [None]:
labels_full = ["Facility", "Floor 1", "Floor 2", "Floor 3", "Floor 4", "Floor 5", "Room"]

In [None]:
mu_eo = np.array(df_abl.iloc[0]["mu"][0])
sigma_eo = np.array(df_abl.iloc[0]["sigma"])
mu_gen = np.array(eval(df_tf.iloc[0]["mu"])[0])
sigma_gen = np.array(eval(df_tf.iloc[0]["sigma"]))

grid = gs.GridSpec(8, 1)
fig = plt.figure(figsize=(5,7))
x_min = -6
x_max = -0
x = np.arange(x_min, x_max, 0.05)
axes = []
alpha = 0.5
for i in range(7):
    legend = False if i > 0 else True
    axes.append(fig.add_subplot(grid[i:i+1, 0:]))

    y0 = norm.pdf(x, mu_eo[i], np.sqrt(sigma_eo[i][i]))
    y1 = norm.pdf(x, mu_gen[i], np.sqrt(sigma_gen[i][i]))
    sns.lineplot(x=x, y=y0, label="Old Density", ax=axes[-1], legend=legend, alpha=alpha)

    sns.lineplot(x=x, y=y1, label="New Density", ax=axes[-1], legend=legend, alpha=alpha)

    rect = axes[-1].patch
    rect.set_alpha(0)

    axes[-1].set_yticklabels([])
    axes[-1].set_ylabel('')

    axes[-1].set_ylim(0, 2)
    axes[-1].set_xlim(x_min, x_max)

    axes[-1].set_yticks([])

    if i < 7 - 1:
        axes[-1].set_xticklabels([])
    else:
        axes[-1].set_xlabel("Infection Rate (Log-scale)")

    axes[-1].fill_between(x, y1, color='#ff7f0e', alpha=alpha)
    axes[-1].fill_between(x, y0, color='#1f77b4', alpha=alpha)


    axes[-1].text(x_min - .8, 0.3, labels_full[i])
grid.update(hspace= -.0)
plt.tight_layout()
# plt.savefig("images/crkp/crkp_compare_het.png")
plt.show()

In [None]:
mu_eo = np.array(df_abl.iloc[0]["mu"][0])
sigma_eo = np.array(df_abl.iloc[0]["sigma"])
mu_gen = np.array(df_gdn.iloc[0]["mu"][0])
sigma_gen = np.array(df_gdn.iloc[0]["sigma"])
# mu_gen = np.array(eval(df_tf.iloc[0]["mu"])[0])
# sigma_gen = np.array(eval(df_tf.iloc[0]["sigma"]))

grid = gs.GridSpec(8, 1)
fig = plt.figure(figsize=(5,7))
x_min = -6
x_max = -0
x = np.arange(x_min, x_max, 0.05)
axes = []
alpha = 0.5
for i in range(7):
    legend = False if i > 0 else True
    axes.append(fig.add_subplot(grid[i:i+1, 0:]))

    y0 = norm.pdf(x, mu_eo[i], np.sqrt(sigma_eo[i][i]))
    y1 = norm.pdf(x, mu_gen[i], np.sqrt(sigma_gen[i][i]))
    sns.lineplot(x=x, y=y0, label="Posterior Density (E)", ax=axes[-1], legend=legend, alpha=alpha)

    sns.lineplot(x=x, y=y1, label="Posterior Density (E/G)", ax=axes[-1], legend=legend, alpha=alpha)

    rect = axes[-1].patch
    rect.set_alpha(0)

    axes[-1].set_yticklabels([])
    axes[-1].set_ylabel('')

    axes[-1].set_ylim(0, 2)
    axes[-1].set_xlim(x_min, x_max)

    axes[-1].set_yticks([])

    if i < 7 - 1:
        axes[-1].set_xticklabels([])
    else:
        axes[-1].set_xlabel("Infection Rate (Log-scale)")

    axes[-1].fill_between(x, y1, color='#ff7f0e', alpha=alpha)
    axes[-1].fill_between(x, y0, color='#1f77b4', alpha=alpha)


    axes[-1].text(x_min - .8, 0.3, labels_full[i])
grid.update(hspace= -.0)
plt.tight_layout()
plt.savefig("images/phylo-sim/ablation.png")
plt.show()

In [None]:
np.log(np.array([.02, .03, .06, .03, .03, .12, .03]))