# Accelerated Python
## - Numpy, Numba, Jax
## - JIT theory and application

In [None]:
import numpy as np
import pandas as pd
from numba import jit as njit

## Python is a virtual machine

### ...but what does that mean?

In [None]:
# Our extremely simple calculating virtual machine
# It has 4 possible operations (represented as 'opcode' strings),
# and processes these in a ticker-tape like fashion

def trivial_vm(inputs):
    out_value = 0.0
    for opcode,data in inputs:
        if opcode=='+':
            out_value += data
        elif opcode == '-':
            out_value -= data
        elif opcode == '*':
            out_value *= data
        elif opcode == '/':
            out_value /= data
    return out_value

In [None]:
trivial_vm([('+',2.0),('*',3.0),('-',1.0),('/',5.0)])

## A motivating example

In [None]:
def accumulate_positive_flow_contributions(
    flow_rates: np.ndarray,
    comp_rates: np.ndarray,
    pos_flow_map: np.ndarray,
):
    """
    Fast accumulator for summing positive flow rates into their effects on compartments

    Args:
        flow_rates (np.ndarray): Flow rates to be accumulated
        comp_rates (np.ndarray): Output array of compartment rates
        pos_flow_map (np.ndarray): Array of src (flow), target (compartment) indices
    """
    for src, target in pos_flow_map:
        comp_rates[target] += flow_rates[src]

In [None]:
flow_rates = np.linspace(0,1.0,1000)
comp_rates = np.zeros(1000)
flow_map = np.array([np.random.choice(range(1000),1000,False),np.random.choice(range(1000),1000,False)],dtype=np.int32).T

In [None]:
%timeit accumulate_positive_flow_contributions(flow_rates,comp_rates,flow_map)

In [None]:
numba_accum = njit(accumulate_positive_flow_contributions)

In [None]:
%time numba_accum(flow_rates,comp_rates,flow_map)

In [None]:
%timeit numba_accum(flow_rates,comp_rates,flow_map)

In [None]:
def accumulate_natural(
    flow_rates: np.ndarray,
    comp_rates: np.ndarray,
    pos_flow_map: dict[int,np.ndarray]
):
    for target, srcs in pos_flow_map.items():
        comp_rates[target] += flow_rates[srcs].sum()

In [None]:
accumulate_natural(flow_rates, comp_rates, {0: [1,5,17]})

In [None]:
acc_nat_numba = njit(accumulate_natural)

%time acc_nat_numba(flow_rates, comp_rates, {0: [1,5,17]})

In [None]:
from jax import numpy as jnp, scipy as jsp
from jax import jit, grad, mask, lax, vmap

from jax.config import config
config.update("jax_enable_x64", True)

In [None]:
import scipy

In [None]:
def calculate_rates_for_interval(
        start_props: pd.core.series.Series, end_props: pd.core.series.Series, delta_t: float, strata: list[str],
        active_flows: dict
) -> dict:

    # Determine some basic characteristics
    n_strata = len(strata)
    n_params = len(active_flows)
    ordered_flow_names = list(active_flows.keys())

    # Create the function that we need to find the root of
    def function_to_zero(params):
        # params is a list ordered in the same order as ordered_flow_names

        # Create the transition matrix associated with a given set of transition parameters
        m = np.zeros((n_strata, n_strata))
        for i_row, stratum_row in enumerate(strata):
            for i_col, stratum_col in enumerate(strata):
                if i_row == i_col:
                    # Diagonal components capture flows starting from the associated stratum
                    relevant_flow_names = [f_name for f_name, f_ends in active_flows.items() if f_ends[0] == stratum_row]
                    for f_name in relevant_flow_names:
                        m[i_row, i_col] -= params[ordered_flow_names.index(f_name)]
                else:
                    # Off-diagonal components capture flows from stratum_col to stratum_row
                    for f_name, f_ends in active_flows.items():
                        if f_ends == (stratum_col, stratum_row):
                            m[i_row, i_col] = params[ordered_flow_names.index(f_name)]

        # Calculate the matrix exponential, accounting for the time interval width
        exp_mt = scipy.linalg.expm(m * delta_t)

        # Calculate the difference between the left and right terms of the equation
        diff = np.matmul(exp_mt, start_props) - end_props

        # Return the norm of the vector to make the minimised function a scalar function
        return scipy.linalg.norm(diff)
    
    return function_to_zero


In [None]:
strata = ["A","B","C"]
start_props = pd.Series([1.0,0.0,0.0],index=strata)
end_props = pd.Series([0.3,0.5,0.2],index=strata)

delta_t = 100.0
active_flows = {
    'a_to_b': ("A","B"),
    'b_to_c': ("B","C"),
    'c_to_a': ("C","A"),    
}

ftz = calculate_rates_for_interval(start_props, end_props, delta_t, strata, active_flows)

In [None]:
%timeit ftz((0.0,0.0,0.0))

In [None]:
def calc_rates_int_jax(
        start_props: pd.core.series.Series, end_props: pd.core.series.Series, delta_t: float, strata: list[str],
        active_flows: dict
) -> dict:

    # Determine some basic characteristics
    n_strata = len(strata)
    n_params = len(active_flows)
    ordered_flow_names = list(active_flows.keys())
    
    start_props = jnp.array(start_props)
    end_props = jnp.array(end_props)

    # Create the function that we need to find the root of
    @jit
    def function_to_zero(params):
        # params is a list ordered in the same order as ordered_flow_names

        # Create the transition matrix associated with a given set of transition parameters
        m = jnp.zeros((n_strata, n_strata))
        for i_row, stratum_row in enumerate(strata):
            for i_col, stratum_col in enumerate(strata):
                if i_row == i_col:
                    # Diagonal components capture flows starting from the associated stratum
                    relevant_flow_names = [f_name for f_name, f_ends in active_flows.items() if f_ends[0] == stratum_row]
                    for f_name in relevant_flow_names:
                        m.at[i_row, i_col].add(-params[ordered_flow_names.index(f_name)])
                else:
                    # Off-diagonal components capture flows from stratum_col to stratum_row
                    for f_name, f_ends in active_flows.items():
                        if f_ends == (stratum_col, stratum_row):
                            m = m.at[i_row, i_col].set(params[ordered_flow_names.index(f_name)])

        # Calculate the matrix exponential, accounting for the time interval width
        exp_mt = jsp.linalg.expm(m * delta_t)

        # Calculate the difference between the left and right terms of the equation
        diff = jnp.matmul(exp_mt, start_props) - end_props

        # Return the norm of the vector to make the minimised function a scalar function
        return jnp.linalg.norm(diff)
    
    return function_to_zero


In [None]:
ftzj = calc_rates_int_jax(start_props, end_props, delta_t, strata, active_flows)

In [None]:
%timeit ftz((0.0,0.0,0.0))

In [None]:
%timeit ftzj((0.0,0.0,0.0))

In [None]:
np.random.uniform(size=(1000,3)).shape

In [None]:
vparams = jnp.array(np.random.uniform(size=(500,3)))

In [None]:
f = jit(vmap(ftzj))

%time _ = f(vparams)

In [None]:
%timeit _ = f(vparams)

In [None]:
grad(ftzj)((0.1,0.0,0.0))

# Jax syntax and application

In [None]:
def condf_py(t):
    if t < 0.0:
        return t * -1.5
    else:
        return t ** 2.5

In [None]:
def condf_jax(t):
    return lax.cond(t<0.0, lambda t: t*-1.5, lambda t: t**2.5, t) 

In [None]:
condf_py(0.5)

In [None]:
condf_jax(0.5)

In [None]:
j_condf_jax = jit(condf_jax)
j_condf_jax(0.5,)

In [None]:
pd.Series([float(condf_py(x)) for x in np.linspace(-1.0,1.0,100)]).plot()

In [None]:
pd.Series([float(j_condf_jax(x)) for x in np.linspace(-1.0,1.0,100)]).plot()
pd.Series([float(grad(j_condf_jax)(x)) for x in np.linspace(-1.0,1.0,100)]).plot()

In [None]:
%time for x in np.linspace(-10,1.0,1000000): condf_py(x)

In [None]:
vec_func = jit(vmap(condf_jax))

In [None]:
mapped_space = jnp.linspace(-1.0,1.0,1000000)
_ = vec_func(mapped_space)

In [None]:
%timeit vec_func(mapped_space)

In [None]:
def condf_np(t):
    out_arr = np.empty_like(t)
    low = np.where(t < 0.0)
    out_arr[low] = t[low] * -1.5
    
    high = np.where(t >= 0.0)
    out_arr[high] = t[high] ** 2.5
    
    return out_arr

In [None]:
%timeit _ = condf_np(np.linspace(-1.0,1.0,1000000))

In [None]:
vec_grad = jit(vmap(grad(condf_jax)))
grad_res = vec_grad(mapped_space)

In [None]:
%timeit _ = vec_grad(mapped_space)

In [None]:
pd.Series(grad_res, index=mapped_space).plot()

# Real world applications, long term goals...<br>
#
#
#
#
#
#
#
#
#
#
#
# <b>Summer in Jax!

In [None]:
import summer

In [None]:
from dataclasses import dataclass

In [None]:
@dataclass
class Param:
    name: str

In [None]:
params = {"contact_rate": 1.0, "recovery_rate": 0.01}

In [None]:
def build_model_param():
    m = summer.CompartmentalModel([0,100],["S","I","R"],["I"])
    m.set_initial_population(dict(S=90,I=10,R=0))
    m.add_infection_frequency_flow("infection",Param("contact_rate"),"S","I")
    m.add_transition_flow("recovery",Param("recovery_rate"),"I","R")
    return m

In [None]:
def build_model_static():
    m = summer.CompartmentalModel([0,100],["S","I","R"],["I"])
    m.set_initial_population(dict(S=90,I=10,R=0))
    m.add_infection_frequency_flow("infection",params['contact_rate'],"S","I")
    m.add_transition_flow("recovery",params['recovery_rate'],"I","R")
    return m

In [None]:
from summer import Stratification

In [None]:
age_strata = list(range(0,80,5))
print(age_strata)

age_strat = Stratification(name="age", strata=age_strata, compartments=["S", "I", "R"])

# Build and run model with the stratification we just defined
m_strat = build_model_param()
m_strat.stratify_with(age_strat)

m_strat_static = build_model_static()
m_strat_static.stratify_with(age_strat)

In [None]:
m_static = build_model_static()

In [None]:
%timeit m_static.run()

In [None]:
%timeit m_strat_static.run()

In [None]:
m_static.get_outputs_df().plot()

In [None]:
m_param = build_model_param()

In [None]:
m_param._flows[1].param

In [None]:
from summer import CompartmentalModel, flows

In [None]:
def extract_summer(m: CompartmentalModel):
    @jit
    def get_rates(y, t, params):
        out_rates = jnp.zeros_like(y)
        for f in m._flows:
            src_idx = m.compartments.index(f.source)
            dst_idx = m.compartments.index(f.dest)
            
            if isinstance(f, flows.InfectionFrequencyFlow):
                S, I = y[src_idx], y[dst_idx]
                rate = S * params[f.param.name] * (I / (S+I))
            
            elif isinstance(f, flows.TransitionFlow):
                rate = y[src_idx] * params[f.param.name]
            
            out_rates = out_rates.at[dst_idx].add(rate)
            out_rates = out_rates.at[src_idx].add(-rate)
        
        return out_rates
    return get_rates

In [None]:
m = m_param
y0 = jnp.array(m.initial_population)
t = jnp.array(m.times)

get_rates = jit(extract_summer(m))

In [None]:
from jax.experimental.ode import odeint as jodeint

In [None]:
jres = jodeint(get_rates,y0,t,params)

In [None]:
pd.DataFrame(jres, columns=["S","I","R"]).plot()

In [None]:
%timeit _ = jodeint(get_rates,y0,t,params)

In [None]:
@jit
def full_jit_run(params):
    jres = jodeint(get_rates, y0, t, params)
    return jres
    #return jres[:,1].mean()
    

In [None]:
res = full_jit_run(params)

In [None]:
pd.DataFrame(res, columns=["S","I","R"]).plot()

In [None]:
%timeit _ = full_jit_run(params)

In [None]:
%time for x in np.linspace(0.0,1.0,1000): p = {'contact_rate': x, 'recovery_rate': x}; _ = full_jit_run(p)

In [None]:
vec_summer = jit(vmap(full_jit_run))

In [None]:
vparams = {
    'contact_rate': jnp.linspace(0,1.0,1000),
    'recovery_rate': jnp.linspace(0,0.1,1000),
}

In [None]:
%timeit _ = vec_summer(vparams)

In [None]:
@jit
def objf_jit_run(params):
    jres = jodeint(get_rates, y0, t, params)
    return jres[:,1].mean()
    

In [None]:
grad_summer = jit(grad(objf_jit_run))

In [None]:
grad_summer(params)

In [None]:
grad_summer_vmap = jit(vmap(grad(objf_jit_run)))

In [None]:
from jax import value_and_grad

In [None]:
vg_summer_vmap = jit(vmap(value_and_grad(objf_jit_run)))

In [None]:
vparams = {
    'contact_rate': jnp.linspace(0,1.0,1000),
    'recovery_rate': jnp.array([0.1] * 1000),
}

gres = grad_summer_vmap(vparams)

In [None]:
pd.DataFrame(gres,index=vparams['contact_rate']).plot()

In [None]:
vres, gres = vg_summer_vmap(vparams)

In [None]:
pd.DataFrame(gres,index=vparams['contact_rate']).plot()

In [None]:
vres.shape

In [None]:
pd.DataFrame(vres,index=vparams['contact_rate']).plot()