## New summer things - there are a lot of them

Google colab demo of some new idioms, processes, libraries etc

In [None]:
#!pip install summerepi==4.0.1a
#!pip install numpyro

# graphviz is installed already, but need lib too
#!apt install libgraphviz-dev
#!pip install pygraphviz

In [None]:
from summer2 import CompartmentalModel
from summer2.parameters import Parameter, Function, ModelVariable, Time, Data
from summer2.experimental.model_builder import ModelBuilder
from summer2.experimental import model_builder as mb 
#from autumn.model_features import functional

import numpy as np
from jax import numpy as jnp

In [None]:
import pandas as pd

pd.options.plotting.backend = "plotly"

In [None]:
from numpyro.distributions import constraints
from numbers import Real

In [None]:
constraints.non_negative = constraints.interval(0.0, np.inf)

In [None]:
pclass = mb.parameter_class

In [None]:
from pydantic import BaseModel as PydanticBaseModel

class BaseModel(PydanticBaseModel, mb.ParamStruct):
    class Config:
        arbitrary_types_allowed = True

In [None]:
from typing import List, Dict

class SeedParams(BaseModel):
    peak_t: pclass(desc="Time of peak value")
    spread: pclass(constraints.positive, "Spread in time units")
    scale: pclass(constraints.positive, "Total value to integrate to")

class Strain(BaseModel):
    seed: SeedParams
        
class Matrices(BaseModel):
    matrix_a: np.ndarray
    matrix_b: np.ndarray
    ratio: pclass(constraints.unit_interval, "Matrix blending proportion")
        
class BaseEpiParams(BaseModel):
    contact_rate: pclass(constraints.non_negative, "Contact rate")
    progression_rate: pclass(constraints.non_negative, "Progression rate")
    recovery_rate: pclass(constraints.non_negative, "Recovery rate")

class ModelOptions(BaseModel):
    base_epi: BaseEpiParams
    strains: Dict[str, Strain]
    matrices: Matrices


In [None]:
import pickle
mm_dict = pickle.load(open("MYS_matrices.pkl",'rb'))

In [None]:
pdict = {
    "base_epi": {
        "contact_rate": 1.0,
        "progression_rate": 0.33,
        "recovery_rate": 0.2, 
    },
    "strains": {
        "wild_type": {
            "seed": {
                "peak_t": 10.0,
                "spread": 14.0,
                "scale": 1.0,
            }
        },
    },
    "matrices": {
        "matrix_a": mm_dict["home"],
        "matrix_b": mm_dict["other_locations"],
        "ratio": 0.5
    }
}

In [None]:
from summer2.stratification import StrainStratification, Stratification

In [None]:
def triangular_seed(t, peak_t, spread, scale):
    hspread = spread * 0.5
    hspread_inv = 1.0 / hspread
    dist = jnp.abs(peak_t - t)
    val = (1.0 - dist * hspread_inv) * scale
    val = val * hspread_inv
    return jnp.where(dist > hspread, 0.0, val)

In [None]:
def apply_strain_stratification(builder: ModelBuilder, strains: Dict[str, Strain]):
    model = builder.model
    strat = StrainStratification("strain", list(strains), ["E","I"])
    
    model.stratify_with(strat)
    
    for name, strain in strains.items():
        seed_func = builder.get_mapped_func(triangular_seed, strain.seed, {'t': Time})
        model.add_importation_flow(f"seed_{name}", seed_func, "E", True, {"strain": name},16)

In [None]:
def apply_age_stratification(builder: ModelBuilder, mixing_matrix):
    model = builder.model
    
    age_groups = [str(age) for age in range(0,80,5)]
    
    strat = Stratification("age", age_groups, ["S","E","I","R"])
    
    # Let's just say there are 5 times as many 0-5 year olds as 75+,
    # and linearly interpolate in between
    age_dist = np.linspace(5,1,16)
    age_dist = age_dist / age_dist.sum()

    strat.set_population_split({age:age_dist[i] for i, age in enumerate(age_groups)})
    
    strat.set_mixing_matrix(mixing_matrix)
    rec_scale = np.linspace(1.5,0.5,16)
    
    strat.set_flow_adjustments("recovery", {age:rec_scale[i] for i, age in enumerate(age_groups)})
    
    model.stratify_with(strat)

In [None]:
def get_blended_mm(mat_a, mat_b, ratio):
    return mat_a * ratio + mat_b * (1.0-ratio)

In [None]:
def create_builder(pdict):
    m = CompartmentalModel([0,300],["S","E","I","R"],["I"])
    m.set_initial_population({"S": 30000000.0, "E": 0, "I": 0, "R": 0})
    
    builder = ModelBuilder(pdict, ModelOptions)
    builder.set_model(m)
    params = builder.params
    
    
    
    mixing_matrix = Function(get_blended_mm,
                             [Data(params.matrices.matrix_a), 
                              Data(params.matrices.matrix_b), 
                              Parameter("matrices.ratio")])
    
    m.add_infection_frequency_flow("infection", params.base_epi.contact_rate, "S", "E")
    m.add_transition_flow("progression", params.base_epi.progression_rate, "E", "I")
    m.add_transition_flow("recovery", params.base_epi.recovery_rate, "I", "R")

    apply_age_stratification(builder, mixing_matrix)    
    apply_strain_stratification(builder, params.strains)
    
    m.request_output_for_flow("progression","progression")
    
    return builder

In [None]:
b = create_builder(pdict)

In [None]:
b.model.get_input_parameters()

In [None]:
defp = b.get_default_parameters()

defp

In [None]:
runner = b.model.get_runner(defp)

In [None]:
runner.run(defp)

In [None]:
runner._run_func(defp)

In [None]:
from jax import random

import numpyro
import numpyro.distributions as dist

In [None]:
target_param_updates = {
    "base_epi.recovery_rate": 0.2,
    "base_epi.progression_rate": 0.1,
    "base_epi.contact_rate": 0.5,
    "matrices.ratio": 0.32
}

results = runner.run(defp | target_param_updates)

targets = {
    "progression": results["derived_outputs"]["progression"],
}

In [None]:
pd.DataFrame(targets).plot()

In [None]:
priors = {
    "base_epi.recovery_rate": dist.Uniform(0.1, 0.3),
    "base_epi.progression_rate": dist.Uniform(0.05, 0.3),
    "base_epi.contact_rate": dist.Uniform(0.01, 0.1),
    "strains.wild_type.seed.peak_t": dist.Uniform(0,100),
    "matrices.ratio": dist.Uniform(0.0,1.0)
}

priors = {k:v for k,v in priors.items() if k in target_param_updates}
for k in target_param_updates:
    assert(k in priors), k
priors

In [None]:
def log_uniform(low,high):
    log_low = np.log(low)
    log_high = np.log(high)
    d = dist.TransformedDistribution(dist.Uniform(log_low,log_high),dist.transforms.ExpTransform())
    return d

In [None]:
rng_key = random.PRNGKey(777)

In [None]:
d = log_uniform(0.01,1.0)

In [None]:
pd.Series(d.sample(rng_key, (10000,))).hist()

In [None]:
#transforms = numpyro.distributions.transforms

In [None]:
for key, prior_dist in priors.items():
    print(key)
    plow, phigh = prior_dist.low, prior_dist.high
    param_obj = mb.find_obj_from_key(key, b.params)
    cfunc = param_obj.constraint
    is_constrained = (cfunc(plow) and cfunc(prior_dist.high))
    msg = f"Prior for {key} ({prior_dist},[{plow},{phigh}]) samples outside of constraint {cfunc}"
    assert(is_constrained), msg

In [None]:
def model():
    param_updates = {k:numpyro.sample(k, priors[k]) for k in priors}
    
    results = runner._run_func(defp | param_updates)

    for k,obs in targets.items():
        modelled = results["derived_outputs"][k]
        k_fac = dist.TruncatedNormal(modelled, jnp.std(obs),low=0.0).log_prob(obs).sum()
        numpyro.factor(k, k_fac)
        numpyro.deterministic(f"{k}_ll", k_fac)
        numpyro.deterministic(f"{k}_sum", jnp.sum(modelled))

In [None]:
from numpyro.infer import MCMC

In [None]:
sa_kernel = numpyro.infer.SA(model)

In [None]:
mcmc = MCMC(sa_kernel, num_chains=4, num_samples=20000,num_warmup=10000)
rng_key = random.PRNGKey(1)
mcmc.run(rng_key, extra_fields=("accept_prob",))

In [None]:
mcmc.print_summary(exclude_deterministic=False)

In [None]:
import arviz as az

In [None]:
az_data = az.from_numpyro(mcmc)

In [None]:
_ = az.plot_trace(az_data, compact=False, figsize=(15,20))

In [None]:
az.plot_posterior(az_data)