In [None]:
try:
  import google.colab
  IN_COLAB = True
  %pip install summerepi2
except:
  IN_COLAB = False

In [None]:
import pandas as pd
from jax import numpy as jnp
from typing import Dict
import copy

from summer2 import CompartmentalModel
from summer2.parameters import Parameter as param
from summer2.parameters import Function as func
from summer2.parameters import Time

pd.options.plotting.backend = "plotly"

In [None]:
def get_series_comps_model(
    parameters: Dict,
    n_comps: 1,
) -> CompartmentalModel:
    
    compartments = [f"comp_{i_comp}" for i_comp in range(n_comps)]
    analysis_times = (0, 20)
    model = CompartmentalModel(
        times=analysis_times,
        compartments=compartments,
        infectious_compartments=[],
    )
    
    # Start everyone from the first compartment
    model.set_initial_population(
        distribution={"comp_0": 1.}
    )
    
    # Adjust the transition rate for the multiple compartments
    progression_rate = 1. / param("transition_rate") * n_comps
    
    # Join up all the sequential compartments with transition flows
    for i_comp in range(n_comps - 1):
        model.add_transition_flow(
            f"progression_{i_comp}", 
            fractional_rate=progression_rate, 
            source=f"comp_{i_comp}", 
            dest=f"comp_{i_comp + 1}"
        )
    
    return model

In [None]:
parameters = {"transition_rate": 10.}
outputs = {}
n_comp_requests = [2, 3, 4] + list(range(6, 10, 2)) + list(range(10, 105, 5))
for i in n_comp_requests:
    transition_model = get_series_comps_model(parameters, i)
    transition_model.run(parameters)
    comp_sizes = transition_model.get_outputs_df()
    outputs[i] = comp_sizes[f"comp_{i - 1}"]

In [None]:
outputs_df = pd.DataFrame(outputs)
outputs_df.plot(title="Last compartment size")

In [None]:
outputs_df.diff().plot(title="Transition time")