In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import Callable, Dict, List, Any
from functools import partial
from jaxopt import Bisection, FixedPointIteration
import sys
import os
import seaborn as sns 
from scipy.stats import nbinom

We want to try to explain the results in Figure 3 using the implicit function method.
We are considering the case of a negative binimial PGF which is given by
$$G(x;R_0,\kappa) =  \bigg[ 1 + \tfrac{R_0}\kappa(1-x) \bigg]^{-\kappa -1} $$

In [None]:
@jax.jit
def G_nbinom(x: float, R0, k) -> float:
    """
    Function to compute the negative binomial distribution.
    """
    return (1 + R0/k * (1-x))**(-k)



def nbinom_degree_sequence_non_vec(r0, k, N_max = 300):
    n = k
    p = k / (r0 + k) 
    #p = r0 / (r0 + k)  # Adjusted for the new parameterization
    dist = nbinom(n=n, p=p)  # Scipy parameterizes with alpha / R0 + alpha
    degree_sequence = dist.pmf(range(N_max + 1))
    return degree_sequence

def nbinom_degree_sequence(r0_row, k_row, N_max = 300):
    return jax.vmap(lambda r0,k, N_max: nbinom_degree_sequence_non_vec(r0, k, N_max))(r0_row, k_row, N_max)





In [None]:
#use the Jax fixed point finder to find the root of the function
@jax.jit
def find_fixed_point(R0: float, k: float, x0: float = 0.5, tol: float = 1e-5, maxiter: int = 100) -> float:
    """
    Function to find the fixed point of G_nbinom for given R0 and k values.
    """
    fixed_point_func = lambda x: G_nbinom(x, R0, k)
    fpi = FixedPointIteration(fixed_point_fun=fixed_point_func, maxiter=maxiter, tol=tol)
    return fpi.run(x0).params


In [None]:


class PGF:
    def __init__(self, poly_coef):
        """Initialize a PGF using polynomial coefficients."""
        self.exponents = jnp.arange(0, len(poly_coef))
        self.coefficients = jnp.array(poly_coef)
        
    def __call__(self, x):
        """Evaluate the PGF at x."""
        return jnp.sum(self.coefficients * jnp.power(x, self.exponents))
 
    def derivative(self, x):
        """Evaluate the derivative of the PGF at x."""
        return jnp.sum(self.exponents * self.coefficients * jnp.power(x, self.exponents - 1))


def implicit_diff_non_parametric(pgf, u):
    """Compute SCE using implicit differentiation for non-parametric case."""
    powers = u ** jnp.arange(len(pgf.coefficients))
    p_prime = pgf.derivative(u)
    dr_da = -powers / p_prime
    return dr_da

def fixed_point_solver_non_parametric(coeffs, x0=0.5, tol=1e-10, max_iter=100):
    """Find root using FixedPointIteration solver for non-parametric case."""
    def iter_func(x,coeffs):
        pgf = PGF(coeffs)
        return pgf(x)
    
    solver = FixedPointIteration(
        fixed_point_fun=iter_func,
        maxiter=max_iter,
        tol=tol,
        jit=True,
        implicit_diff=True,  # Added this parameter to enable implicit differentiation
    )
    return solver.run(x0,coeffs).params

def find_fixed_point_non_parametric_wrapper(R0,k):
    """
    Wrapper function to find the fixed point for non-parametric case.
    """
    coeffs = nbinom_degree_sequence(R0, k)
    x0 = 0.5
    tol = 1e-10
    max_iter = 100
    return fixed_point_solver_non_parametric(coeffs, x0, tol, max_iter)

def G_binom_param_to_non_param(R0, k, N_max=300):
    """
    Function to compute the negative binomial degree sequence.
    """
    coeffs = nbinom_degree_sequence(R0, k, N_max)
    pgf = PGF(coeffs)
    return pgf

In [None]:
# Calculate fixed points for a range of R0 and kappa values
R_0_list = jnp.linspace(0.8, 4, 3000)
kappa_list = jnp.linspace(0.01,10, 3000)

def compute_fixed_points(R0_list, k):
    """Compute fixed points for a list of R0 values with fixed k"""
    fixed_points = []
    for R0 in R0_list:
        fp = find_fixed_point(R0, k)
        fixed_points.append(fp)
    return jnp.array(fixed_points)

def implicit_diff_parametric(G: Callable, u: float, R0: float, k: float):
    """Compute derivative of fixed point u with respect to R0 using implicit differentiation"""
    dG_dR0 = jax.grad(G, argnums=1)(u, R0, k)
    dG_du = jax.grad(G, argnums=0)(u, R0, k) 
    # Return derivative du/dR0
    return -dG_dR0 / (dG_du - 1)

# Calculate derivatives of fixed points with respect to R0
@jax.jit
def compute_derivatives(fixed_point, R0, k,PGF_func = None):
    """Compute derivatives of fixed points with respect to R0 and k"""
    #compute the derivatives of G with respect to u, R0 and k
    dG_du, dG_dR0, dG_dk = jax.jacrev(G_nbinom, argnums=(0, 1, 2))(fixed_point, R0, k)
   
    # Calculate the derivatives as the gradient norm 
    sce_R0 = jnp.abs(dG_dR0 / (dG_du - 1))
    sce_k = jnp.abs(dG_dk / (dG_du - 1))
    return sce_R0, sce_k

# def compute_derivatives_non_parametric(fixed_point, R0, k):
#     """Compute derivatives of fixed points with respect to R0 and k"""
#     #compute the derivatives of G with respect to u, R0 and k
#     dG_du, dG_dR0, dG_dk = jax.jacrev(G_nbinom_, argnums=(0, 1, 2))(fixed_point, R0, k)
   
#     # Calculate the derivatives as the gradient norm 
#     sce_R0 = jnp.abs(dG_dR0 / (dG_du - 1))
#     sce_k = jnp.abs(dG_dk / (dG_du - 1))
#     return sce_R0, sce_k
    


# Calculate fixed points and their derivatives for a range of R0 and kappa values
# R_0_list = jnp.linspace(0.1, 2, 100)
# kappa_list = jnp.linspace(0.1, 2, 100)

# Create meshgrid for parameter space
R0_grid, k_grid = jnp.meshgrid(R_0_list, kappa_list, indexing='ij')

# Calculate fixed points using vectorization
@jax.vmap
def row_fixed_points(r0_row, k_row):
    return jax.vmap(lambda rr, kk: find_fixed_point(rr, kk))(r0_row, k_row)

fp = row_fixed_points(R0_grid, k_grid)

# Initialize arrays for derivatives
derivatives_R0 = jnp.zeros_like(fp)
derivatives_k = jnp.zeros_like(fp)


compute_derivatives_partial = partial(compute_derivatives, PGF_func=G_binom_param_to_non_param)

# Define a function to compute derivatives for a single point
@jax.jit
def compute_single_derivative(r0, k, fixed_pt):
    return compute_derivatives(fixed_pt, r0, k)

# Vectorize the computation over rows
@jax.vmap
def compute_row_derivatives(r0_row, k_row, fp_row):
    return jax.vmap(lambda r, k, f: compute_single_derivative(r, k, f))(r0_row, k_row, fp_row)



# Apply the vectorized function to compute derivatives
derivatives = compute_row_derivatives(R0_grid, k_grid, fp)
derivatives_R0 = derivatives[0]
derivatives_k = derivatives[1]

# data = {
#     'R0': np.repeat(R_0_list, len(kappa_list)),
#     'kappa': np.tile(kappa_list, len(R_0_list)),
#     'fixed_point': fp.flatten(),
#     'derivative_R0': derivatives_R0.flatten(),
#     'derivative_kappa': derivatives_k.flatten()
# }
# df = pd.DataFrame(data)

# df['S'] =  1- df['fixed_point']

In [None]:
# compute_derivatives(0.1,0.1,0.1,PGF_func = G_binom_param_to_non_param)

In [None]:
# Let's create a heatmap to visualize the sensitivity
plt.figure(figsize=(12, 10))
im = plt.pcolormesh(R0_grid, k_grid, derivatives_R0, cmap='viridis', shading='auto')
plt.colorbar(im, label='Sensitivity to R₀')
plt.xlabel('R₀')
#plt.yscale('log')
plt.ylabel('κ')
plt.title('Sensitivity of Fixed Points to R₀ Changes')
plt.tight_layout()
plt.show()

In [None]:
#Store all the resutls in a dataframe 
data = {
    'R0': np.repeat(R_0_list, len(kappa_list)),
    'kappa': np.tile(kappa_list, len(R_0_list)),
    'fixed_point': fp.flatten(),
    'derivative_R0': derivatives_R0.flatten(),
    'derivative_kappa': derivatives_k.flatten()
}
df = pd.DataFrame(data)
df['S'] =  1- df['fixed_point']

#maximum_sce_by_kappa = df.groupby('kappa').agg(max_sce = ('derivative_R0','max'))
argmax_sce_by_kappa = df.groupby('kappa').agg(argmax_sce = ('derivative_R0','idxmax'))

df_r0_kappa = df.pivot(index='kappa', columns='R0', values='derivative_R0')


r0_maxs = np.max(df_r0_kappa, axis=1) # R_0 maxes 
r0_maxs_indices = np.argmax(df_r0_kappa, axis=1)
kappa_vals = kappa_list[r0_maxs_indices]


k_maxs = np.max(df_r0_kappa, axis=0) # R_0 maxes 
k_maxs_indices = np.argmax(df_r0_kappa, axis=0)
r0_vals = R_0_list[k_maxs_indices]


fig,ax = plt.subplots(figsize=(12, 10))
ax.scatter(R_0_list, kappa_vals, c = 'C0', s = 1.5,label = 'Max SCE w/r/t R₀')
ax.scatter(r0_vals, kappa_list, c = 'C1', s = 1.5,label = 'Max SCE w/r/t κ')
#im = ax.pcolormesh(R0_grid, k_grid, derivatives_R0, cmap='Blues', shading='auto')
ax.set(yscale='log')
ax.axvline(1, color='black', linestyle='--', linewidth=0.5)
plt.legend()

# plt.scatter(R0_vals[20::5], alpha_vals[r0_maxs_indices[20::5]], linewidths= 1.5, facecolors='none', edgecolors="#2CD1C1")
#df.groupby('R0').agg(max_sce = ('derivative_R0','max'))


In [None]:
fig,ax = plt.subplots(1,1,figsize=(12, 10))
sns.heatmap(df.pivot(index='R0', columns='kappa', values='S').T, cmap='viridis', ax=ax)
ax.invert_yaxis()
#ax.set(yscale='log')

In [None]:
fig,ax = plt.subplots(1,1,figsize=(12, 10))
sns.heatmap(df.pivot(index='R0', columns='kappa', values='derivative_R0').T, cmap='Blues', ax=ax,vmax=0.1)
ax.invert_yaxis()
ax.set(yscale='log',ylim =(0.01,10))

In [None]:
fig,ax = plt.subplots(1,1,figsize=(12, 10))
sns.heatmap(df.pivot(index='R0', columns='kappa', values='derivative_kappa').T, cmap='Blues', ax=ax)
ax.invert_yaxis()
#ax.set(yscale='log',ylim =(0.01,10))

In [None]:
df['derivative_combo'] = np.sqrt(df['derivative_R0']**2 + df['derivative_kappa']**2)
fig,ax = plt.subplots(1,1,figsize=(12, 10))
sns.heatmap(df.pivot(index='R0', columns='kappa', values='derivative_combo').T, cmap='Blues', ax=ax)
ax.invert_yaxis()
ax.set(yscale='log',ylim =(0.01,10))