# Simulating Onset Age Distribution of anti-GABABR Autoimmune Encephalitis from Published Summary Statistics


## Import required library


In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import lognorm, weibull_min, gamma, genextreme # Importing necessary libraries for statistical distributions 
from scipy.optimize import minimize # Optimization for parameter fitting
from scipy.stats import probplot # Probability plot for visual assessment
from sklearn.metrics import mean_squared_error # Mean Squared Error for goodness-of-fit
from scipy.stats import gaussian_kde # Kernel Density Estimation for smooth CDF

## Sensitivity analysis with Monte Carlo simulation

### Purpose

The current sensitivity analysis based on a fixed ±10% grid has notable limitations. 

To address these, I plan to implement Monte Carlo simulation using fitted distribution parameters, which offers three key advantages:

- Continuous uncertainty representation rather than relying on only low, central, and high values.
- Faster and smoother calculations through the cumulative distribution function (CDF), without the need for inner resampling.
- More stable confidence intervals and compatibility with tornado analysis for identifying key drivers.

### Procedure

1. Draw N parameter triplets from continous priors distributions.
2. Use CDF differences to get exact band probabilities per draw.
3. Aggregate acroos draws to get mean, median and 95% CI. 

### Code for implementation

In [41]:
import numpy as np
import pandas as pd
from scipy.stats import gengamma, norm

In [42]:
# The fitted model object from previous analysis
result_gengamma


  message: CONVERGENCE: REL_REDUCTION_OF_F_<=_FACTR*EPSMCH
  success: True
   status: 0
      fun: 0.10243077858152738
        x: [ 2.661e+01  1.583e+00  8.409e+00]
      nit: 88
      jac: [-1.463e-02  1.069e+00 -8.571e-02]
     nfev: 456
     njev: 114
 hess_inv: <3x3 LbfgsInvHessProduct with dtype=float64>

In [53]:
print(result_gengamma.hess_inv)

<3x3 LbfgsInvHessProduct with dtype=float64>


In [55]:
# Get the MLE
theta_hat = result_gengamma.x

print(theta_hat)

[26.61103626  1.58349596  8.40915197]


According to the output of fitted Generalized Gamma model, the key parameters of the distribution inlcudes:
- mean = 26.61
- log SD = 1.58
- Q = 8.41

The parameter mean shifts the curve on the log-time axis; scale controls spread of log times; and Q governs skewness and hazard shape, determining sub-family.

Compute the Hessian numerically at the optimum

In [57]:
import numpy as np
from statsmodels.tools.numdiff import approx_hess

H = approx_hess(theta_hat, gengamma_objective)  # by default, central differences

In [58]:
# Invert the Hessian to get the variance-covariance matrix
vcov_matrix = np.linalg.inv(H)

print(vcov_matrix)

[[ 5.95129550e+02 -1.77608293e+01 -3.15062237e+02]
 [-1.77608293e+01  5.51055889e-01  9.63377691e+00]
 [-3.15062237e+02  9.63377691e+00  1.69340808e+02]]


In [60]:
# Standard errors are the square roots of the diagonal elements
std_errors = np.sqrt(np.diag(vcov_matrix))

print(std_errors)

[24.3952772   0.74233139 13.01310141]


Fitted parameters of generalized Gamma distribution

In [43]:
# Fitted parameters of generalized Gamma distribution

a_hat = 26.611
c_hat = 1.583
scale_hat = 8.409

In [44]:
# Age bands (inclusive of lower, exclusive of upper)
age_bands = [(0, 12), (12, 18), (18, 100)]

In [45]:
# Monte Carlo settings
N = 1000 # Number of Monte Carlo interations
rng = np.random.default_rng(42) # For reproducibility

In [46]:
# Log normal priors (distribution) that place ~95% of mass within ±10% of fitted parameters
def calc_sigma_for_pm10():
    return np.log(1.1) / 1.96

This custom function aims to calculate log-scale standard deviation for a log-normal prior, so that approximately 95% of mass falls with ±10% multiplicative of the median.

Derivation

For Y ~ Normal(mu, sigma) the 95% interval is mu ± 1.96sigma, so on the orignal scale exp(mu ± 1.96sigma) gives a multiplicative factor exp(1.96*sigma). Setting that factor = 1.1 (i.e. +10%) give sigma = ln(1.1)/ 1.96

In [47]:
s_log = calc_sigma_for_pm10()
print(f"Log-normal standard deviation for ±10%: {s_log:.4f}")

Log-normal standard deviation for ±10%: 0.0486


In [48]:
mu_a = np.log(a_hat)
mu_c = np.log(c_hat)
mu_scale = np.log(scale_hat)

In [49]:
# Sample parameters from log-normal priors
a_samples = rng.lognormal(mean=mu_a, sigma=s_log, size=N)
c_samples = rng.lognormal(mean=mu_c, sigma=s_log, size=N)
scale_samples = rng.lognormal(mean=mu_scale, sigma=s_log, size=N)

rng.lognormal() is a NumPy random number generator method that draws samples from a log-normal distribution - a distribution where the logarithm of the variable follows normal (Gaussian) distribution.

rng.lognormal(mean=mu_a, sigma=s_log, size=N) draws samples X = exp(Y) where Y ~ Normal(mu_a, s_log)

Note that rng.lognormal equivalent to np.exp(rng.normal(mean, sigma, size))

In [50]:
# Compute band probabilities for each parameter set via CDF
def band_probs_for_draw(a, c, s):
    F = gengamma(a=a, c=c, scale=s).cdf
    p0_12 = F(12.0) - F(0.0)
    p12_18 = F(18.0) - F(12.0)
    p18_100 = F(100.0) - F(18.0)
    return p0_12, p12_18, p18_100

gengamma(a=a, c=c, scale=s) constructs a "frozen" SciPy generalized Gamma distribution with given parameters. Appending .cdf returns that distribution's cumulative distribution function as a callable (i.e., can be called like a function).

What F is: a function F(x) that returns P(x <= x) for X ~ GenGamma(a,c,scale=s). It accepts acalars or numpy arrays and returns probabities in [0,1].



In [51]:
P = np.array([band_probs_for_draw(a, c, s) for a, c, s in zip(a_samples, c_samples, scale_samples)]) # Shape (N, 3)

In [52]:
# Summarize reults
summary = pd.DataFrame({
    'Age Band': ['0-12', '12-18', '18+'],
    'Mean': np.char.mod('%.2f%%', P.mean(axis=0)*100),
    "SD": np.char.mod('%.2f%%', P.std(axis=0)*100),
    "Median": np.char.mod('%.2f%%', np.median(P, axis=0)*100),
    "CI Lower (2.5%)": np.char.mod('%.2f%%', np.percentile(P, 2.5, axis=0)*100),
    "CI Upper (97.5%)": np.char.mod('%.2f%%', np.percentile(P, 97.5, axis=0)*100)
})

print(summary)


  Age Band    Mean     SD  Median CI Lower (2.5%) CI Upper (97.5%)
0     0-12   0.00%  0.00%   0.00%           0.00%            0.00%
1    12-18   0.00%  0.00%   0.00%           0.00%            0.00%
2      18+  98.87%  4.31%  99.99%          88.34%          100.00%
