In [6]:
from treeflow_pipeline.util import yaml_input
import treeflow_pipeline.model

model_file = "../config/model.yaml"
model = treeflow_pipeline.model.Model(yaml_input(model_file))
model.clock_model, model.clock_params

('relaxed_lognormal_conjugate',
 {'rate_loc_precision': {'normalgamma': {'loc': -8.117321296021004,
    'precision_scale': 0.6145984625809157,
    'concentration': 2.0264094157104164,
    'rate': 0.055894814133092476}}})

In [54]:
import tensorflow_probability as tfp
import treeflow.priors
import numpy as np

cov_quantiles = np.array([0.2, 0.7])
precision_quantiles = 1.0 / np.log(np.square(cov_quantiles)+1.0)[::-1]
precision_prior_params, res = treeflow.priors.get_params_for_quantiles(tfp.distributions.Gamma, precision_quantiles)
precision_prior_params

{'concentration': 3.3164221870890347, 'rate': 0.3028646047813991}

In [56]:
prior_params = {
    'loc': -8.117321296021004,
    'precision_scale': 0.6145984625809157,
    'concentration': precision_prior_params["concentration"],
    'rate': precision_prior_params["rate"]
}
prior_params

{'loc': -8.117321296021004,
 'precision_scale': 0.6145984625809157,
 'concentration': 3.3164221870890347,
 'rate': 0.3028646047813991}

In [102]:
from ipywidgets import interact, IntSlider

import matplotlib.pyplot as plt
import seaborn as sns

def prior_plot(loc, precision_scale):
    params = {
        'loc': loc,
        'precision_scale': precision_scale,
        'concentration': precision_prior_params["concentration"],
        'rate': precision_prior_params["rate"]
    }
    
    n_plot_samples = 10000
    
    dist = tfp.distributions.JointDistributionNamed(treeflow.priors.get_normal_conjugate_prior_dict(**params))
    prior_samples = dist.sample(n_plot_samples)
    prior_loc = prior_samples["loc"].numpy()
    prior_scale = treeflow.priors.precision_to_scale(prior_samples["precision"]).numpy()
    prior_cov = np.sqrt(np.exp(np.square(prior_scale))-1)
    prior_mean = np.exp(prior_loc + np.square(prior_scale) / 2.0)

    fig, ax = plt.subplots(figsize=(10, 10))

    ax.scatter(prior_mean, prior_cov, alpha=0.5)
    ax.set_xlabel("Prior mean")
    ax.set_ylabel("Prior CoV")

    ps = np.array([0.025, 0.975])
    cov_quantiles = np.quantile(prior_cov, ps)
    mean_quantiles = np.quantile(prior_mean, ps)
    return dict(cov_quantiles=cov_quantiles, mean_quantiles=mean_quantiles), params
    
interact(
    prior_plot,
    loc=-8.117321296021004 + np.log(10),
    precision_scale=0.3
)

interactive(children=(FloatSlider(value=-5.814736203026958, description='loc', max=5.814736203026958, min=-17.…

<function __main__.prior_plot(loc, precision_scale)>

In [104]:
pop_size_quantiles = np.array([1.0, 2.0])
treeflow.priors.get_params_for_quantiles(tfp.distributions.LogNormal, pop_size_quantiles)

({'loc': 0.06367702938213937, 'scale': 0.13341435492327347},
       fun: 0.41134913257826833
  hess_inv: <2x2 LbfgsInvHessProduct with dtype=float64>
       jac: array([-1.99921818, -2.76387171])
   message: b'CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH'
      nfev: 13
       nit: 8
      njev: 13
    status: 0
   success: True
         x: array([0.06367703, 0.13341435]))