In [None]:
# Standard imports, plotting option and constant definition
import pandas as pd
from scipy import stats
import numpy as np
from jax import numpy as jnp

pd.options.plotting.backend = "plotly"

import numpyro
from numpyro import distributions as dist

from numpyro.infer.initialization import init_to_median, init_to_value, init_to_sample
from numpyro.infer import MCMC
from jax import random

from autumn.core.project import get_project
from autumn.settings.region import Region

# Build a model object

In [None]:
project = get_project("sm_covid2", Region.FRANCE)
params = project.param_set.baseline
sm_model = project.run_baseline_model(params) 

In [None]:
runner = sm_model.get_runner(params.to_dict(), dyn_params=["contact_rate"])

### Convert targets to pandas Series

In [None]:
targets = project.calibration.targets

In [None]:
def get_target_indices(target_series, model_times):
    return jnp.array([np.where(model_times == idx)[0][0] for idx in target_series.index])

In [None]:
np_targets, target_indices = {}, {}
for target in targets:
    np_targets[target.data.name] = jnp.array(target.data.to_numpy())
    target_indices[target.data.name] = get_target_indices(target.data, sm_model.times)

In [None]:
priors = {
    "contact_rate": dist.Uniform(0.01, 0.5),
}

In [None]:
all_params = sm_model.builder.get_default_parameters()

In [None]:
def model():
    # numpyro.sample is one of the most common numpyro primitives
    # it generates a sample based on the supplied distribution,
    # and the PRNGKey supplied by the optimizer (more on this later)
    # In most example models the samples would be explicit,
    # but here we iterate through our priors dictionary

    param_updates = {k:numpyro.sample(k, priors[k]) for k in priors}

    
    # runer._run_func is a pure compiled jax function - it bypasses
    # any additional CPython code, therefore can be used directly in
    # numpyro models
    results = runner._run_func(all_params | param_updates)
    
    # Now we assemble our loglikelihood, and collect additional information
    for k,obs in np_targets.items():
        # Note we can use normal looking python code here on our model outputs;
        # internally these are all jax types
        modelled = results["derived_outputs"][k][target_indices[k]]
        
        # For some models, we would use the sample primitive here as well,
        # and supply an 'obs' argument
        # But, by using the numpyro factor primitive we can assemble a more
        # customized ll function, as well as save specific components
        # via the deterministic primitive
        #k_fac = dist.TruncatedNormal(modelled, jnp.std(obs), low=0.0).log_prob(obs).sum()
        k_fac = dist.Normal(modelled, jnp.std(obs)).log_prob(obs).sum()
        
        numpyro.factor(k, k_fac)
        numpyro.deterministic(f"{k}_ll", k_fac)
        #numpyro.deterministic(f"{k}_sum", jnp.sum(modelled))

In [None]:
initial_parameters = {
    "contact_rate": jnp.array((0.2,)),
}

#sa_kernel = numpyro.infer.SA(model, init_strategy=init_to_value(values=initial_parameters))
sa_kernel = numpyro.infer.SA(model, dense_mass=True, adapt_state_size=8, init_strategy=init_to_median)

In [None]:
mcmc = MCMC(sa_kernel, num_chains=2, num_samples=1000,num_warmup=1000)#,thinning=1)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",))
#mcmc.run(rng_key)#, init_params=initial_parameters)

In [None]:
pd.Series(mcmc.get_extra_fields()['accept_prob']).plot()

In [None]:
mcmc.print_summary(exclude_deterministic=False)

In [None]:
class SampleWrapper:
    def __init__(self, samples, params):
        self.samples = samples
        self.params = params
        
    def __getitem__(self, idx):
        out = {}
        for k,v in samples.items():
            if k in self.params:
                out[k] = v[idx]
        return out

In [None]:
samples = mcmc.get_samples(False)

In [None]:
sw = SampleWrapper(samples, sir_model.get_input_parameters())

In [None]:
best_idx = pd.Series(sw.samples['active_cases_ll']).idxmax()

In [None]:
runner.run(sw[best_idx])
# Plot the model outputs against the data
output_df = pd.DataFrame({
    "modelled": runner.get_derived_outputs_df()["active_cases"],
    "observed": data.active_cases
})
output_df.plot(kind='scatter')

In [None]:
sampled_df = pd.DataFrame(index=runner.get_derived_outputs_df().index)

for i in range(100):
    runner.run(sw[i*100])
    sampled_df[i] = runner.get_derived_outputs_df()['active_cases']

In [None]:
sampled_df.plot()