# Pharmacokinetics models with TensorFlow Probability

Copyright 2021 Google LLC.

In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License"); { display-mode: "form" }
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

This notebook demonstrates how to fit a pharmacokinetic model with TensorFlow probability. This includes defining the relevant joint distribution and working through the basic steps of a Bayesian workflow, e.g. prior and posterior predictive checks, diagnostics for the inference, etc.

There are three main components when building a pharmacokinetic model:


1.   The pharmacokinetics of the system involves solving ordinary differential equations with varying levels of sophistication.
2.   We need to describe the treatment the patient undergoes, using a _clinical event schedule_.
3.  The data we have comes from multiple patients. To model the heterogeneity between patients, we use a hierarchical model (also termed a population model).

We'll first tackle a one compartment model with a first-order absorption from the gut.
The ODE describing this system is simple enough to be solved analytically.
We'll start with a one-dose model for one patient and build our way up to a population model with an event schedule.

Next on the to-do list will be expanding these models to nonlinear ODEs. Examples of ODEs that arise in PK models can be found in this [Stan notebook](https://mc-stan.org/events/stancon2017-notebooks/stancon2017-margossian-gillespie-ode.html).






ToDo list:
- one cpt model (with analytical solution, numerical integrator an option), one patient, one dose $\checkmark$
- one cpt model, one patient, multiple doses $\checkmark$
- one cpt model, multiple patients, multiple doses $\checkmark$
- Michaelis-Mentis PK model, one patient, one dose
- Michaelis-Mentis PK model, multiple patients, multiple doses.
- Friberg-Karlsson PKPD model, one patient, multiple doses.
- Friberg-Karlsson PKPD model, multiple patients, multiple doses.

In [None]:
import tensorflow as tf
tf.executing_eagerly()

import numpy as np
from matplotlib.pyplot import *
%config InlineBackend.figure_format = 'retina'
matplotlib.pyplot.style.use("dark_background")

import jax
from jax import random
from jax import numpy as jnp

from colabtools import adhoc_import

# import tensforflow_datasets
from inference_gym import using_jax as gym

# import tensorflow as tf
from tensorflow_probability.python.internal import prefer_static as ps
from tensorflow_probability.python.internal import unnest

import tensorflow_probability as _tfp
tfp = _tfp.substrates.jax
tfd = tfp.distributions
tfb = tfp.bijectors

tfp_np = _tfp.substrates.numpy
tfd_np = tfp_np.distributions 

from jax.experimental.ode import odeint
from jax import vmap

import arviz as az
from tensorflow_probability.python.internal.unnest import get_innermost

In [None]:
# Define nested Rhat for one parameter.
# Assume for now the indexed parameter is a scalar.
def nested_rhat(result_state, num_super_chains, index_param, num_samples,
                warmup_length = 0):
  state_param = result_state[index_param][
                           warmup_length:(warmup_length + num_samples), :, :]
  num_samples = state_param.shape[0]
  num_chains = state_param.shape[1]
  num_sub_chains = num_chains // num_super_chains
  
  state_param = state_param.reshape(num_samples, -1, num_sub_chains, 1)

  mean_chain = np.mean(state_param, axis = (0, 3))
  between_chain_var = np.var(mean_chain, axis = 1, ddof = 1)
  within_chain_var = np.var(state_param, axis = (0, 3), ddof = 1)
  total_chain_var = between_chain_var + np.mean(within_chain_var, axis = 1)

  mean_super_chain = np.mean(state_param, axis = (0, 1, 3))
  between_super_chain_var = np.var(mean_super_chain, ddof = 1)

  return np.sqrt(1 + between_super_chain_var / np.mean(total_chain_var))

# WARNING: this is a very poor estimate for ESS, and we shoud note
# W / B isn't typically used to estimate ESS.
def ess_per_super_chain(nRhat):
  return 1 / (np.square(nRhat) - 1)

## 1 One compartment model with absoprtion from the gut

A patient orally takes in a drug, which lands in the gut and is then absorbed into a central compartment (e.g. the blood). This process is described by a differential equation:
\begin{eqnarray*}
  y_0' & = & -k_0 y_0, \\
  y_1' & = & k_0 y_0 - k_1 y_1,
\end{eqnarray*}
with each state corresponding to the drug mass in the gut and the central compartment.
This system can be solved analytically for initial conditions $(y_0^I, y_1^I)$ at time $t = 0$:
\begin{eqnarray*}
  y_0 & = & y_0^I e^{-k_0 t}, \\
  y_1 & = & \frac{e^{-k_1 t}}{k_0 - k_1} \left [ y_0^I k_0(1 - e^{(k_1 - k_0)t}) + (k_0 - k_1) y^I_1 \right ], 
\end{eqnarray*}
provided $k_0 \neq k_1$.

 We can also use on Jax's `odeint` and solve the equation numerically. This will set us up for more complicated problems. The data is noisy observation of $y_1$ (in practice we should use $y_1 / V$ where $V$ is the volume of the central compartment, but I'll omit this for now).


In [None]:
# NOTE: need to pass the initial time as the first element of t.
t = np.array([0., 0.5, 0.75, 1, 1.25, 1.5, 2, 3, 4, 5, 6])
y0 = np.array([100.0, 0.0])

theta = np.array([1.5, 0.25])
def system(state, time, theta):
  k1 = theta[0]
  k2 = theta[1]

  return jnp.array([
    - k1 * state[0]  ,
    k1 * state[0] - k2 * state[1]
  ])


In [None]:
use_analytical_sln = True

if (use_analytical_sln):
  def ode_map(k1, k2):
    sln = jnp.exp(- k2 * t) / (k1 - k2) * (y0[0] * k1 * (1 - jnp.exp((k2 - k1) * t)) + (k1 - k2) * y0[1])
    return sln[1:]
else:
  def ode_map(k1, k2):
    theta = jnp.array([k1, k2])
    return odeint(system, y0, t, theta, mxstep = 1e6)[1:, 1]

In [None]:
states = ode_map(k1 = theta[0], k2 = theta[1])
random.normal(random.PRNGKey(37272710), (states.shape[0],))
jnp.log(states)

## 1.1 Model for one patient recieving a single dose

In [None]:
# Simulate data
states = ode_map(k1 = theta[0], k2 = theta[1])
sigma = 0.1
log_y = sigma * random.normal(random.PRNGKey(37272710), (states.shape[0],)) \
  + jnp.log(states)

y = jnp.exp(log_y)
# print(y)

figure(figsize = [6, 6])
plot(t[1:], states)
plot(t[1:], y, 'o')
show()

### 1.1.1 Run model with TFP
The model runs faster on a CPU than a GPU, because of the ODE integrator.

In [None]:
model = tfd.JointDistributionSequentialAutoBatched([
    # Priors
    tfd.LogNormal(loc = jnp.log(1.), scale = 0.5, name = "k1"),
    tfd.LogNormal(loc = jnp.log(1.), scale = 0.5, name = "k2"),
    tfd.HalfNormal(scale = 1., name = "sigma"),

    lambda sigma, k2, k1: (
      tfd.LogNormal(loc = jnp.log(ode_map(k1, k2)),
                    scale = sigma[..., jnp.newaxis], name = "y"))
])

def target_log_prob_fn(k1, k2, sigma):
  return model.log_prob((k1, k2, sigma, y))


In [None]:
num_dimensions = 3
def initialize (shape, key = random.PRNGKey(37272709)):
  prior_location = jnp.log(jnp.array([1., 1., 1.]))
  prior_scale = jnp.array([0.5, 0.5, 0.5])
  return jnp.exp(prior_scale * random.normal(key, shape + (num_dimensions,)) + prior_location)

# initial_state = initialize((4, ), key = random.PRNGKey(1954))
initial_state = model.sample(sample_shape = (4, 1), seed = random.PRNGKey(1954))[:3]

In [None]:
x = jnp.array(initial_state).reshape(3, 4)
print(x[0, :])

In [None]:
# TODO: find a wat to do this when the init is a list!! 
# Check call to target_log_prob_fn works
# target = target_log_prob_fn(initial_state)
# print(target)

In [None]:
# Prior predictive checks
num_prior_samples = 1000
*prior_samples, prior_predictive = model.sample(1000, seed = random.PRNGKey(37272709))

In [None]:
figure(figsize = [6, 6])
plot(t[1:], y, 'o')
plot(t[1:], np.median(prior_predictive, axis = 0), color = 'yellow')
plot(t[1:], np.quantile(prior_predictive, q = 0.95, axis = 0), linestyle = ':', color = 'yellow')
plot(t[1:], np.quantile(prior_predictive, q = 0.05, axis = 0), linestyle = ':', color = 'yellow')
show()

In [None]:
# Implement ChEES transition kernel.
init_step_size = 1
warmup_length = 1000

kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, warmup_length)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
     kernel, warmup_length, target_accept_prob = 0.75,
     reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)

def trace_fn(current_state, pkr):
  return (
    # proxy for divergent transitions
    get_innermost(pkr, 'log_accept_ratio') < -1000
  )

In [None]:
num_chains = 4

mcmc_states, diverged = tfp.mcmc.sample_chain(
    num_results = 2000,
    current_state = initial_state, 
    kernel = kernel,
    trace_fn = trace_fn,
    seed = random.PRNGKey(1954))

In [None]:
# remove warmup samples
for i in range(0, len(mcmc_states)):
  mcmc_states[i] = mcmc_states[i][1000:]

In [None]:
# get draws for posterior predictive checks
*_, posterior_predictive = model.sample(value = mcmc_states, 
                                        seed = random.PRNGKey(37272709))

### 1.1.2 Analyze results


In [None]:
print("Divergent transition(s):", np.sum(diverged[1000:]))

To convert TFP's output to something compatible with Arviz, we'll follow the example in https://jeffpollock9.github.io/bayesian-workflow-with-tfp-and-arviz/.

In [None]:
parameter_names = model._flat_resolve_names()

az_states = az.from_dict(
    prior = {k: v[tf.newaxis, ...] for k, v in zip(parameter_names, prior_samples)},
    posterior={
        k: np.swapaxes(v, 0, 1) for k, v in zip(parameter_names, mcmc_states)
    },
)

In [None]:
print(az.summary(az_states).filter(items=["mean", "sd", "mcse_sd", "hdi_5%", 
                                       "hdi_95%", "ess_bulk", "ess_tail", 
                                       "r_hat"]))

In [None]:
axs = az.plot_trace(az_states, combined = False, compact = False)

In [None]:
# TODO: include potential divergent transitions.
az.plot_pair(az_states, figsize = (6, 6), kind = 'hexbin', divergences = True);

In [None]:
ppc_data = posterior_predictive.reshape((4000, 10))

In [None]:
figure(figsize = [6, 6])
plot(t[1:], y, 'o')
plot(t[1:], np.median(ppc_data, axis = 0), color = 'yellow')
plot(t[1:], np.quantile(ppc_data, q = 0.95, axis = 0), linestyle = ':', color = 'yellow')
plot(t[1:], np.quantile(ppc_data, q = 0.05, axis = 0), linestyle = ':', color = 'yellow')
show()

## 1.2 Clinical event schedule

Let's now suppose the patient recieves a bolus dose every $12$ hours for a total of $15$ doses.
The first dose is administered at time $t = 0$ and the final dose at time $t = 180$ (hours).
We take many observations during the first, second and fourtennth doses. For all other dosing events, we record the drug plasma concentration at the time of the dosing event (i.e. right before the dosing), and then 6 and 12 hours after the dose is administered.


In [None]:
# Construct event times, and identify dosing times (all other times correspond
# to measurement events).
time_after_dose = np.array([0.083, 0.167, 0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 6, 8])

t = np.append(
    np.append(np.append(np.append(0., time_after_dose),
                          np.append(12., time_after_dose + 12)),
               np.linspace(start = 24, stop = 156, num = 12)),
               np.append(jnp.append(168., 168. + time_after_dose),
               np.array([180, 192])))


start_event = np.array([], dtype = int)
dosing_time = range(0, 192, 12)

# Use dosing events to determine times of integration between
# exterior interventions on the system.
eps = 1e-4  # hack to deal with some t being slightly offset.
for t_dose in dosing_time:
  start_event = np.append(start_event, np.where(abs(t - t_dose) <= eps))

amt = jnp.array([1000., 0.])
n_dose = start_event.shape[0]

start_event = np.append(start_event, t.shape[0] - 1)

In [None]:
def ode_map (theta, dt, current_state):
  k1 = theta[0]
  k2 = theta[1]
  y0_hat = jnp.exp(- k1 * dt) * current_state[0]
  y1_hat = jnp.exp(- k2 * dt) / (k1 - k2) * (current_state[0] * k1 *\
                (1 - jnp.exp((k2 - k1) * dt)) + (k1 - k2) * current_state[1])
  return jnp.array([y0_hat, y1_hat])


In [None]:
ode_map(theta, np.array([1, 2, 3]), y0)[1, :]

We now wrap our ODE solver (whehter it be via an analytical solution or a numerical integrator) inside an event schedule handler. For starters, we'll go through the events using a `for` loop. This, it turns out, is fairly inefficient, and we'll later revise this code using `jax.lax.scan`.

In [None]:
def ode_map_event (theta):
  '''
  Wrapper around the ODE solver, based on the event schedule.
  NOTE: if using the ode integrator, need to adjust the shape of mass.
  '''
  y_hat = jnp.array([])
  current_state = amt
  for i in range(0, n_dose):
    t_integration = jax.lax.dynamic_slice(t, (start_event[i], ), 
                           (start_event[i + 1] - start_event[i] + 1, ))
    
    mass = ode_map(theta, t_integration - t_integration[0], current_state)
    # mass = odeint(system, current_state, t_integration,
    #               theta, rtol = 1e-6, atol = 1e-6, mxstep = 1e3)

    y_hat = jnp.append(y_hat, mass[1, 1:])
    current_state = mass[:, mass.shape[1]] + amt
  return y_hat

y_hat = ode_map_event(theta)
log_y_hat = jnp.log(y_hat[1:])

sigma = 0.5
# NOTE: no observation at time t = 0.
log_y = sigma * random.normal(random.PRNGKey(1954), (y_hat.shape[0],)) \
  + jnp.log(y_hat)
y_obs = jnp.exp(log_y)


In [None]:
figure(figsize = [6, 6])
plot(t[1:], y_hat)
plot(t[1:], y_obs, 'o', markersize = 2)
show()

The code above works fine to simulate data but we can do better using `jax.lax.scan`.

In [None]:
t_jax = jnp.array(t)
amt_vec = np.repeat(0., t.shape[0])
amt_vec[start_event] = 1000
amt_vec[amt_vec.shape[0] - 1] = 0.
amt_vec_jax = jnp.array(amt_vec)

# Overwrite definition of ode_map_event.
def ode_map_event(theta):
  def ode_map_step (current_state, event_index):
    dt = t_jax[event_index] - t_jax[event_index - 1]
    y_sln = ode_map(theta, dt, current_state)
    return (y_sln + jnp.array([amt_vec_jax[event_index], 0.])), y_sln[1,]

  (__, yhat) = jax.lax.scan(ode_map_step, amt, np.array(range(1, t.shape[0])))
  return yhat


In [None]:
y_hat = ode_map_event(theta) 

figure(figsize = [6, 6])
plot(t[1:], y_hat)
plot(t[1:], y_obs, 'o', markersize = 2)
show()

In [None]:
# Remark: using more informative priors helps insure the chains mix
# reasonably well. (Could be interesting to examine with nested-rhat
# the case where they don't).
model = tfd.JointDistributionSequentialAutoBatched([
    # Priors
    tfd.LogNormal(loc = jnp.log(1.), scale = 0.5, name = "k1"),
    tfd.LogNormal(loc = jnp.log(.5), scale = 0.25, name = "k2"),
    tfd.HalfNormal(scale = 1., name = "sigma"),

    lambda sigma, k2, k1: (
      tfd.LogNormal(loc = jnp.log(ode_map_event(jnp.array([k1, k2]))),
                    scale = sigma[..., jnp.newaxis], name = "y_obs"))
])

def target_log_prob_fn(k1, k2, sigma):
  return model.log_prob((k1, k2, sigma, y_obs))


In [None]:
initial_state = model.sample(sample_shape = (4, 1), seed = random.PRNGKey(1954))[:3]

In [None]:
# TODO: find a way to test target_log_prob_fn with init as a list
# print(initial_state)
# target_log_prob_fn(initial_state)

In [None]:
# Implement ChEES transition kernel.
init_step_size = 0.1
warmup_length = 1000

kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, warmup_length)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
     kernel, warmup_length, target_accept_prob = 0.75,
     reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)


In [None]:
def trace_fn(current_state, pkr):
  return (
    # proxy for divergent transitions
    get_innermost(pkr, 'log_accept_ratio') < -1000,
    get_innermost(pkr, 'step_size'),
    get_innermost(pkr, 'max_trajectory_length')
  )

In [None]:
num_chains = 4

mcmc_states, diverged = tfp.mcmc.sample_chain(
    num_results = 2000, 
    current_state = initial_state, 
    kernel = kernel,
    trace_fn = trace_fn,
    seed = random.PRNGKey(1954))

In [None]:
semilogy(diverged[1], label = "step size")
semilogy(diverged[2], label = "max_trajectory length")
legend(loc = "best")
show()

In [None]:
# remove warmup samples
for i in range(0, len(mcmc_states)):
  mcmc_states[i] = mcmc_states[i][1000:]

We'll only look at some essential diagnostics. For more, we can follow the code in the single-dose example.

In [None]:
print("Divergent transition(s):", np.sum(diverged[1000:]))

In [None]:
parameter_names = model._flat_resolve_names()

az_states = az.from_dict(
    prior = {k: v[tf.newaxis, ...] for k, v in zip(parameter_names, prior_samples)},
    posterior={
        k: np.swapaxes(v, 0, 1) for k, v in zip(parameter_names, mcmc_states)
    },
)

print(az.summary(az_states).filter(items=["mean", "sd", "mcse_sd", "hdi_3%", 
                                       "hdi_97%", "ess_bulk", "ess_tail", 
                                       "r_hat"]))

In [None]:
# get draws for posterior predictive checks
*_, posterior_predictive = model.sample(value = mcmc_states, 
                                        seed = random.PRNGKey(37272709))

In [None]:
# ppc_data = posterior_predictive.reshape(1000, 4, 52)

# az_data = az.from_dict(
#     posterior = dict(x = ppc_data.transpose((1, 0, 2)))
# )
# print(az.summary(az_data).filter(items=["mean", "hdi_3%", 
#                                        "hdi_97%", "ess_bulk", "ess_tail", 
#                                        "r_hat"]))

In [None]:
# REMARK: hmmm... the ppc's look odd. Not sure why. Everything else looks fine.
figure(figsize = [6, 6])
semilogy(t[1:], y_obs, 'o')
semilogy(t[1:], np.median(posterior_predictive, axis = (0, 1, 2)), color = 'yellow')
semilogy(t[1:], np.quantile(posterior_predictive, q = 0.95, axis = (0, 1, 2)), linestyle = ':', color = 'yellow')
semilogy(t[1:], np.quantile(posterior_predictive, q = 0.05, axis = (0, 1, 2)), linestyle = ':', color = 'yellow')
show()

## 1.3 Population models

We now model data from multiple patients and use a hierarchical model to describe inter-individual heterogeneity. For simplicity, we assume the patients all undergo the same treatment.

### 1.3.1 Simulate data

In [None]:
# (Code from previous cells, rewritten here to make
# section 1.3 self-contained).
# TODO: replace this with a function.
time_after_dose = np.array([0.083, 0.167, 0.25, 0.5, 0.75, 1, 1.5, 2, 3, 4, 6, 8])

t = np.append(
    np.append(np.append(np.append(0., time_after_dose),
                          np.append(12., time_after_dose + 12)),
               np.linspace(start = 24, stop = 156, num = 12)),
               np.append(jnp.append(168., 168. + time_after_dose),
               np.array([180, 192])))

start_event = np.array([], dtype = int)
dosing_time = range(0, 192, 12)

# Use dosing events to determine times of integration between
# exterior interventions on the system.
eps = 1e-4  # hack to deal with some t being slightly offset.
for t_dose in dosing_time:
  start_event = np.append(start_event, np.where(abs(t - t_dose) <= eps))

amt = jnp.array([1000., 0.])
n_dose = start_event.shape[0]

start_event = np.append(start_event, t.shape[0] - 1)

In [None]:
# NOTE: need to run the first cell under Section 1.2
# (Clinical event schedule)

n_patients = 100
pop_location = jnp.log(jnp.array([1.5, 0.25]))
# pop_location = jnp.log(jnp.array([0.5, 1.0]))
pop_scale = jnp.array([0.15, 0.35])
theta_patient = jnp.exp(pop_scale * random.normal(random.PRNGKey(37272709), 
                          (n_patients, ) + (2,)) + pop_location)

amt = np.array([1000., 0.])
amt_patient = np.append(np.repeat(amt[0], n_patients),
                        np.repeat(amt[1], n_patients))
amt_patient = amt_patient.reshape(2, n_patients)

# redfine variables from previous section (in case we only run population model)
t_jax = jnp.array(t)
amt_vec = np.repeat(0., t.shape[0])
amt_vec[start_event] = 1000
amt_vec[amt_vec.shape[0] - 1] = 0.
amt_vec_jax = jnp.array(amt_vec)

We rewrite the ode_map, so that, rather than returning the mass for one patient, it returns the mass across multiple patients. The function `ode_map` now takes in the physiological parameters for all patients, as well as the initial states for each patient.

The call to `jax.lax.scan` now takes in an additional argument, `unroll`, which is used to unroll the for loop and make its call on an accelerator more efficient. By default, `unroll = 1` (no unrolling); we observe a major speedup when using `unroll = 10`, and an additional (more minor) speedup when `unroll = 20`.

In [None]:
# Rewrite ode_map_event for population case.
# TODO: remove 'use_second_axis' hack.
def ode_map (theta, dt, current_state, use_second_axis = False):
  if (use_second_axis):
    k1 = theta[0, :]
    k2 = theta[1, :]
  else: 
    k1 = theta[:, 0]
    k2 = theta[:, 1]

  y0_hat = jnp.exp(- k1 * dt) * current_state[0, :]
  y1_hat = jnp.exp(- k2 * dt) / (k1 - k2) * (current_state[0, :] * k1 *\
                (1 - jnp.exp((k2 - k1) * dt)) + (k1 - k2) * current_state[1, :])
  return jnp.array([y0_hat, y1_hat])

# @jax.jit  # Cannot use jit if function has an IF statement.
def ode_map_event(theta, use_second_axis = False):
  def ode_map_step (current_state, event_index):
    dt = t_jax[event_index] - t_jax[event_index - 1]
    y_sln = ode_map(theta, dt, current_state, use_second_axis)
    dose = jnp.repeat(amt_vec_jax[event_index], n_patients)
    y_after_dose = y_sln + jnp.append(jnp.repeat(amt_vec_jax[event_index], n_patients),
                                      jnp.repeat(0., n_patients)).reshape(2, n_patients)
    return (y_after_dose, y_sln[1, ])

  (__, yhat) = jax.lax.scan(ode_map_step, amt_patient, 
                            np.array(range(1, t.shape[0])),
                            unroll = 20)
  return yhat

In [None]:
# Simulate some data
y_hat = ode_map_event(theta_patient)

sigma = 0.1
# NOTE: no observation at time t = 0.
log_y = sigma * random.normal(random.PRNGKey(1954), y_hat.shape) \
  + jnp.log(y_hat)
y_obs = jnp.exp(log_y)

figure(figsize = [6, 6])
plot(t[1:], y_hat)
plot(t[1:], y_obs, 'o', markersize = 2)
show()

### 1.3.2 Fit the model with TFP 

This is an adaptation of the previous model, except we're now only working with parameters on the unconstrained scale. This makes it easier for HMC and it is good practice.

In [None]:
pop_model = tfd.JointDistributionSequentialAutoBatched([
    # tfd.LogNormal(loc = jnp.log(1.), scale = 0.25, name = "k1_pop"),
    # tfd.LogNormal(loc = jnp.log(0.3), scale = 0.1, name = "k2_pop"),
    # tfd.Normal(loc = jnp.log(1.), scale = 0.25, name = "log_k1_pop"),
    tfd.Normal(loc = jnp.log(1.), scale = 0.1, name = "log_k1_pop"),
    tfd.Normal(loc = jnp.log(0.3), scale = 0.1, name = "log_k2_pop"),
    tfd.Normal(loc = jnp.log(0.15), scale = 0.1, name = "log_scale_k1"),
    tfd.Normal(loc = jnp.log(0.35), scale = 0.1, name = "log_scale_k2"),
    # tfd.HalfNormal(scale = 1., name = "sigma"),
    tfd.Normal(loc = -1., scale = 1., name = "log_sigma"),

    # non-centered parameterization for hierarchy
    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),
                               scale = jnp.ones(n_patients),
                               name = "eta_k1"),
                    reinterpreted_batch_ndims = 1),
    
    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),
                               scale = jnp.ones(n_patients),
                               name = "eta_k2"),
                    reinterpreted_batch_ndims = 1),

    lambda eta_k2, eta_k1, log_sigma, log_scale_k2, log_scale_k1,
           log_k2_pop, log_k1_pop: (
      tfd.Independent(tfd.LogNormal(
          loc = jnp.log(
              ode_map_event(theta = jnp.array([
                  jnp.exp(log_k1_pop[..., jnp.newaxis] + eta_k1 * jnp.exp(log_scale_k1[..., jnp.newaxis])),
                  jnp.exp(log_k2_pop[..., jnp.newaxis] + eta_k2 * jnp.exp(log_scale_k2[..., jnp.newaxis]))]),
                  use_second_axis = True)),
          scale = jnp.exp(log_sigma[..., jnp.newaxis]), name = "y_obs")))

    # lambda eta_k2, eta_k1, sigma, log_scale_k2, log_scale_k1,
    #        k2_pop, k1_pop: (
    #   tfd.Independent(tfd.LogNormal(
    #       loc = jnp.log(
    #           ode_map_event(theta = jnp.array(
    #           [jnp.exp(jnp.log(k1_pop[..., jnp.newaxis]) + eta_k1 * jnp.exp(log_scale_k1[..., jnp.newaxis])),
    #            jnp.exp(jnp.log(k2_pop[..., jnp.newaxis]) + eta_k2 * jnp.exp(log_scale_k2[..., jnp.newaxis]))]),
    #            use_second_axis = True)),
    #       scale = sigma[..., jnp.newaxis], name = "y_obs")))
])

def pop_target_log_prob_fn(log_k1_pop, log_k2_pop, log_scale_k1, log_scale_k2,
                           log_sigma, eta_k1, eta_k2):
  return pop_model.log_prob((log_k1_pop, log_k2_pop, log_scale_k1, log_scale_k2,
                            log_sigma, eta_k1, eta_k2, y_obs))
  # CHECK -- do we need to parenthesis?



# def pop_target_log_prob_fn(k1_pop, k2_pop, log_scale_k1, log_scale_k2,
#                            sigma, eta_k1, eta_k2):
#   return pop_model.log_prob((k1_pop, k2_pop, log_scale_k1, log_scale_k2,
#                            sigma, eta_k1, eta_k2, y_obs))

def pop_target_log_prob_fn_flat(x):
  k1_pop = x[:, 0]
  k2_pop = x[:, 1]
  sigma = x[:, 2]
  log_scale_k1 = x[:, 3]
  log_scale_k2 = x[:, 4]
  eta_k1 = x[:, 5:(5 + n_patients)]
  eta_k2 = x[:, (5 + n_patients):(5 + 2 * n_patients)]

  return pop_model.log_prob((k1_pop, k2_pop, log_scale_k1, log_scale_k2,
                           sigma, eta_k1, eta_k2, y_obs))


If we want to run many chains in parallel and use $n\hat R$ (nested $\hat R$), we need to specify the number of chains and the number of super chains.
The number of super chains determined the numbers of distinct starting point, seeing within each super chain, each chain starts at the same location. 

In [None]:
# Sample initial states from prior
num_chains = 128
num_super_chains = 4  #  num_chains  #  128

n_parm = 5 + 2 * n_patients
initial_state_raw = pop_model.sample(sample_shape = (num_super_chains, 1),\
                                     seed = random.PRNGKey(37272710))[:7]

# QUESTION: does this assignment create a pointer?
initial_state = initial_state_raw

for i in range(0, len(initial_state_raw)):
  initial_state[i] = np.repeat(initial_state_raw[i],
                               num_chains // num_super_chains, axis = 0)


Some care is required when setting the tuning parameters for ChEES-HMC, in particular the initial step size. In the [ChEES-HMC paper](http://proceedings.mlr.press/v130/hoffman21a/hoffman21a.pdf), the following proceudre is used: "Initial step sizes were chosen by repeatedly halving the step size (starting from a consistently too-large value of 1.0) until an HMC proposal with a single leapfrog step achieved a harmonic-mean acceptance probability of at least 0.5."

TODO: implement this.

For now we note that an "appropriate" initial step size depends on the number chains (for reasons I don't quite understand...).

In [None]:
# Implement ChEES transition kernel. Increase the target acceptance rate
# to avoid divergent transitions.
# NOTE: increasing the target acceptance probability can lead to poor performance.
init_step_size = 0.001  # CHECK -- how to best tune this?
warmup_length = 1000 # 1000

kernel = tfp.mcmc.HamiltonianMonteCarlo(pop_target_log_prob_fn, 
                                        step_size = init_step_size, 
                                        num_leapfrog_steps = 10)
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, warmup_length)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
     kernel, warmup_length, target_accept_prob = 0.75,
     reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)

def trace_fn(current_state, pkr):
  return (
    # proxy for divergent transitions
    get_innermost(pkr, 'log_accept_ratio') < -1000,
    get_innermost(pkr, 'step_size'),
    get_innermost(pkr, 'max_trajectory_length')
  )

In [None]:
mcmc_states, diverged = tfp.mcmc.sample_chain(
    num_results = warmup_length + 1000,
    current_state = initial_state,
    kernel = kernel,
    trace_fn = trace_fn,
    seed = random.PRNGKey(1954))


In [None]:
# Remark: somehow modifying mcmc_states still modifies
# mcmc_states_raw.
mcmc_states_raw = mcmc_states

### 1.3.3 Traditional diagnostics

In [None]:
# remove warmup samples
# NOTE: not a good idea. It's better to store all the states.
if False:
  for i in range(0, len(mcmc_states)):
    mcmc_states[i] = mcmc_states_raw[i][warmup_length:]

In [None]:
semilogy(diverged[1], label = "step size")
semilogy(diverged[2], label = "max_trajectory length")
legend(loc = "best")
show()

In [None]:
mcmc_states[0].shape

In [None]:
# Use this to search for points where the the step size changes
# dramatically and divergences that might be happening there.
if False:
  index_l = 219  # 225
  index_u = index_l + 1  # 235
  print("Max L:" , diverged[2][index_l:index_u])
  print("Divergence:", np.sum(diverged[0][index_l:index_u]),
        "at", np.where(diverged[0][index_l:index_u] == 1))

  chain = 0
  eta1_state = mcmc_states[5][index_l, chain, :] *\
    mcmc_states[2][index_l, chain, 0] + mcmc_states[0][index_l, chain, 0]
  eta2_state = mcmc_states[6][index_l, chain, :] *\
    mcmc_states[3][index_l, chain, 0] + mcmc_states[1][index_l, chain, 0]

  k0_state = np.exp(eta1_state)
  k1_state = np.exp(eta2_state) 
  print(k0_state - k1_state)

In [None]:
print("Divergent transition(s):", np.sum(diverged[0][warmup_length:]))

In [None]:
# NOTE: the last parameter is an 'x': not sure where this comes from...
parameter_names = pop_model._flat_resolve_names()[:-1]

az_states = az.from_dict(
    #prior = {k: v[tf.newaxis, ...] for k, v in zip(parameter_names, prior_samples)},
    posterior={
        k: np.swapaxes(v, 0, 1) for k, v in zip(parameter_names, mcmc_states)
    },
)

print(az.summary(az_states).filter(items=["mean", "sd", "mcse_sd", "hdi_3%", 
                                       "hdi_97%", "ess_bulk", "ess_tail", 
                                       "r_hat"]))

In [None]:
# Only plot the population parameters.
axs = az.plot_trace(az_states, combined = False, compact = False,
                    var_names = parameter_names[:5])

In [None]:
# posterior predictive checks
# NOTE: for 100 patients, running this exhausts memory
*_, posterior_predictive = pop_model.sample(value = mcmc_states, 
                                        seed = random.PRNGKey(37272709))
ppc_data = posterior_predictive.reshape(1000 * num_chains, 52, n_patients)

In [None]:
# NOTE: unclear why the confidence interval is so small...
fig, axes = subplots(n_patients, 1, figsize=(8, 4 * n_patients))

for i in range(0, n_patients):
  patient_ppc = posterior_predictive[:, :, :, :, i]
  axes[i].semilogy(t[1:], y_obs[:, i], 'o')
  axes[i].semilogy(t[1:], np.median(patient_ppc, axis = (0, 1, 2)), color = 'yellow')
  axes[i].semilogy(t[1:], np.quantile(patient_ppc, q = 0.95, axis = (0, 1, 2)), linestyle = ':', color = 'yellow')
  axes[i].semilogy(t[1:], np.quantile(patient_ppc, q = 0.05, axis = (0, 1, 2)), linestyle = ':', color = 'yellow')
show()

### 1.3.3 Diagnostic using $n \hat R$.

For starters, let's examine estimates in the short regime, i.e. using only the first few iterations from each chain. We'll focus on $\log k_{1,\text{pop}}$ which seems to have the most difficult expectation value to estimate (given it's relatively low ESS).

In [None]:
# Assumes mcmc_states contains all the samples (including warmup)
parameter_index = 0
num_samples = 500
mc_mean = np.mean(mcmc_states[parameter_index][
                  warmup_length:(warmup_length + num_samples), :, :])

print("Mean:", mc_mean)
print("Estimated squared error:",
      np.square(mc_mean -
                np.mean(mcmc_states[parameter_index][warmup_length:, :, :])))
print("Upper bound on expected squared error for one iteration:",
      np.var(mcmc_states[0]) / num_chains)

In [None]:
nRhat = nested_rhat(result_state = mcmc_states, 
                    num_super_chains = num_super_chains,
                    index_param = parameter_index, 
                    num_samples = num_samples,
                    warmup_length = warmup_length)

print("num_samples:", num_samples)
print("nRhat:", nRhat)
print("Rhat:",
      tfp.mcmc.potential_scale_reduction(
          mcmc_states[0][warmup_length:(num_samples + warmup_length), :, :]))

## 2 Michaelis-Menten pharmacokinetics (Incomplete)



Nonlinear PK model with absorption from the gut.

\begin{eqnarray*}
  y_0' & = & - k_a y_0 \\
  y_1' & = & k_a y_0 - \frac{V_m C}{K_m + C},
\end{eqnarray*}
whwre $C = y_1 / V$.

In [None]:
t = np.array([0.0, 0.5, 0.75, 1, 1.25, 1.5, 10, 20, 30, 40, 50, 60, 70, 80, 90, 100])
y0 = np.array([100.0, 0.0])
theta = np.array([0.5, 27, 10, 14])

def system(state, time, theta):
  ka = theta[0]
  V = theta[1]
  Vm = theta[2]
  Km = theta[3]
  C = state[1] / V

  return jnp.array([
    - ka * state[0],
    ka * state[0] - Vm * C  / (Km + C)            
  ])

states = odeint(system, y0, t, theta, mxstep = 1000)
sigma = 0.5
log_y = sigma * random.normal(random.PRNGKey(37272709), (states.shape[0] - 1,)) \
  + jnp.log(states[1:, 1])

y = jnp.exp(log_y)

figure(figsize = [6, 6])
plot(t[1:], states[1:, 1])
plot(t[1:], y, 'o');

In [None]:
def ode_map(ka, V, Vm, Km):
  theta = jnp.array([ka, V, Vm, Km])
  return odeint(system, y0, t, theta, mxstep = 1e3)[1:, 1]

model = tfd.JointDistributionSequentialAutoBatched([
    # Priors
    tfd.LogNormal(loc = jnp.log(1), scale = 0.5, name = "ka"),
    tfd.LogNormal(loc = jnp.log(35), scale = 0.5, name = "V"),
    tfd.LogNormal(loc = jnp.log(10), scale = 0.5, name = "Vm"),
    tfd.LogNormal(loc = jnp.log(2.5), scale = 1, name = "Km"),
    tfd.HalfNormal(scale = 1., name = "sigma"),

    # Likelihood (TODO: divide location by volume to get concentration)
    lambda sigma, Km, Vm, V, ka: (
      tfd.LogNormal(loc = jnp.log(ode_map(ka, V, Vm, Km) / V),
                   scale = sigma[..., jnp.newaxis], name = "y"))
])

def target_log_prob_fn(x):
  ka = x[:, 0]
  V = x[:, 1]
  Vm = x[:, 2]
  Km = x[:, 3]
  sigma = x[:, 4]
  return model.log_prob((ka, V, Vm, Km, sigma, y))

num_dimensions = 5
def initialize (shape, key = random.PRNGKey(37272709)):
  prior_location = jnp.log(jnp.array([1.5, 35, 10, 2.5, 0.5]))
  prior_scale = jnp.array([3, 0.5, 0.5, 3, 1.])
  return jnp.exp(prior_scale * random.normal(key, shape + (num_dimensions,)) + prior_location)

initial_state = initialize((4, ), key = random.PRNGKey(1954))


In [None]:
# Test target probability density can be computed
target = target_log_prob_fn(initial_state)
print(target)

In [None]:
# Implement ChEES transition kernel.
init_step_size = 1
warmup_length = 250

kernel = tfp.mcmc.HamiltonianMonteCarlo(target_log_prob_fn, init_step_size, 1)
kernel = tfp.experimental.mcmc.GradientBasedTrajectoryLengthAdaptation(kernel, warmup_length)
kernel = tfp.mcmc.DualAveragingStepSizeAdaptation(
     kernel, warmup_length, target_accept_prob = 0.75,
     reduce_fn = tfp.math.reduce_log_harmonic_mean_exp)


In [None]:
num_chains = 4

# NOTE: It takes 29 seconds to run one iteration. So running 500 iterations
# would take ~4 hours :(
# QUESTION: why does JAX struggle so much to solve this type of problem??
result = tfp.mcmc.sample_chain(
    num_results = 1, 
    current_state = initial_state, 
    kernel = kernel,
    seed = random.PRNGKey(1954))

## Draft Code


In [None]:
R = 1.62
1 / (R * R - 1)

In [None]:
a = np.array(range(4, 1024, 4))
d = np.repeat(6., len(a))

# Two optimization solutions, solving quadratic equations (+ / -)
# Remark: + solution gives a negative upper-bound for delta_u
alpha_1 = 2 * a + d / 2 - np.sqrt(np.square(2 * a + d / 2) - 2 * a)
alpha_2 = a - alpha_1
delta_u = (np.square(alpha_1 + d / 2) / (alpha_1 * alpha_2)) / 2

In [None]:
eps = 0.01
delta = np.square(1 + eps) - 1
print(delta)

In [None]:
semilogy(a / d, delta_u)
hlines(delta, (a / d)[0], (a / d)[len(a) - 1], linestyles = '--',
      label =  "delta for 1.01 threshold")
xlabel("a / d")

In [None]:
semilogy(a / d, alpha_1 / a, label = "alpha_1")
semilogy(a / d, alpha_2 / a, label = "alpha_2")
legend(loc = 'best')
xlabel("a / d")
ylabel("alpha")

In [None]:
aindex_location = np.where(a / d == 100)
print(index_location)
print(delta_u[index_location])
delta

In [None]:
pop_model = tfd.JointDistributionSequentialAutoBatched([
    tfd.LogNormal(loc = jnp.log(1.), scale = 0.5, name = "k1_pop"),
    tfd.LogNormal(loc = jnp.log(.5), scale = 0.25, name = "k2_pop"),
    tfd.Normal(loc = jnp.log(0.5), scale = 1., name = "log_scale_k1"),
    tfd.Normal(loc = jnp.log(0.5), scale = 1., name = "log_scale_k2"),
    tfd.HalfNormal(scale = 1., name = "sigma"),

    # non-centered parameterization for hierarchy
    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),
                               scale = jnp.ones(n_patients),
                               name = "eta_k1"),
                    reinterpreted_batch_ndims = 1),
    
    tfd.Independent(tfd.Normal(loc = jnp.zeros(n_patients),
                               scale = jnp.ones(n_patients),
                               name = "eta_k2"),
                    reinterpreted_batch_ndims = 1),

    lambda eta_k2, eta_k1, sigma, log_scale_k2, log_scale_k1,
           k2_pop, k1_pop: (
      tfd.Independent(tfd.LogNormal(
          loc = jnp.log(
              ode_map_event(theta = jnp.array(
              [jnp.exp(jnp.log(k1_pop[..., jnp.newaxis]) + eta_k1 * jnp.exp(log_scale_k1[..., jnp.newaxis])),
               jnp.exp(jnp.log(k2_pop[..., jnp.newaxis]) + eta_k2 * jnp.exp(log_scale_k2[..., jnp.newaxis]))]),
               use_second_axis = True)),
          scale = sigma[..., jnp.newaxis], name = "y_obs")))
])


In [None]:
num_hyper = 5
num_dimensions = num_hyper + 2 * n_patients

def pop_initialize(shape, key = random.PRNGKey(37272710)) :
  # init for k1_pop, k2_pop, and sigma
  hyper_prior_location = jnp.array([jnp.log(1.5), jnp.log(0.25), 0.])
  hyper_prior_scale = jnp.array([0.5, 0.1, 0.5])
  init_hyper_param = jnp.exp(hyper_prior_scale * random.normal(key, shape + \
                             (3, )) + hyper_prior_location)

  # init for log_scale_k1 and log_scale_k2
  scale_prior_location = jnp.array([-1., -1.])
  scale_prior_scale = jnp.array([0.25, 0.25])
  init_scale = scale_prior_scale * random.normal(key, shape + (2, )) +\
    scale_prior_location

  # inits for the etas
  init_eta = random.normal(key, shape + (2 * n_patients, ))
  return jnp.append(jnp.append(init_hyper_param, init_scale, axis = 1), 
                    init_eta, axis = 1)

initial_state = pop_initialize((4, ))

In [None]:
initial_list = [initial_state[:, 0],  # k1_pop
                initial_state[:, 1],  # k2_pop
                initial_state[:, 2],  # log_scale_k1
                initial_state[:, 3],  # log_scale_k2
                initial_state[:, 4],  # sigma
                initial_state[:, 5:(5 + n_patients)],                    # eta_k1
                initial_state[:, (5 + n_patients):(5 + 2 * n_patients)]  # eta_k2         
                ]