In [None]:
import pandas as pd
import zipfile
import io
import jinko_helpers as jinko
from IPython.display import display
from plotnine import *
import numpy as np
import pickle

import vpop_calibration

# jinko set-up
jinko.initialize()

# Calibrating a population using a QSP model on jinko

This notebook demonstrates how to use the [`vpop-calibration`](https://git.novadiscovery.net/jinko/api/vpop-calibration) module to calibrate a QSP model that is simulated on jinko.

For the context of this notebook we will use a simple 2-compartments PK model that was implemented the cookbooks project. The reference trial is available here: https://jinko.ai/tr-OJvV-CPhT


## 1. Training a surrogate model

### 1.1 Generate training data

In order to train a surrogate model, we need to generate some data using the QSP model. Some caution needs to be taken if we want the training ot be efficient, in particular when it comes to choosing the parameter values. The most important aspect to consider here is how to choose the descriptors to include. Two questions are of importance:

- What are the parameters I want to estimate the distributions of?
- What are the parameters that will play a role in predicting the patients behavior (PDK)?

This should give a good idea of the total list of parameters we want to explore in the training data set.

The trial setup is important. Make sure you have loaded the right protocol design, and reasonable solving options, in particular the solving times, as the GP will be trained on each and every of the time steps included in the solving times (too many points will lead to a slow and difficult training). If the goal is to calibrate on observed data, it makes sense to filter the training set to only include the osbervation time points.


In [None]:
# INPUT REQUIRED HERE
# Enter your info:

# the trial short id, the outputs of interest and the descriptors to be studied
trial_sid = "tr-OJvV-CPhT"
descriptors = ["k12", "k21", "k_el"]
output_names = ["A1", "A2"]

# A folder that will be used to dump any new project item in jinko
folder_id = "a1032e99-2d28-4d7d-a1b8-a9bb9eeb0c68"

# Turn this flag to true to generate a new training vpop and rerun the trial
generate_vpop = False

In [None]:
trial_project_item = jinko.get_project_item(sid=trial_sid)
trial_core_item_id = trial_project_item["coreId"]["id"]
trial_snapshot_id = trial_project_item["coreId"]["snapshotId"]

# Fetching the trial content
trial_info = jinko.make_request(
    path=f"/core/v2/trial_manager/trial/{trial_core_item_id}/snapshots/{trial_snapshot_id}",
).json()

# Fetching the CM info
model_core_item_id = trial_info["computationalModelId"]["coreItemId"]
model_snapshot_id = trial_info["computationalModelId"]["snapshotId"]

# Fetching the protocol design info
protocol_design_core_id = trial_info["protocolDesignId"]["coreItemId"]
protocol_design_snapshot_id = trial_info["protocolDesignId"]["snapshotId"]

response = jinko.make_request(
    path=f"/core/v2/scenario_manager/protocol_design/{protocol_design_core_id}/snapshots/{protocol_design_snapshot_id}",
    method="GET",
    json={
        "Accept": "application/json;charset=utf-8, text/csv",
    },
)
protocol_design = response.json()
protocol_arms = [arm["armName"] for arm in protocol_design["scenarioArms"]]
selected_protocol_arms = protocol_arms[:3]
print("selected protocol arms: ", selected_protocol_arms)

# Fetching the solving times
solving_times = trial_info["solvingOptions"]["solvingTimes"]

In [None]:
def to_vpop(df):
    vpop_descriptors = set(df.columns)
    vpop_descriptors.remove("id")
    assert vpop_descriptors is not None

    def to_patient(row):
        return {
            "patientIndex": row["id"],
            "patientCategoricalAttributes": [],
            "patientAttributes": [
                {"id": param, "val": row[param]} for param in vpop_descriptors
            ],
        }

    vpop = {"patients": [to_patient(p) for _, p in df.iterrows()]}
    return vpop

In [None]:
if generate_vpop:
    # Generate an exploration vpop
    # The vpop calibration module provides a tool for that, the only thing we need is a set of parameter ranges
    # Here we use the same range for all studied descriptors
    ranges = {desc: {"low": -2.0, "high": 0.0, "log": True} for desc in descriptors}

    log_nb_patients = 5
    nb_patients = 2**log_nb_patients
    patients = vpop_calibration.generate_vpop_from_ranges(log_nb_patients, ranges)

    vpop = to_vpop(patients)

    response = jinko.make_request(
        path="/core/v2/vpop_manager/vpop",
        method="POST",
        json=vpop,
        options={
            "name": "Training vpop for GP",
            "folder_id": folder_id,
        },
    )
    vpop_item_info = jinko.get_project_item_info_from_response(response)
    assert vpop_item_info is not None
    vpop_core_item_id = vpop_item_info["coreItemId"]["id"]
    vpop_snapshot_id = vpop_item_info["coreItemId"]["snapshotId"]

    print(f"Generated a Vpop of {nb_patients} patients, sampling {descriptors}")
    print(f"Vpop link: {jinko.get_project_item_url_from_response(response)}")

    response = jinko.make_request(
        path=f"/core/v2/trial_manager/trial/{trial_core_item_id}/snapshots/{trial_snapshot_id}",
        method="PATCH",
        json={
            "vpopId": {"coreItemId": vpop_core_item_id, "snapshotId": vpop_snapshot_id}
        },
    )
    print(f"Updated trial link: {jinko.get_project_item_url_from_response(response)}")
    new_trial_info = jinko.get_project_item_info_from_response(response)
    assert new_trial_info is not None
    trial_core_item_id = new_trial_info["coreItemId"]["id"]
    trial_snapshot_id = new_trial_info["coreItemId"]["snapshotId"]

    # Run the trial
    response = jinko.make_request(
        path=f"/core/v2/trial_manager/trial/{trial_core_item_id}/snapshots/{trial_snapshot_id}/run",
        method="POST",
    )
    jinko.monitor_trial_until_completion(trial_core_item_id, trial_snapshot_id)
else:
    # Retrieve the vpop information from jinko directly
    vpop_core_item_id = trial_info["vpopId"]["coreItemId"]
    vpop_snapshot_id = trial_info["vpopId"]["snapshotId"]

    # Get the vpop content
    response = jinko.make_request(
        path=f"/core/v2/vpop_manager/vpop/{vpop_core_item_id}",
        method="GET",
        json={
            "Accept": "application/json;charset=utf-8, text/csv",
        },
    )
    vpop_data = response.json()
    patient_attributes_list = [
        {
            "id": p["patientIndex"],
            **{desc["id"]: desc["val"] for desc in p["patientAttributes"]},
        }
        for p in vpop_data["patients"]
    ]
    patients = pd.DataFrame(patient_attributes_list)
    print("exploration vpop:")
    display(patients)

In [None]:
# Retrieve results
timeseries_json = {"timeseries": {output: protocol_arms for output in output_names}}
csvTimeSeries = ""
try:
    response = jinko.make_request(
        path=f"/core/v2/result_manager/trial/{trial_core_item_id}/snapshots/{trial_snapshot_id}/timeseries/download",
        method="POST",
        json=timeseries_json,
    )
    if response.status_code == 200:
        print("Time series data retrieved successfully.")
        archive = zipfile.ZipFile(io.BytesIO(response.content))
        filename = archive.namelist()[0]
        print(f"Extracted time series file: {filename}")

        csvTimeSeries = archive.read(filename).decode("utf-8")

    else:
        print(
            f"Failed to retrieve time series data: {response.status_code} - {response.reason}"
        )
        response.raise_for_status()
except Exception as e:
    print(f"Error during time series retrieval or processing: {e}")
    raise

In [None]:
# Merge time series with patient descriptors together in a single data frame
df_time_series = pd.read_csv(io.StringIO(csvTimeSeries))
df_time_series = df_time_series.rename(columns={"Patient Id": "id"})
training_df = pd.merge(df_time_series, patients, on="id").rename(
    columns={
        "Arm": "protocol_arm",
        "Value": "value",
        "Descriptor": "output_name",
        "Time": "time",
    }
)
# OPTIONAL: Filter to keep only the protocol arms we want to train the GP on
training_df = training_df.loc[training_df["protocol_arm"].isin(selected_protocol_arms)]
# Remove the training point t=0, it is not informative and brings a lot of struggle for the GP since it does not vary in the vpop
training_df = training_df.loc[training_df["time"] > 0]
print("training dataframe:")
display(training_df)

In [None]:
# Visualize the training set quickly
(
    ggplot(training_df, aes(x="time", y="value", color="id"))
    + geom_line()
    + facet_grid(rows="protocol_arm", cols="output_name")
    + theme(legend_position="none")
    + scale_y_continuous(trans="log10")
)

## 1.2 Create and train a GP model on the training data

We will now instantiate a GP surrogate model and train it using the dataframe that was gathered from the jinko simulation


In [None]:
use_pickle = False
override_pickle = False
gp_file = "gp_surrogate.pkl"

In [None]:
if (not use_pickle) or (override_pickle):
    myGP = vpop_calibration.GP(
        training_df,
        descriptors + ["time"],
        nb_training_iter=200,
        nb_inducing_points=100,
        learning_rate=0.05,
        lr_decay=0.99,
        min_delta=0.01,
        log_inputs=descriptors,
        log_outputs=output_names,
    )

    myGP.train()
    if use_pickle and override_pickle:
        with open(gp_file, "wb") as file:
            pickle.dump(myGP, file)
        print(f"Model saved to {gp_file}")
elif use_pickle and (not override_pickle):
    try:
        with open(gp_file, "rb") as file:
            myGP = pickle.load(file)

        print("Model loaded successfully!")

    except FileNotFoundError:
        print(
            f"File not found. Please make sure '{gp_file}' exists and is in the correct directory."
        )

In [None]:
myGP.plot_obs_vs_predicted("training")

In [None]:
myGP.plot_all_solutions("training")

In [None]:
myGP.plot_individual_solution(0)

# 2. Load the calibration data set

Now that we have a surrogate model ready, we are going to be able to launch SAEM on it using some observed data.
The things to keep in mind are the following:

- observed time steps should be the same as the ones provided in the training data (adapt your training trial if necessary)
- observed outputs should be the outputs predicted by the GP (remove unnecessary outputs from the training data)
- observed protocols should be the same as the training protocols

In the context of this tutorial, we cheat by getting some patients from the training data frame and adding noise to the observations.


In [None]:
rng = np.random.default_rng()

n_patients_calib = 12
# Sample patients
calib_patients = pd.DataFrame(
    {"id": training_df.id.drop_duplicates().sample(n_patients_calib)}
)
# Choose one protocol arm per patient
calib_protocols = np.array(selected_protocol_arms)[
    rng.integers(0, len(selected_protocol_arms), n_patients_calib)
]
# Create a data set describing the patients (here only id and protocol arm are included, but covariates or PDK woul be included here)
calib_patients["protocol_arm"] = calib_protocols

# Creat a calibration data frame
calib_df = calib_patients.merge(training_df, on=["id", "protocol_arm"])
# Add proportional noise
proportional_error = 0.1
noise = rng.normal(1, proportional_error, calib_df.shape[0])
calib_df.loc[:, "value"] = calib_df["value"] * noise
# Remove some observations at random
calib_df = calib_df.sample(frac=0.8)
display(calib_df)

In [None]:
# Visualize the calib set quickly
(
    ggplot(calib_df, aes(x="time", y="value", color="id"))
    + geom_point()
    + facet_grid(rows="protocol_arm", cols="output_name")
    + theme(legend_position="none")
    + scale_y_continuous(trans="log10")
)

# 3. Calibrate the surrogate model using SAEM

In order to use SAEM, we need to define a Non-Linear Mixed Effects (NLME) model. This is the combination of a statistical model, a structural model and an error model.


In [None]:
# Define a structural model
gp_structural_model = vpop_calibration.StructuralGp(myGP)

# Define a parameter distribution model
# We consider no model intrinsic parameters
init_log_mi = {}
# All 3 parameters are to be calibrated
init_log_pdu = {
    "k12": {"mean": -1.0, "sd": 0.5},
    "k21": {"mean": -1.0, "sd": 0.5},
    "k_el": {"mean": -1.0, "sd": 0.5},
}
# No covariates are described here
init_covariate_map = None

# Error model

error_model = "additive"
init_res_var = [0.05, 0.05]
nlme_surrogate = vpop_calibration.NlmeModel(
    gp_structural_model,
    calib_patients,
    init_log_mi,
    init_log_pdu,
    init_res_var,
    None,
    error_model,
)

optimizer = vpop_calibration.PySaem(
    nlme_surrogate,
    calib_df,
    nb_phase1_iterations=500,
    nb_phase2_iterations=200,
    mcmc_nb_transitions=1,
    verbose=False,
    plot_frames=50,
)

In [None]:
optimizer.run()

In [None]:
# Check that the calibrated NLME fits the observed data
vpop_calibration.plot_map_estimates(nlme_surrogate)

In [None]:
warnings, ranges = vpop_calibration.check_surrogate_validity_gp(nlme_surrogate)

# 4 Posterior validation

Now that we have a calibrated NLME (surrogate) model, we can verify that the QSP model is also calibrated. A straight forward method of verifying this is to extract the maximum a posteriori (MAP) parameter estimates and generate a vpop from them


In [None]:
map_estimates = nlme_surrogate.map_estimates_descriptors()
map_estimates["id"] = nlme_surrogate.patients
new_vpop = to_vpop(map_estimates)

# Now you only need to push this vpop to jinko and run the model on it