# Quantifying uncertainty in simulation results
## Introduction

The purpose of this cookbook is to demonstrate how one can compute and plot the 95% percentile prediction interval (PPI) of the mean for multiple time series at once. 

To compute a PPI for a given summary metric (e.g.: the mean), a given number of randomly samples of size n are drawn from a larger population. The metric of interest is computed for each sample, therefore allowing to estimate the empirical distribution of the latter. The 2.5% and 97.5% percentiles are then estimated. 

The PPI is a relevant metric to assess the degree of uncertainty embedded in the model and eventually to compare it to the uncertainty observed in a real-life setting. Indeed, standard confidence intervals are not well suited for the in silico context as they tend to get very narrow as the Virtual Population (VP) size increases. On the other hand, PPI allows to define a sample size (in the case where the VP is much larger than a real-life clinical trial, one can use the same sample size as the one used for real life observations).

In [None]:
# Jinko specifics imports & initialization
# Please fold this section and do not edit it

import sys

sys.path.insert(0, "../lib")
import jinko_helpers as jinko

# Connect to Jinko (see README.md for more options)

jinko.initialize()

In [None]:
# Cookbook specifics imports
import io
import json
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import zipfile


# Cookbook specifics constants:
# put here the constants that are specific to your cookbook like
# the reference to the Jinko items, the name of the model, etc.

# @param {"name":"trialId", "type": "string"}
# trial short id can be retrieved in the url, pattern is `https://jinko.ai/<trail_sid>`
trial_sid = "tr-9Bid-BL1I"

## Step 1: Loading the trial and getting the last completed version

In [None]:
# Convert short id to core item id
trial_core_item_id = jinko.getCoreItemId(trial_sid, 1)

# List all trial versions
# https://doc.jinko.ai/api/#/paths/core-v2-trial_manager-trial-trialId--status/get
response = jinko.makeRequest(
    f'/core/v2/trial_manager/trial/{trial_core_item_id["id"]}/status'
)
versions = response.json()

# Get the latest completed version
try:
    latest_completed_version = next(
        (item for item in versions if item["status"] == "completed"), None
    )
    if latest_completed_version is None:
        raise Exception("No completed trial version found")
    else:
        print(
            "Successfully fetched this simulation:\n",
            json.dumps(latest_completed_version, indent=1),
        )
        simulation_id = latest_completed_version["simulationId"]
        trial_core_item_id = simulation_id["coreItemId"]
        trial_snapshot_id = simulation_id["snapshotId"]
except Exception as e:
    print(f"Error processing trial versions: {e}")
    raise

## Step 2: Displaying a summary of the data content

In [None]:
# https://doc.jinko.ai/api/#/paths/core-v2-trial_manager-trial-trialId--snapshots--trialIdSnapshot--results_summary/get
response = jinko.makeRequest(
    f"/core/v2/trial_manager/trial/{trial_core_item_id}/snapshots/{trial_snapshot_id}/results_summary",
    method="GET",
)
response_summary = json.loads(response.content)

# Print a summary of the results content
print("Keys in the results summary:\n", list(response_summary.keys()), "\n")
print("Available patients:\n", response_summary["patients"], "\n")
print("Available arms:\n", response_summary["arms"], "\n")
print(
    "Available scalars:\n",
    [scalar["id"] for scalar in response_summary["scalars"]],
    "\n",
)
print(
    "Available cross-arm scalars:\n",
    [scalar["id"] for scalar in response_summary["scalarsCrossArm"]],
    "\n",
)
print(
    "Available categorical scalars:\n",
    [scalar["id"] for scalar in response_summary["categoricals"]],
    "\n",
)
print(
    "Available cross-arm categorical scalars:\n",
    [scalar["id"] for scalar in response_summary["categoricalsCrossArm"]],
    "\n",
)

# Store the list of scenario descriptors fetch them
scenario_descriptors = [
    scalar["id"]
    for scalar in (response_summary["scalars"] + response_summary["categoricals"])
    if "ScenarioOverride" in scalar["type"]["labels"]
]
print("List of scenario overrides:\n", scenario_descriptors, "\n")

## Step 3: Retrieving time series

In [None]:
# Listing the time series to retrieve
time_series_ids = ["Blood.Drug", "Tumor.CancerCell"]

try:
    print("Retrieving time series data...")
    response = jinko.makeRequest(
        "/core/v2/result_manager/timeseries_summary",
        method="POST",
        json={
            "select": time_series_ids,
            "trialId": latest_completed_version["simulationId"],
        },
    )
    if response.status_code == 200:
        print("Time series data retrieved successfully.")
        archive = zipfile.ZipFile(io.BytesIO(response.content))
        filename = archive.namelist()[0]
        print(f"Extracted time series file: {filename}")
        csv_time_series = archive.read(filename).decode("utf-8")
    else:
        print(
            f"Failed to retrieve time series data: {response.status_code} - {response.reason}"
        )
        response.raise_for_status()
except Exception as e:
    print(f"Error during time series retrieval or processing: {e}")
    raise

## Step 4: Post-processing the time series

In [None]:
# Loading timeseries into a dataframe
df_time_series = pd.read_csv(io.StringIO(csv_time_series))
print("Raw timeseries data (first rows): \n")
display(df_time_series.head())

# Count the number of observations per time point
counts = df_time_series["Time"].value_counts()

# Check if all time points have the same number of observations
all_equal = counts.nunique() == 1

if all_equal:
    print("All time points have the same number of observations.")
else:
    print(f"Time points have varying numbers of observations:\n{counts.value_counts()}")

## Step 5: Computing mean value by time point, for each arm and each descriptor 

In [None]:
df_means_grouped = (
    df_time_series.groupby(["Arm", "Descriptor", "Time"])["Value"].mean().reset_index()
)
df_means_grouped = pd.DataFrame(df_means_grouped)
display(df_means_grouped.head())

## Step 6: Computing the 95% percentiles (basic version)

### Defining a function to compute the 2.5% and 97.5% percentiles and another to perform the bootstrap

In [None]:
def percentiles(x):
    return x.quantile([0.025, 0.975]).values


def generate_subsample_means(group, num_subsamples, sample_size):
    subsample_means = []
    for _ in range(num_subsamples):
        subsample = group["Value"].sample(n=sample_size, replace=False)
        subsample_means.append(subsample.mean())
    return subsample_means

### Applying both function to the data

In [None]:
# Defining the number of subsamples and sample size
num_subsamples = 50
sample_size = 50

# Grouping by descriptor, arm and time and applying the function
df_subsample_means = (
    df_time_series.groupby(["Descriptor", "Arm", "Time"])
    .apply(
        lambda group: generate_subsample_means(
            group[["Value"]], num_subsamples, sample_size
        ),
        include_groups=False,
    )
    .reset_index()
)

# Exploding the list of means into separate rows
df_subsample_means = df_subsample_means.explode(0).reset_index(drop=True)
df_subsample_means.columns = ["Descriptor", "Arm", "Time", "Subsample_Mean"]

print(df_subsample_means.head())

# Computing percentiles
df_percentiles_grouped = (
    df_subsample_means.groupby(["Arm", "Descriptor", "Time"])["Subsample_Mean"]
    .apply(percentiles)
    .reset_index()
)

# Splitting the newly generated column into two separate ones
df_percentiles_grouped[["LoBound", "HiBound"]] = pd.DataFrame(
    df_percentiles_grouped["Subsample_Mean"].tolist(),
    index=df_percentiles_grouped.index,
)
df_percentiles_grouped = df_percentiles_grouped.drop(columns=["Subsample_Mean"])
display(df_percentiles_grouped.head())

## Step 7: Merging the two data-frames together

In [None]:
# Merging the two data frames together
df_ppi = pd.merge(
    df_means_grouped, df_percentiles_grouped, on=["Arm", "Descriptor", "Time"]
)
display(df_ppi)

# Step 8: Plotting the outputs

In [None]:
## Creating subplots
unique_variables = df_ppi["Descriptor"].unique()
fig = make_subplots(
    rows=1,
    cols=len(unique_variables),
    shared_yaxes=False,
    subplot_titles=unique_variables,
)

## Defining colors for different arms
palette = px.colors.qualitative.Plotly

## Creating a dictionary to map each arm to a color
unique_arm = df_ppi["Arm"].unique()
color_map = {
    category: palette[i % len(palette)] for i, category in enumerate(unique_arm)
}


## Looping through each descriptor and adding traces for mean, lower bound, and upper bound stratified by arm
for i, group in enumerate(unique_variables):
    group_df = df_ppi[df_ppi["Descriptor"] == group]

    for arm in unique_arm:
        subset = group_df[group_df["Arm"] == arm]

        # Add the mean line (plain line)
        fig.add_trace(
            go.Scatter(
                x=subset["Time"],
                y=subset["Value"],
                mode="lines",
                name=f"{group} {arm} Mean",
                line=dict(color=color_map[arm]),
            ),
            row=1,
            col=i + 1,
        )

        # Add the lower bound line (dotted line)
        fig.add_trace(
            go.Scatter(
                x=subset["Time"],
                y=subset["LoBound"],
                mode="lines",
                name=f"{group} {arm} 2.5% PPI ",
                line=dict(color=color_map[arm], dash="dot"),
            ),
            row=1,
            col=i + 1,
        )

        # Add the upper bound line (dotted line)
        fig.add_trace(
            go.Scatter(
                x=subset["Time"],
                y=subset["HiBound"],
                mode="lines",
                name=f"{group} {arm} 97.5% PPI",
                line=dict(color=color_map[arm], dash="dot"),
            ),
            row=1,
            col=i + 1,
        )

## Updating the layout
fig.update_layout(
    title="Mean and Bootstrapped 95% Prediction Interval Stratified by Variable and Arm",
    xaxis_title="X-axis",
    yaxis_title="Values",
    legend_title="Legend",
)

## Show the plot
fig.show()