## How to transition code from notebooks into modules

In [None]:
from datetime import datetime

import numpy as np
import pandas as pd
from jax import numpy as jnp

from summer2 import CompartmentalModel, Stratification
from summer2.parameters import Parameter as param

from summer2.functions import time as stf
from summer2.population import calculate_initial_population


In [None]:
m = CompartmentalModel([0,100],["S","I","R"],"I", ref_date=datetime(2019,1,1))

m.set_initial_population({"S": 1000.0})

# Seed at seed_rate, from a parameterised start time, for a fixed length of 7 days
seed_func = stf.get_piecewise_function([param("seed_start"),param("seed_start")+7.0], [0.0,param("seed_rate"),0.0])

m.add_importation_flow("seed", seed_func, "I", split_imports=True)

m.add_infection_frequency_flow("infection", param("contact_rate"), "S", "I")
m.add_transition_flow("recovery", param("recovery_rate"), "I", "R")

m.request_output_for_flow("infection", "infection")

for c in ["S","I","R"]:
    m.request_output_for_compartments(c, c)

### Rewriting as a function

The first step in tidying up notebooks is usually identifying code that can be moved from __main__ (ie "directly in the notebook") to separate functions

In [None]:
# Let's take a look at our earlier definition of the seed function

seed_func = stf.get_piecewise_function([param("seed_start"),param("seed_start")+7.0], [0.0,param("seed_rate"),0.0])

# The initial "naive" function - it's just our code from before but moved into a function
# This can still be extremely valuable for things that we might want to call frequently (especially in an interactive context),
# and is much better than just cutting and pasting the original code each time we want to call it
def get_seed_function():
    return stf.get_piecewise_function([seed_start,seed_start+seed_duration], [0.0,param("seed_rate"),0.0])

# What if we wanted multiple seed functions for different strains?
# Here we move everything we might reasonably need into an argument
def get_seed_function(seed_start, seed_duration, seed_rate):
    return stf.get_piecewise_function([seed_start,seed_start+seed_duration], [0.0,seed_rate,0.0])

In [None]:
# Generate with some fixed values, no summer2 parameter logic in here, so convenient for testing..
seed_func_fixed = get_seed_function(50.0, 7.0, 1.0)

# Get the values of this function at each of the model times
# get_time_callable will convert a ComputeGraph Function into a normal python function whose first argument is time
values = stf.get_time_callable(seed_func_fixed)(m.times)

pd.Series(values,m.times).plot()

In [None]:
# A more complicated example - but a useful one!
seed_func_strain1 = get_seed_function(param("strain1_seed_start"), 7.0, param("strain1_seed_rate"))
seed_func_strain2 = get_seed_function(param("strain2_seed_start"), 7.0, param("strain2_seed_rate"))

seed_params = {
    "strain1_seed_start": 20.0,
    "strain1_seed_rate": 1.0,
    "strain2_seed_start": 50.0,
    "strain2_seed_rate": 2.5
}

pd.DataFrame({
    "strain1": stf.get_time_callable(seed_func_strain1)(m.times, seed_params),
    "strain2": stf.get_time_callable(seed_func_strain2)(m.times, seed_params),
}, index=m.get_epoch().index_to_dti(m.times)).plot()

### A note on Epochs

In [None]:
# Epochs provide a means of converting between calendar time and model time

epoch = m.get_epoch()
epoch

In [None]:
epoch.number_to_datetime(55)

In [None]:
epoch.index_to_dti(m.times)

### ...back to running the model

In [None]:
parameters = {
    "seed_start": epoch.datetime_to_number(datetime(2019, 2,1)),
    "seed_rate": 2.0,
    "contact_rate": 0.4,
    "recovery_rate": 0.1
}

In [None]:
m.run(parameters)

In [None]:
m.get_derived_outputs_df().plot()

In [None]:
from learningmodel import model

In [None]:
m = model.build_model()
m

In [None]:
# If we're still working on this model code (but not enough to justify keeping it in the notebook),
# we'll probably want to keep reload handy...

In [None]:
from importlib import reload
reload(model)

In [None]:
m = model.build_model()

In [None]:
m.run(model.DEFAULT_PARAMS)

In [None]:
m.get_derived_outputs_df().plot()

In [None]:
pd.read_csv("learningmodel/data/infections.csv")

In [None]:
cd ..

In [None]:
pd.read_csv("learningmodel/data/infections.csv")

In [None]:
from learningmodel.helpers import DATA_PATH

In [None]:
pd.read_csv(DATA_PATH / "infections.csv", parse_dates=["date"], index_col="date").plot()