In [None]:
from autumn.infrastructure.remote import springboard

In [None]:
from datetime import datetime
import numpy as np

In [None]:
from summer2 import CompartmentalModel
from summer2.parameters import Parameter as param

In [None]:
def get_model():
    m = CompartmentalModel([0,100], ["S", "I", "R"], "I", ref_date=datetime(2001,1,1))
    m.set_initial_population({"S": 990.0, "I": 10.0})
    m.add_infection_frequency_flow("infection", param("contact_rate"), "S", "I")
    m.add_transition_flow("recovery", param("recovery_rate"), "I", "R")
    incidence = m.request_output_for_flow("incidence", "infection")
    m.request_function_output("notifications", incidence * param("cdr"))
    m.set_default_parameters({"contact_rate": 0.4, "recovery_rate": 0.1, "cdr": 0.2})
    return m


In [None]:
m = get_model()

m.run({"contact_rate": 0.5, "recovery_rate": 0.4})
do_def = m.get_derived_outputs_df()
obs_clean = do_def["incidence"].iloc[0:50]
obs_noisy = obs_clean * np.exp(np.random.normal(0.0,0.2,len(obs_clean)))
obs_clean.plot()
obs_noisy.plot(style='.')



In [None]:
from estival import targets as est
from estival import priors as esp
from estival.model import BayesianCompartmentalModel

In [None]:
# Specify a Truncated normal target with a free dispersion parameter
targets = [
    est.TruncatedNormalTarget("incidence", obs_noisy, (0.0,np.inf),
        esp.UniformPrior("incidence_dispersion",(0.1, obs_noisy.max()*0.1)))
]

# Uniform priors over our 2 model parameters
priors = [
    esp.UniformPrior("contact_rate", (0.01,1.0)),
    esp.TruncNormalPrior("recovery_rate", 0.5, 0.2, (0.01,1.0)),
]

In [None]:
from estival.calibration import pymc as epm
import pymc as pm
import arviz as az

In [None]:
def calibrate_model(targets, priors, draws: int, chains: int):
    
    # Build our model as specified above
    m = get_model()
    defp = m.get_default_parameters()
    
    # Build the BCM based on user supplied targets and priors
    bcm = BayesianCompartmentalModel(m, defp, priors, targets)
    
    # Run for the specified number of draws, over the specified number of chains/cores
    with pm.Model() as model:
        variables = epm.use_model(bcm)
        idata = pm.sample(step=[pm.DEMetropolis(variables)], draws=draws, tune=0,cores=chains,chains=chains)
        
    return idata

In [None]:
def remote_calibration_task(bridge: springboard.task.TaskBridge, targets, priors, draws, chains):
    
    import multiprocessing as mp
    mp.set_start_method('forkserver')
    
    bridge.logger.info(f"Calibrating {chains} chains for {draws} draws")
    
    idata = calibrate_model(targets, priors, draws, chains)
    idata.to_netcdf(bridge.out_path / "idata.nc")

    summary = az.summary(idata)
    bridge.logger.info(summary["r_hat"])
    
    bridge.logger.info("Calibration complete")

In [None]:
N_CHAINS = 4

In [None]:
# Get a compute machine to do.. some computation.
mspec = springboard.EC2MachineSpec(N_CHAINS, 4, "compute")

# Wrap our function; the first argument is supplied by the runner, the rest must be kwargs

task_kwargs = {
    "targets": targets,
    "priors": priors,
    "draws": 2000,
    "chains": N_CHAINS
}

tspec = springboard.TaskSpec(remote_calibration_task, task_kwargs)

In [None]:
run_path = springboard.launch.get_autumn_project_run_path("testing", "exampleworld", "calibration_longalarm")
run_path

In [None]:
# This should be fairly resilient to failure, however if you do receive any kind of error or Exception here,
# please tell David ASAP!
# It is possible you will get some "SSH connection will waiting, retrying" messages; please also report these
# - they're not a 'failture', just AWS being a bit slow...

runner = springboard.launch.launch_synced_autumn_task(tspec, mspec, run_path,"springboard_refactor")

In [None]:
runner.instance

In [None]:
# Use 'tail' to report the last (n, default=10) lines of command output
# This is only valid while the task is actually running, and is included here to demonstrate its use in debugging 
# Typical runs would just use the wait method instead (see cell below)

print(runner.tail())

In [None]:
print(runner.top("+%CPU"))

In [None]:
runner.s3.get_status()

In [None]:
runner.wait()

In [None]:
# This is the complete 'command line' output from the run
print(runner.get_iodump())

In [None]:
# This should print out a log containing the text from our wrapped task...
print(runner.get_log())

In [None]:
from autumn.core.runs import ManagedRun

In [None]:
mr = ManagedRun(run_path)

In [None]:
mr.remote.list_contents()

In [None]:
for f in mr.remote.list_contents():
    mr.remote.download(f)

In [None]:
import arviz as az

In [None]:
idata = az.from_netcdf(mr.local_path / "output/idata.nc")

In [None]:
az.summary(idata)

In [None]:
az.plot_trace(idata, compact=False);

In [None]:
print(open(mr.local_path / ".taskmeta/task_spec.yml").read())