### NUTS

This code runs MCMC to make sure the model logic is working independent of any SVI complexity. It should be run on a small subset of the data. 

In [None]:
from matplotlib import pyplot as plt
from numpyro.infer import MCMC, NUTS, init_to_value
from jax import numpy as jnp
import numpyro
numpyro.set_host_device_count(8)
import jax
import numpy as np

from numpyro.infer.util import log_density
from numpyro.infer import init_to_median, init_to_feasible

from tfscreen.analysis.hierarchical.growth_model import GrowthModel
#from tfscreen.analysis.hierarchical.growth_model.batch import sample_batch
import tfscreen

def run_nuts_diagnostic(gm,seed,num_warmup=500, num_samples=500, num_chains=1,target_accept_prob=0.9):
    """
    Run NUTS on a model. 
    """

    # Create seed
    main_key = jax.random.PRNGKey(seed)

    # Setup model
    jax_model_kwargs = {
        "priors": gm.priors,
        "data": gm.data 
    }

    # Let NumPyro try to find a valid init using its standard heuristic
    init_params, potential_fn, post_transform, _ = numpyro.infer.util.initialize_model(
        main_key,
        gm.jax_model,
        model_args=[],
        model_kwargs=jax_model_kwargs,
        init_strategy=init_to_median # safer than uniform for hierarchical
    )
    
    init_strategy = init_to_value(values=init_params)
    #init_strategy = init_to_value(values=gm.init_params)

    # Setup kernel
    kernel = NUTS(gm.jax_model, 
                  init_strategy=init_strategy,
                  target_accept_prob=target_accept_prob)

    # Setup mcmc
    mcmc = MCMC(kernel, 
                num_warmup=num_warmup, 
                num_samples=num_samples, 
                num_chains=num_chains,
                progress_bar=True)

    # Run sampler
    print("Starting NUTS run...", flush=True)
    run_key, main_key = jax.random.split(main_key)
    mcmc.run(run_key, **jax_model_kwargs)

    
    # Check for divergences (geometry failures)
    divergences = mcmc.get_extra_fields()['diverging']
    num_divergences =  jnp.sum(divergences)
    print(f"Number of divergences: {num_divergences}")
    
    return mcmc

to_get_list = ["wt","M42I","H74A","K84L","I64N","L45P","I79C","T68V","A81C"]

growth_df = tfscreen.util.read_dataframe("growth.csv")
growth_df_subset = growth_df[growth_df["genotype"].isin(to_get_list)].reset_index(drop=True)

bind_df = tfscreen.util.read_dataframe("binding.csv")
bind_df_subset = bind_df[bind_df["genotype"].isin(to_get_list)].reset_index(drop=True)

gm = GrowthModel(growth_df=growth_df_subset,
                 binding_df=bind_df_subset,
                 theta="hill",
                 condition_growth="hierarchical",
                 theta_binding_noise="none",
                 theta_growth_noise="none",
                 activity="horseshoe")

mcmc = run_nuts_diagnostic(gm,num_chains=1,seed=1242,num_warmup=500,num_samples=1000) 
posteriors = mcmc.get_samples()
params = gm.extract_parameters(posteriors)

In [None]:
print(params.keys())
params["theta_high"]

In [None]:
def plot_fit_param(df,title=None,ax=None):

    if ax is None:
        fig, ax = plt.subplots(1,figsize=(6,6))

    x = np.arange(len(df))
    y = df["median"]
    y_low_95 = y - df["lower_95"]
    y_low_std = y - df["lower_std"]
    y_high_std = df["upper_std"] - y
    y_high_95 = df["upper_95"] - y

    for i in range(len(x)):
        ax.fill([x[i],x[i],x[i]+1,x[i]+1],
                [0   ,y[i],y[i]  ,0],
                facecolor='gray',edgecolor='black')
    ax.errorbar(x=x+0.5,
                 y=y,
                 yerr=[y_low_95,y_high_95],
                 lw=0,
                 elinewidth=1,
                 capsize=5,
                 color='black')
    
    ax.set_xticks(x + 0.5)

    if "genotype" in df.columns:
        ax.set_xticklabels(df["genotype"])
    else:
        ax.set_xticklabels(df["condition"])
        ax.tick_params(axis='x', labelrotation=90)
    ax.set_title(title)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    fig = ax.get_figure()
    fig.tight_layout()
        

    return fig, ax

param_keys = list(params.keys())
param_keys.sort()
for p in param_keys:

    fig, ax = plot_fit_param(params[p],p)

    
    

In [None]:
posteriors["transformation_sigma_value"]