### SVI

This code runs SVI, but not to convergence. It should be run on a small subset of the data. 

In [None]:
from matplotlib import pyplot as plt
from jax import numpy as jnp
import numpyro
numpyro.set_host_device_count(8)
import jax
import numpy as np
import pandas as pd
import os
import glob

from tfscreen.analysis.hierarchical.growth_model import GrowthModel
from tfscreen.analysis.hierarchical.run_inference import RunInference
from tfscreen.analysis.hierarchical.analyze_theta import _run_svi
import tfscreen

In [None]:
growth_df = tfscreen.util.read_dataframe("growth.csv")
v = pd.unique(growth_df["genotype"])
np.random.shuffle(v)
v[:10]

In [None]:


to_get_list = ["wt","M42I","H74A","K84L","I64N","L45P","I79C","T68V","A81C",
               'S77I', 'I64R', 'S85V', 'T34P', 'A75G', 'K37I', 'A67S', 'E39L',
               'S69P', 'A92Y']

growth_df = tfscreen.util.read_dataframe("growth.csv")
growth_df_subset = growth_df[growth_df["genotype"].isin(to_get_list)].reset_index(drop=True)

bind_df = tfscreen.util.read_dataframe("binding.csv")
bind_df_subset = bind_df[bind_df["genotype"].isin(to_get_list)].reset_index(drop=True)

gm = GrowthModel(growth_df=growth_df_subset,
                 binding_df=bind_df_subset,
                 batch_size=8,
                 transformation="congression",
                 condition_growth="hierarchical",
                 ln_cfu0="hierarchical",
                 dk_geno="hierarchical",
                 activity="horseshoe",
                 theta="hill",
                 theta_binding_noise="none",
                 theta_growth_noise="none",
                 spiked_genotypes=["wt","M42I","H74A","K84L"])

ri = RunInference(gm,seed=42)

for f in glob.glob("svi-tmp_*"): 
    os.remove(f)

svi_state, params, converged = _run_svi(ri,
                                        init_params=None,
                                        checkpoint_file=None,
                                        out_root="svi-tmp",
                                        adam_step_size=1e-3,
                                        adam_final_step_size=1e-4,
                                        adam_clip_norm=1.0,
                                        elbo_num_particles=2,
                                        convergence_tolerance=1e-9,
                                        convergence_window=10000,
                                        checkpoint_interval=10000,
                                        num_steps=10000000,
                                        num_posterior_samples=10000,
                                        sampling_batch_size=100,
                                        forward_batch_size=512,
                                        always_get_posterior=False)

# This call forces getting posteriors
svi_state, params, converged = _run_svi(ri,
                                        init_params=None,
                                        checkpoint_file=None,
                                        out_root="svi-tmp",
                                        adam_step_size=1e-3,
                                        adam_final_step_size=1e-4,
                                        adam_clip_norm=1.0,
                                        elbo_num_particles=2,
                                        convergence_tolerance=1e-6,
                                        convergence_window=100,
                                        checkpoint_interval=10,
                                        num_steps=0,
                                        num_posterior_samples=5000,
                                        sampling_batch_size=100,
                                        forward_batch_size=512,
                                        always_get_posterior=True)

params = gm.extract_parameters("svi-tmp_posterior.npz")

In [None]:

# Plot human readable losses
df = pd.read_csv("svi-tmp_losses.txt",header=None)
plt.plot(df[0],'o')
plt.xlabel('step')
plt.ylabel("ELBO")
plt.show()


# Plot raw losses
losses_array = np.fromfile("svi-tmp_losses.bin", dtype=np.float64)
plt.plot(losses_array)
plt.yscale('log')


In [None]:
def plot_fit_param(df,title=None,ax=None):

    if ax is None:
        fig, ax = plt.subplots(1,figsize=(6,6))

    x = np.arange(len(df))
    y = df["median"]
    y_low_95 = y - df["lower_95"]
    y_low_std = y - df["lower_std"]
    y_high_std = df["upper_std"] - y
    y_high_95 = df["upper_95"] - y

    for i in range(len(x)):
        ax.fill([x[i],x[i],x[i]+1,x[i]+1],
                [0   ,y[i],y[i]  ,0],
                facecolor='gray',edgecolor='black')
    ax.errorbar(x=x+0.5,
                 y=y,
                 yerr=[y_low_95,y_high_95],
                 lw=0,
                 elinewidth=1,
                 capsize=5,
                 color='black')
    
    ax.set_xticks(x + 0.5)

    if "genotype" in df.columns:
        ax.set_xticklabels(df["genotype"])
    else:
        ax.set_xticklabels(df["condition"])
        ax.tick_params(axis='x', labelrotation=90)
    ax.set_title(title)

    ax.spines['top'].set_visible(False)
    ax.spines['right'].set_visible(False)
    
    fig = ax.get_figure()
    fig.tight_layout()
        

    return fig, ax

param_keys = list(params.keys())
param_keys.sort()
for p in param_keys:

    fig, ax = plot_fit_param(params[p],p)

    
    