## New summer things - there are a lot of them

Google colab demo of some new idioms, processes, libraries etc

In [None]:
!pip install summerepi==4.0.1a
!pip install numpyro

# graphviz is installed already, but need lib too
!apt install libgraphviz-dev
!pip install pygraphviz

In [None]:
from summer import jaxify
jaxify.set_using_jax(True)

In [None]:
from summer import CompartmentalModel
from summer.parameters.params import Parameter, Function, ComputedValue, ModelVariable, Time
from summer.experimental.model_builder import ModelBuilder
from summer.experimental import model_builder as mb 
#from autumn.model_features import functional

import numpy as np
from jax import numpy as jnp

In [None]:
import pandas as pd

pd.options.plotting.backend = "plotly"

In [None]:
from numpyro.distributions import constraints
from numbers import Real

In [None]:
constraints.non_negative = constraints.interval(0.0, np.inf)

In [None]:
pclass = mb.parameter_class

In [None]:
from pydantic import BaseModel as PydanticBaseModel

class BaseModel(PydanticBaseModel):
    class Config:
        arbitrary_types_allowed = True

In [None]:
from typing import List, Dict

class SeedParams(BaseModel):
    peak_t: pclass(description="Time of peak value")
    spread: pclass(constraints.positive, description="Spread in time units")
    scale: pclass(constraints.positive, description="Total value to integrate to")

class Strain(BaseModel):
    seed: SeedParams
        
class Matrices(BaseModel):
    matrix_a: jaxify.Array
    matrix_b: jaxify.Array
    ratio: pclass(constraints.unit_interval, description="Matrix blending proportion")
        
class BaseEpiParams(BaseModel):
    contact_rate: pclass(constraints.non_negative, "Contact rate")
    progression_rate: pclass(constraints.non_negative, "Progression rate")
    recovery_rate: pclass(constraints.non_negative, "Recovery rate")

class ModelOptions(BaseModel):
    base_epi: BaseEpiParams
    strains: Dict[str, Strain]
    matrices: Matrices


In [None]:
!curl https://raw.githubusercontent.com/monash-emu/AuTuMN/master/notebooks/capacity_building/malaysia/MYS_matrices.pkl > MYS_matrices.pkl

In [None]:
import pickle
mm_dict = pickle.load(open("MYS_matrices.pkl",'rb'))

In [None]:
pdict = {
    "base_epi": {
        "contact_rate": 1.0,
        "progression_rate": 0.33,
        "recovery_rate": 0.2, 
    },
    "strains": {
        "wild_type": {
            "seed": {
                "peak_t": 10.0,
                "spread": 14.0,
                "scale": 1.0,
            }
        },
    },
    "matrices": {
        "matrix_a": mm_dict["home"],
        "matrix_b": mm_dict["other_locations"],
        "ratio": 0.5
    }
}

In [None]:
from summer.stratification import StrainStratification

In [None]:
def triangular_seed(t, peak_t, spread, scale):
    hspread = spread * 0.5
    hspread_inv = 1.0 / hspread
    dist = jnp.abs(peak_t - t)
    val = (1.0 - dist * hspread_inv) * scale
    val = val * hspread_inv
    return jnp.where(dist > hspread, 0.0, val)

In [None]:
def apply_strain_stratification(builder: ModelBuilder, strains: Dict[str, Strain]):
    model = builder.model
    strat = StrainStratification("strain", list(strains), ["E","I"])
    
    model.stratify_with(strat)
    
    for name, strain in strains.items():
        seed_func = builder.get_mapped_func(triangular_seed, strain.seed, {'t': Time})
        model.add_importation_flow(f"seed_{name}", seed_func, "E", True, {"strain": name},16)

In [None]:
from summer import Stratification

In [None]:
def apply_age_stratification(builder: ModelBuilder):
    model = builder.model
    
    age_groups = [str(age) for age in range(0,80,5)]
    
    strat = Stratification("age", age_groups, ["S","E","I","R"])
    
    # Let's just say there are 5 times as many 0-5 year olds as 75+,
    # and linearly interpolate in between
    age_dist = np.linspace(5,1,16)
    age_dist = age_dist / age_dist.sum()

    strat.set_population_split({age:age_dist[i] for i, age in enumerate(age_groups)})
    
    strat.set_mixing_matrix(Parameter("mixing_matrix"))
    rec_scale = np.linspace(1.5,0.5,16)
    
    strat.set_flow_adjustments("recovery", {age:rec_scale[i] for i, age in enumerate(age_groups)})
    
    model.stratify_with(strat)

In [None]:
def get_blended_mm(mat_a, mat_b, ratio):
    return mat_a * ratio + mat_b * (1.0-ratio)

In [None]:
def create_builder(pdict):
    m = CompartmentalModel([0,300],["S","E","I","R"],["I"],takes_params=True)
    m.set_initial_population({"S": 30000000.0, "E": 0, "I": 0, "R": 0})
    
    builder = ModelBuilder(m, pdict, ModelOptions)
    params = builder.params
    
    builder.add_output("mixing_matrix", Function(get_blended_mm,
                                         [Parameter("matrices.matrix_a"),
                                          Parameter("matrices.matrix_b"),
                                          Parameter("matrices.ratio")]))
    
    m.add_infection_frequency_flow("infection", Parameter("base_epi.contact_rate"), "S", "E")
    m.add_transition_flow("progression", builder.get_param(params.base_epi.progression_rate), "E", "I")
    m.add_transition_flow("recovery", builder.get_param("base_epi.recovery_rate"), "I", "R")

    apply_age_stratification(builder)    
    apply_strain_stratification(builder, params.strains)
    
    m.request_output_for_flow("progression","progression")
    
    return builder

In [None]:
from computegraph import ComputeGraph

In [None]:
b = create_builder(pdict)

In [None]:
b.model

In [None]:
list(b.model.get_input_parameters())

In [None]:
b.input_graph

In [None]:
ComputeGraph(b.input_graph).draw()

In [None]:
runner = b.get_jax_runner()

In [None]:
from jax import random

import numpyro
import numpyro.distributions as dist

In [None]:
target_param_updates = {
    "base_epi.recovery_rate": 0.2,
    #"base_epi.progression_rate": 0.1,
    #"base_epi.contact_rate": 0.03,
    "matrices.ratio": 0.1
}

results = runner(target_param_updates)

targets = {
    "progression": results["derived_outputs"]["progression"],
}

In [150]:
cidx = slice(16,32,15)

In [None]:
pd.DataFrame(results["outputs"][:,cidx],columns=b.model.compartments[cidx]).plot()

In [None]:
pd.DataFrame(results["outputs"][:,cidx],columns=b.model.compartments[cidx]).plot()

In [None]:
pd.DataFrame(targets).plot()

In [None]:
priors = {
    "base_epi.recovery_rate": dist.Uniform(0.1, 0.3),
    "base_epi.progression_rate": dist.Uniform(0.05, 0.3),
    "base_epi.contact_rate": dist.Uniform(0.01, 0.1),
    "strains.wild_type.seed.peak_t": dist.Uniform(0,100),
    "matrices.ratio": dist.Uniform(0.0,1.0)
}

priors = {k:v for k,v in priors.items() if k in target_param_updates}
for k in target_param_updates:
    assert(k in priors), k
priors

In [None]:
for key, prior_dist in priors.items():
    print(key)
    plow, phigh = prior_dist.low, prior_dist.high
    param_obj = mb.find_obj_from_key(key, b.params)
    cfunc = param_obj.constraint
    is_constrained = (cfunc(plow) and cfunc(prior_dist.high))
    msg = f"Prior for {key} ({prior_dist},[{plow},{phigh}]) samples outside of constraint {cfunc}"
    assert(is_constrained), msg

In [None]:
def model():
    param_updates = {k:numpyro.sample(k, priors[k]) for k in priors}
    
    results = runner(param_updates)

    for k,obs in targets.items():
        modelled = results["derived_outputs"][k]
        numpyro.factor(k, dist.TruncatedNormal(modelled, jnp.std(obs),low=0.0).log_prob(obs).sum())

In [None]:
import arviz as az

In [None]:
from numpyro.infer import MCMC

In [None]:
sa_kernel = numpyro.infer.SA(model)

In [None]:
mcmc = MCMC(sa_kernel, num_chains=2, num_samples=10000,num_warmup=1000,thinning=1)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key)

In [None]:
mcmc.print_summary()

In [None]:
target_param_updates

In [None]:
samples = mcmc.get_samples(True)

In [None]:
arviz_data = az.from_dict(samples)

In [None]:
_ = az.plot_trace(arviz_data, compact=False, figsize=(15,10))

In [None]:
def thin_samples(samples, thinning):
    thinned = {}
    for k, v in samples.items():
        thinned[k] = samples[k][:,::thinning].flatten()
    return thinned

In [None]:
thinned_samples = thin_samples(samples,20)

In [None]:
ll = numpyro.infer.util.log_likelihood(model,thinned_samples)

In [None]:
def sample_model(samples):
    params_df = pd.DataFrame(samples)
    out = []
    for i, row in params_df.iterrows():
        res = runner(dict(row))
        out.append(res['derived_outputs'])
    return out

In [None]:
runs = sample_model(thinned_samples)

In [None]:
data = jnp.array([samp["progression"] for samp in runs])
quantiles = jnp.array((0.01,0.25,0.5,0.75,0.99))
q = jnp.quantile(data, quantiles,axis=0).T

In [None]:
q_df = pd.DataFrame(np.array(q), columns=quantiles)
q_df['target'] = targets["progression"]

In [None]:
q_df.plot()

In [None]:
inspect_df = pd.DataFrame(thinned_samples)

In [None]:
ll_tot = ll["progression"]

In [None]:
inspect_df['ll'] = ll_tot

In [None]:
from plotly import express as px

In [None]:
px.scatter(inspect_df, x="base_epi.recovery_rate", y="matrices.ratio", hover_data=inspect_df.columns, color="ll" )