In [None]:
import numpy as np
import pandas as pd
import uuid
import pickle
from IPython.display import display


%load_ext autoreload
%autoreload 2

from vpop_calibration import *

## Define a reference ODE model


In [None]:
# Define the reference PK model we will use for training the GP and generating a synthetic data set
def equations(t, y, k_a, k_12, k_21, k_el):
    # y[0] is A_absorption, y[1] is A_central, y[2] is A_peripheral
    A_absorption, A_central, A_peripheral = y[0], y[1], y[2]
    dA_absorption_dt = -k_a * A_absorption
    dA_central_dt = (
        k_a * A_absorption + k_21 * A_peripheral - k_12 * A_central - k_el * A_central
    )
    dA_peripheral_dt = k_12 * A_central - k_21 * A_peripheral

    ydot = [dA_absorption_dt, dA_central_dt, dA_peripheral_dt]
    return ydot


variable_names = ["A0", "A1", "A2"]
parameter_names = ["k_a", "k_12", "k_21", "k_el"]

ode_model = OdeModel(equations, variable_names, parameter_names)
print(ode_model.variable_names)
time_span = (0, 24)
nb_steps = 10
time_steps = np.linspace(time_span[0], time_span[1], nb_steps).tolist()

protocol_design = pd.DataFrame({"protocol_arm": ["arm-A", "arm-B"], "k_el": [0.5, 1]})
nb_protocols = len(protocol_design)

initial_conditions = np.array([10.0, 0.0, 0.0])

## Train or load an existing GP surrogate


In [None]:
model_file = "gp_surrogate_pk_model.pkl"
folder_path = "./"

model_full_path = folder_path + model_file

use_pickle = False
override_existing_pickle = False

In [None]:
if (override_existing_pickle) or (not use_pickle):
    # Simulate a training data set using parameters sampled via Sobol sequences
    log_nb_patients = 9
    param_ranges = {
        "k_12": {"low": -2.0, "high": 0.0, "log": True},
        "k_21": {"low": -2.0, "high": 0.0, "log": True},
        "k_a": {"low": -2.0, "high": 0.0, "log": True},
    }

    print(f"Simulating {2**log_nb_patients} patients on {nb_protocols} scenario arms")
    dataset = simulate_dataset_from_ranges(
        ode_model,
        log_nb_patients,
        param_ranges,
        initial_conditions,
        protocol_design,
        None,
        None,
        time_steps,
    )

    learned_ode_params = list(param_ranges.keys())
    descriptors = learned_ode_params + ["time"]

    # Instantiate a GP
    myGP = GP(
        dataset,
        descriptors,
        var_strat="IMV",  # either IMV (Independent Multitask Variational) or LMCV (Linear Model of Coregionalization Variational)
        kernel="RBF",  # Either RBF or SMK
        deep_kernel=False,
        data_already_normalized=False,  # default
        nb_inducing_points=100,
        mll="ELBO",  # default, otherwise PLL
        nb_training_iter=200,
        training_proportion=0.7,
        learning_rate=0.1,
        lr_decay=0.99,
        jitter=1e-6,
        nb_features=10,
        nb_latents=3,
        log_inputs=learned_ode_params,
        log_outputs=variable_names,
    )
    # Train the GP
    myGP.train()

    if use_pickle and override_existing_pickle:
        with open(model_full_path, "wb") as file:
            pickle.dump(myGP, file)
        print(f"Model saved to {model_file}")
elif use_pickle and (not override_existing_pickle):
    try:
        with open(model_full_path, "rb") as file:
            myGP = pickle.load(file)

        print("Model loaded successfully!")

    except FileNotFoundError:
        print(
            f"File not found. Please make sure '{model_full_path}' exists and is in the correct directory."
        )

In [None]:
myGP.plot_loss()

In [None]:
myGP.plot_all_solutions("training")

## Generate a synthetic data set using a NLME from the ODEs directly


In [None]:
time_span_rw = (0, 24)
nb_steps_rw = 10

# For each output and for each patient, give a list of time steps to be simulated
time_steps_rw = np.linspace(time_span_rw[0], time_span_rw[1], nb_steps_rw).tolist()

# Parameter definitions
true_log_MI = {}
true_log_PDU = {"k_12": {"mean": -1.0, "sd": 0.25}, "k_a": {"mean": -1.0, "sd": 0.25}}
error_model_type = "additive"
true_res_var = [0.05, 0.01, 0.01]
true_covariate_map = {
    "k_12": {"foo": {"coef": "cov_foo_k12", "value": 0.1}},
    "k_a": {},
}

# Create a patient data frame
# It should contain at the very minimum one `id` per patient
nb_patients = 200
patients_df = pd.DataFrame({"id": [str(uuid.uuid4()) for _ in range(nb_patients)]})
rng = np.random.default_rng()
patients_df["protocol_arm"] = rng.binomial(1, 0.5, nb_patients)
patients_df["protocol_arm"] = patients_df["protocol_arm"].apply(
    lambda x: "arm-A" if x == 0 else "arm-B"
)
patients_df["k_21"] = rng.lognormal(-1, 0.1, nb_patients)
patients_df["foo"] = rng.lognormal(-2, 0.5, nb_patients)
display(patients_df.head())

print(f"Simulating {nb_patients} patients on {nb_protocols} protocol arms")
obs_df = simulate_dataset_from_omega(
    ode_model,
    protocol_design,
    time_steps,
    initial_conditions,
    true_log_MI,
    true_log_PDU,
    error_model_type,
    true_res_var,
    true_covariate_map,
    patients_df,
)

display(obs_df.head())
print(obs_df.shape)

## Optimize the GP surrogate using SAEM


In [None]:
# Initial pop estimates
# Parameter definitions
init_log_MI = {}
init_log_PDU = {
    "k_a": {"mean": -1.0, "sd": 0.1},
    "k_12": {"mean": -1.0, "sd": 0.1},
}
error_model_type = "additive"
init_res_var = [0.05, 0.05, 0.05]
init_covariate_map = {
    "k_12": {"foo": {"coef": "cov_foo_k12", "value": 0.0}},
    "k_a": {},
}

# Create a structural model
structural_gp = StructuralGp(myGP)
# Create a NLME moedl
nlme_surrogate = NlmeModel(
    structural_gp,
    patients_df,
    init_log_MI,
    init_log_PDU,
    init_res_var,
    init_covariate_map,
    error_model_type,
    pred_var_threshold=1e-2,
)
obs_df_bootstrapped = obs_df.sample(frac=0.8)
print(obs_df_bootstrapped.shape)
# Create an optimizer: here we use SAEM
optimizer = PySaem(
    nlme_surrogate,
    obs_df_bootstrapped,
    nb_phase1_iterations=500,
    nb_phase2_iterations=200,
    mcmc_nb_transitions=1,
    verbose=False,
    true_log_MI=true_log_MI,
    true_log_PDU=true_log_PDU,
    true_res_var=true_res_var,
    true_covariates=true_covariate_map,
    plot_frames=50,
)

In [None]:
optimizer.run()

In [None]:
plot_map_estimates(nlme_surrogate)

In [None]:
warnings, ranges = check_surrogate_validity_gp(nlme_surrogate)