In [133]:
import numpy as np
import torch
from mcspace.model import MCSPACE
from mcspace.utils import pickle_load, pickle_save, MODEL_FILE
from pathlib import Path
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns
import pandas as pd

# Paths

In [134]:
basepath = Path("./")
modelpath = basepath / "best_run" / "seed_5"

# Load results

In [135]:
thetadf = pd.read_csv(modelpath / "assemblages.csv")
betadf = pd.read_csv(modelpath / "assemblage_proportions.csv", index_col=0)
pertsdf = pd.read_csv(modelpath / "perturbation_bayes_factors.csv", index_col=0)
radf= pd.read_csv(modelpath / "relative_abundances.csv", index_col=0)
model = torch.load(modelpath / MODEL_FILE, weights_only=False)


In [136]:
gamma_probs = np.concatenate([[1],model.beta_params.sparsity_params.q_probs.cpu().detach().clone().numpy()])
gamma_percentile = 0.95
gammasub = (gamma_probs>gamma_percentile)

In [137]:
pert_probs = model.beta_params.perturbation_indicators.q_probs.cpu().clone().detach().numpy()
pert_probs = pert_probs[gammasub,:]
pert_p = np.mean(pert_probs)

In [138]:
times = np.array(betadf['Time'].unique()).astype(int)
subjects = np.array(betadf['Subject'].unique())
assemblages = np.array(betadf['Assemblage'].unique())

ncomm = len(assemblages)
nsubj = len(subjects)
ntime = len(times)

In [139]:
# reshape beta
beta_vals = np.zeros((ncomm, nsubj, ntime))

for i, comm in enumerate(assemblages):
    for j, subj in enumerate(subjects):
        for k, time in enumerate(times):
            beta_vals[i, j, k] = betadf.loc[(betadf['Assemblage'] == comm) & 
                                            (betadf['Subject'] == subj) & 
                                            (betadf['Time'] == time), 'Value'].values[0]


In [140]:
taxlevels = ['Otu', 'Domain', 'Phylum', 'Class', 'Order', 'Family', 'Genus', 'Species']
thetadf = thetadf.set_index(taxlevels)

theta_vals = thetadf.values

In [141]:
delta_beta1 = beta_vals[pert_probs[:,0]>0.5,:,1] - beta_vals[pert_probs[:,0]>0.5,:,0]
delta_beta2 = beta_vals[pert_probs[:,1]>0.5,:,3] - beta_vals[pert_probs[:,1]>0.5,:,2]
delta_beta3 = beta_vals[pert_probs[:,2]>0.5,:,5] - beta_vals[pert_probs[:,2]>0.5,:,4]

delta_beta = np.concatenate([delta_beta1, delta_beta2, delta_beta3], axis=0)
delta_beta_vals = delta_beta.mean(axis=1)

delta_beta_mu = delta_beta_vals.mean()
delta_beta_var = delta_beta_vals.var()

print(delta_beta_mu)
print(delta_beta_var)

0.06633539620829015
0.013078230388618295


In [142]:
process_var = model.beta_params.var_process.cpu().clone().detach().numpy()

In [143]:
n_samples = 1000
garb_samples = np.zeros(n_samples)
for i in range(n_samples):
    pi_sample, _ = model.garbage_weights()
    garb_samples[i] = np.mean(pi_sample.cpu().detach().numpy()[0])
garb_weight = np.mean(garb_samples)

print(garb_weight)

0.018074205642566087


In [144]:
pert_times = [18, 43, 65]  # Perturbation times
pert_state = [0,1,-1,1,-1,1,-1]

# Save model fit results
Save beta parameters, theta parameters, process variance, times, subjects, perturbation_probability, perturbation_magnitude and variance

In [145]:
data = {"beta": beta_vals,
        "theta": theta_vals,
        "perturbation_probs": pert_p,
        "perturbation_magnitude_mean": delta_beta_mu,
        "perturbation_magnitude_var": delta_beta_var,
        "process_var": process_var,
        "times": times,
        "subjects": subjects,
        "pert_times": pert_times,
        "pert_state": pert_state,
        "pi_garb": garb_weight}

pickle_save(basepath / "time_series_params.pkl", data)