# Parameter Aware Summer

In [None]:
import numpy as np
import pandas as pd

pd.options.plotting.backend = "plotly"

from summer import CompartmentalModel, Stratification, StrainStratification
from summer.solver import SolverType
from summer.runner.jax import build_model_with_jax
from summer.parameters import Parameter, Function, Time, ComputedValue
from summer.adjust import Overwrite

In [None]:
param = Parameter
func = Function

In [None]:
parameters = {
    "age_split": {"young": 0.8, "old": 0.2},
    "contact_rate": 0.1,
    "strain_infect_adjust.wild_type": 1.1,
    "strain_infect_adjust.variant1": 0.9,
    "strain_infect_adjust.variant2": 1.3,
}

def get_ipop_dist(total, infected_prop):
    num_infected = total * infected_prop
    return {"S": total - num_infected, "I": num_infected, "R": 0}

def build_model(**kwargs):
    model = CompartmentalModel((0, 100), ["S", "I", "R"], ["I"], takes_params=True)

    model.set_initial_population(get_ipop_dist(1000.0, 0.4))

    strat = Stratification("age", ["young", "old"], ["S", "I", "R"])
    strat.set_population_split(Parameter("age_split"))
    model.stratify_with(strat)

    def scaled_contact_rate(time, base_rate):
        return base_rate + 0.5*(time/100.0)

    contact_rate = param("contact_rate")
    #contact_rate = func(scaled_contact_rate, [Time, param("contact_rate")])

    model.add_infection_frequency_flow("infection", contact_rate, "S", "I")
    model.add_transition_flow("recovery", 0.1, "I", "R")

    strain_strat = StrainStratification("strain", ["wild_type", "variant1", "variant2"], ["I"])

    strain_strat.add_infectiousness_adjustments(
        "I",
        {
            "wild_type": Parameter("strain_infect_adjust.wild_type"),
            "variant1": Overwrite(Parameter("strain_infect_adjust.variant1")),
            "variant2": Overwrite(Parameter("strain_infect_adjust.variant2")),
        },
    )

    model.stratify_with(strain_strat)

    model.add_death_flow("death_after_infection", 0.01, "I")

    return model



In [None]:
model, jaxrun = build_model_with_jax(build_model)

model.run(parameters=parameters)
joutputs = jaxrun(parameters)

np.testing.assert_allclose(joutputs, model.outputs, atol=1e-5)


In [None]:
model.get_outputs_df().plot()

In [None]:
pd.DataFrame(joutputs-model.outputs, columns=model.compartments).plot()

In [None]:
%time model.run(parameters=parameters)

In [None]:
%time jaxrun(parameters)

In [None]:
from jax import jit

In [None]:
jitrun = jit(jaxrun)

In [None]:
%time jitrun(parameters)

In [None]:
%time jitrun(parameters)

In [None]:
# Do a bunch of runs with varying parameters
def run_lots(n, parameters):
    param_cur = parameters.copy()
    outputs = []
    for x in range(1,n):
        param_cur["contact_rate"] = n/x
        cur_outputs = jitrun(parameters=param_cur)
        outputs.append(cur_outputs)
    return outputs

In [None]:
%time outputs = run_lots(10000, parameters)