# Postprocessing/Visualizing scalar results using reduce functions

## Introduction

This cookbook will guide you through the creation of a simple visualization from an existing trial in jinko.  
In particular, you will be able to retrieve scalar results and plot them using plotly.  

Linked resources: [Jinko](https://jinko.ai/project/e0fbb5bb-8929-439a-bad6-9e12d19d9ae4?labels=d59e57e6-889b-427a-b4ee-bed52a6c6ca2).

In [13]:
# Jinko-specific imports & initialization

import sys

sys.path.insert(0, "../lib")
import jinko_helpers as jinko

# Connect to Jinko (see README.md for more options)

jinko.initialize()

In [195]:
# Cookbook-specific imports

import datetime
import io
import json
import pandas as pd
import plotly.express as px
import re
import zipfile
import plotly.graph_objects as go

# Cookbook-specific constants

# Fill the short Id of your Trial (ex: tr-EKRx-3HRt)
trialId = "tr-Fzt9-uO98"

## Let's use the API and plot the data

### Load the trial

In [196]:
if trialId is None:
    raise Exception("Please specify a Trial Id")
else:
    print(f"Using Trial ID: {trialId}")

# Convert short Id to coreItemId
try:
    coreItemId = jinko.getCoreItemId(trialId, 1)
except Exception as e:
    print(f"Failed to find corresponding trial, check the trialId")
    raise

# List all Trial versions (https://doc.jinko.ai/api/#/paths/core-v2-trial_manager-trial-status/post)
try:
    trialVersions = jinko.makeRequest(
        f'/core/v2/trial_manager/trial/{coreItemId["id"]}/status'
    ).json()
    print(f"Fetched {len(trialVersions)} versions for the trial.")
except Exception as e:
    print(f"Error fetching trial versions: {e}")
    raise

# Get the latest completed version
try:
    latestCompletedVersion = next(
        (item for item in trialVersions if item["status"] == "completed"), None
    )
    if latestCompletedVersion is None:
        raise Exception("No completed Trial version found")
    else:
        print(
            "Successfully fetched this simulation:\n",
            json.dumps(latestCompletedVersion, indent=1),
        )
        # Store the trial Id and the snapshot Id to use in the API requests
        simulationId = latestCompletedVersion["simulationId"]
        trialId = simulationId["coreItemId"]
        trialSnapshotId = simulationId["snapshotId"]
except Exception as e:
    print(f"Error processing trial versions: {e}")
    raise

### Display a result summary

In [197]:
# Retrieve result summary (https://doc.jinko.ai/api/#/paths/core-v2-trial_manager-trial-trialId--snapshots--trialIdSnapshot--results_summary/get)
response = jinko.makeRequest(
    "/core/v2/trial_manager/trial/%s/snapshots/%s/results_summary"
    % (trialId, trialSnapshotId),
    method="GET",
)
responseSummary = json.loads(response.content)

defaultScalars = ['__jinkoAllocationMiB.tmax', '__jinkoSolvingTime.tmax', 'SimulationTMax', 'SimulationTMin']

# Print a summary of the results content
print("Keys in the results summary:\n", list(responseSummary.keys()), "\n")
print("Available patients:\n", responseSummary["patients"], "\n")
print("Available arms:\n", responseSummary["arms"], "\n")
print(
    "Available scalars:\n",
    [scalar["id"] for scalar in responseSummary["scalars"]],
    "\n",
)
print(
    "Available cross-arm scalars:\n",
    [scalar["id"] for scalar in responseSummary["scalarsCrossArm"]],
    "\n",
)
print(
    "Available categorical parameters:\n",
    [scalar["id"] for scalar in responseSummary["categoricals"]],
    "\n",
)
print(
    "Available cross-arm categorical parameters:\n",
    [scalar["id"] for scalar in responseSummary["categoricalsCrossArm"]],
    "\n",
)

# Store the list of scenario overrides to use them later
scenarioOverrides = [
    scalar["id"]
    for scalar in (responseSummary["scalars"] + responseSummary["categoricals"])
    if "ScenarioOverride" in scalar["type"]["labels"]
]
print("List of scenario overrides:\n", scenarioOverrides, "\n")

# Store the list of scalars that are not scenario descriptors to use them later
resultScalars = [
    scalar["id"]
    for scalar in responseSummary["scalars"]
    if "Custom" in scalar["type"]["labels"]
]
print("List of result scalars:\n", resultScalars, "\n")

### Download scalar results data

In [198]:
# Retrieve scalar results (https://doc.jinko.ai/api/#/paths/core-v2-result_manager-scalars_summary/post)

# replace here by the scalar ids list you want
idsForScalars = {
    "scalars": resultScalars
    , "scenarioOverrides": scenarioOverrides
}
csvScalars = {}
def retrieve_scalars(scalar_type):
    try:
        print("Retrieving scalar results...")
        response = jinko.makeRequest(
            "/core/v2/result_manager/scalars_summary",
            method="POST",
            json={
                "select": idsForScalars[scalar_type],
                "trialId": latestCompletedVersion["simulationId"],
            },
        )
        if response.status_code == 200:
            print("Scalar results retrieved successfully.")
            archive = zipfile.ZipFile(io.BytesIO(response.content))
            filename = archive.namelist()[0]
            print(f"Extracted scalar results file: {filename}")
            csvScalars[scalar_type] = archive.read(filename).decode("utf-8")
        else:
            print(
                f"Failed to retrieve scalar results: {response.status_code} - {response.reason}"
            )
            response.raise_for_status()
    except Exception as e:
        print(f"Error during scalar results retrieval or processing: {e}")
        raise

retrieve_scalars("scalars")
retrieve_scalars("scenarioOverrides")

### Postprocess the data in a pandas dataframe

The data is post-processed using the pandas library, and transformed into a table separating the information necessary to plot. 

In [199]:
# All currently available reduce functions
reduceFuncs = ["time-of-max", "time-of-min", "max-slope", "min-slope", "amplitude", "end-minus-start", "at", "auc", "avg", "max", "min"]
reduceFuncRegex = '('+'|'.join(reduceFuncs)+')'

# Load scalars into a dataframe
dfScalars = pd.read_csv(io.StringIO(csvScalars["scalars"]))
print("\nRaw scalar data (first rows):\n")
display(dfScalars.head())

# Postprocessing - Split the scalarId column into pieces depending on the applicable pattern
dfScalars[['generalScalarId','reduceFunction', "startTime"]] = dfScalars['scalarId'].str.extract(r'(.*?)-' + reduceFuncRegex + '(?:-from-|-)(.*)')

# Postprocessing - Extract start and end times for each measure created on jinkō (for point measures like `at`, start and end time are equal) 
def extract_endtime(value):
    match = re.match(r'(?:.*?-to-(.*)|^(.*)$)', value)
    if match:
        return match.group(1) or match.group(2)
    return None

def extract_starttime(value):
    match = re.match(r'(?:(.*?)-to-.*|^(.*)$)', value)
    if match:
        return match.group(1) or match.group(2)
    return None

dfScalars['endTime'] = dfScalars['startTime'].apply(extract_endtime)

dfScalars['startTime'] = dfScalars['startTime'].apply(extract_starttime)


# Load scenarioOverrides into a dataframe
dfOverrides = pd.read_csv(io.StringIO(csvScalars["scenarioOverrides"]))

# Remove unneccessary trailing tmins
dfOverrides['scalarId'] = [x.strip('.tmin') for x in dfOverrides['scalarId']]

# Pivot to a wide format to obtain protocol overrides in columns
dfOverrides = dfOverrides.pivot(
    index=["armId", "patientId"], columns="scalarId", values="value"
)

print("\nPivotted scenario override table (first rows): \n")
display(dfOverrides.head())

# Merge dataframes to have scenario overrides as columns in the results dataframe
dfScalars = dfScalars.merge(dfOverrides, how='inner', on=['patientId', 'armId'])
print("\nScalar result table with overrides (first rows): \n")
display(dfScalars.head())

def parse_isoduration(str):
## Taken from https://stackoverflow.com/questions/36976138/is-there-an-easy-way-to-convert-iso-8601-duration-to-timedelta
## Parse the ISO8601 duration as years,months,weeks,days, hours,minutes,seconds
## Returns: time in days with 86400 seconds, converting units higher than days into seconds used on jinkō
## Examples: "PT1H30M15.460S", "P5DT4M", "P2WT3H"
    def get_isosplit(str, split):
        if split in str:
            n, str = str.split(split, 1)
        else:
            n = '0'
        return n.replace(',', '.'), str  # to handle like "P0,5Y"

    str = str.split('P', 1)[-1]  # Remove prefix
    s_yr, str = get_isosplit(str, 'Y')  # Step through letter dividers
    s_mo, str = get_isosplit(str, 'M')
    s_wk, str = get_isosplit(str, 'W')
    s_dy, str = get_isosplit(str, 'D')
    _, str    = get_isosplit(str, 'T')
    s_hr, str = get_isosplit(str, 'H')
    s_mi, str = get_isosplit(str, 'M')
    s_sc, str = get_isosplit(str, 'S')
    n_yr = float(s_yr) * 31557600   # approx seconds for year, month, week
    n_mo = float(s_mo) * 2629800
    n_wk = float(s_wk) * 604800
    dt = datetime.timedelta(days=float(s_dy), hours=float(s_hr), minutes=float(s_mi), seconds=float(s_sc)+n_yr+n_mo+n_wk)
    return dt.total_seconds()/86400

# Parse iso durations into time in weeks
dfScalars['endTime'] = dfScalars['endTime'].apply(parse_isoduration)
dfScalars['startTime'] = dfScalars['startTime'].apply(parse_isoduration)

# parse tiem in seconds (default) into time in weeks for observables containing 'time'
dfScalars.loc[dfScalars['reduceFunction'].str.contains('time'), 'value'] /= (86400)
dfScalars.loc[dfScalars['reduceFunction'].str.contains('time'), 'unit'] = 'days'

# Sort values by starting time
dfScalars = dfScalars.sort_values(by='startTime', ascending=True, na_position='first')

print("\n Final scalar result table with overrides (first rows): \n")
display(dfScalars.head())


### Plot the data

Finally we plot the time series data by facetting over scenario overrides. 

In [297]:
# Define constants for color mapping
COLOR_MAP = {
    'intravenous': '#1f77b4',  # blue
    'subcutaneous': '#ff7f0e',  # orange
}

# Define offsets for annotation positioning
OFFSET_MAP = {
    'intravenous': 0,
    'subcutaneous': 1,
}

# Define descriptions for summary measures
SUMMARY_MEASURE_MAP = {
    'auc': 'AUC from day 7 to day 10',
    'avg': 'Average from day 7 to day 10',
    'max': 'Maximum from day 7 to day 10',
}

# Define offsets for summary measures
MEASURE_OFFSET_MAP = {
    'auc': 1,
    'avg': 0,
    'max': 2,
}

# Lists of measures to summarize and plot
SUMMARY_MEASURES = ['avg', 'auc']
PLOT_MEASURES = ['max']

# Columns used for faceting and coloring the plot
FACET_COLUMN = 'fullDose'
FACET_ROW = 'primingDose'
COLOR_COLUMN = 'administrationMode'

# Base query to filter data
QUERY_BASE = 'generalScalarId == "Blood.Drug" and reduceFunction == "at"'

def create_line_plot(data):
    """Create a line plot using the provided data."""
    return px.line(
        data,
        x="endTime",
        y="value",
        facet_col=FACET_COLUMN,
        facet_row=FACET_ROW,
        color=COLOR_COLUMN,
        labels={
            "endTime": "Time [days]",
            "value": "Concentration of Drug in Blood [mg/ml]",
            "fullDose": "Full dose [mg]",
            "primingDose": "Priming dose [mg]",
            "administrationMode": "Administration mode",
        },
        height=600,
        color_discrete_map=COLOR_MAP
    )

def get_facet_indices(data, row):
    """Get facet indices for the given row."""
    col_idx = data[FACET_COLUMN].unique().tolist().index(row[FACET_COLUMN])
    row_idx = data[FACET_ROW].unique().tolist().index(row[FACET_ROW])
    return col_idx, row_idx

def add_summary_annotations(annotations, data, key, max_value):
    """Add summary measure annotations."""
    query_key = f'generalScalarId == "Blood.Drug" and reduceFunction == "{key}"'
    for i, row in data.query(query_key).iterrows():
        # Get indices for the facet grid
        col_idx, row_idx = get_facet_indices(data.query(QUERY_BASE), row)
        xref = f"x{col_idx + 1}"
        yref = f"y{row_idx + 1}"

        # Add the annotation
        annotations.append(dict(
            x=10,
            y=(max_value - max_value / 5 * OFFSET_MAP[row[COLOR_COLUMN]] - max_value / 10 * MEASURE_OFFSET_MAP[key]),
            xref=xref if col_idx != 0 else 'x',
            yref=yref if row_idx != 0 else 'y',
            text=f'{row[COLOR_COLUMN]}: {SUMMARY_MEASURE_MAP[key]} [{row["unit"]}]: {"%.2E" % row["value"]}',
            showarrow=False,
            font=dict(color=COLOR_MAP[row[COLOR_COLUMN]]),
            bgcolor="white",
            bordercolor="black",
            textangle=0,
            xanchor="right",
            yanchor="top"
        ))

def add_plot_annotations(fig, annotations, data, key):
    """Add plot measure points and annotations."""
    query_key = f'generalScalarId == "Blood.Drug" and reduceFunction == "{key}"'
    for i, row in data.query(query_key).iterrows():
        # Get indices for the facet grid
        col_idx, row_idx = get_facet_indices(data.query(QUERY_BASE), row)
        xref = f"x{col_idx + 1}"
        yref = f"y{row_idx + 1}"

        # Get the x and y values for the annotation
        x_value = float(data.query(
            f'generalScalarId == "Blood.Drug" and reduceFunction == "time-of-{key}" '
            f'and {COLOR_COLUMN} == "{row[COLOR_COLUMN]}" '
            f'and {FACET_COLUMN} == "{row[FACET_COLUMN]}" '
            f'and {FACET_ROW} == "{row[FACET_ROW]}"')['value'].iloc[0])
        y_value = row["value"]

        # Add a scatter plot point for the measure
        fig.add_trace(go.Scatter(
            x=[x_value],
            y=[y_value],
            mode='markers',
            marker_symbol="x",
            marker=dict(color=COLOR_MAP[row[COLOR_COLUMN]]),
            showlegend=False
        ), row=row_idx + 1, col=col_idx + 1)

        # Add an annotation for the measure
        annotations.append(dict(
            x=x_value,
            y=y_value,
            xref=xref if col_idx != 0 else 'x',
            yref=yref if row_idx != 0 else 'y',
            text=f'{row[COLOR_COLUMN]}: {SUMMARY_MEASURE_MAP[key]} [{row["unit"]}]:<br> {"%.2E" % y_value} at {x_value} days',
            showarrow=True,
            font=dict(color=COLOR_MAP[row[COLOR_COLUMN]]),
            bgcolor="white",
            bordercolor="black",
            textangle=0,
            xanchor="left",
            yanchor="bottom"
        ))

def plot_here(data):
    """Main function to create the plot and add annotations."""
    filtered_data = data.query(QUERY_BASE)
    fig = create_line_plot(filtered_data)

    # Get the maximum value in the filtered data for annotation positioning
    max_value = max(filtered_data['value'])
    annotations = []

    # Add annotations for summary and plot measures
    for key in SUMMARY_MEASURE_MAP.keys():
        if key in SUMMARY_MEASURES:
            add_summary_annotations(annotations, data, key, max_value)
        if key in PLOT_MEASURES:
            add_plot_annotations(fig, annotations, data, key)

    # Assign annotations to the plot
    fig['layout']['annotations'] = annotations
    fig.show()

# Call the main function with the dataframe `dfScalars`
plot_here(dfScalars)



In [298]:
# Define a color map for different administration modes
color_map = {
    'intravenous': '#1f77b4',  # blue
    'subcutaneous': '#ff7f0e',  # orange
}

# Define an offset map for annotation handling
offset_map = {
    'intravenous': 0,
    'subcutaneous': 1,
}

# Define descriptions for various summary measures
summary_measure_map = {
    'end-minus-start': 'Value after three weeks minus baseline',
    'amplitude': 'Amplitude from baseline to week 3',
    'min': 'Minimum from baseline to week 3',
    'min-slope': 'Minimum slope from baseline to week 3',
    'max-slope': 'Maximum slope from baseline to week 3',
}

# Define offsets for summary measures
measure_offset_map = {
    'end-minus-start': 1,
    'amplitude': 0,
    'min-slope': 2,
    'max-slope': 3,
    'min': 4,
}

# Lists of measures to summarize and plot
summary_measures = ["min-slope", "max-slope", "amplitude", "end-minus-start"]
plot_measures = ['min']

# Columns used for faceting and coloring the plot
facet_column = 'fullDose'
facet_row = 'primingDose'
color_column = 'administrationMode'

# Base query to filter data for the main plot
query_here = 'generalScalarId == "Tumor.CancerCell" and reduceFunction == "at"'

# Create the main line plot using Plotly Express
fig = px.line(
    dfScalars.query(query_here),
    x="endTime",
    y="value",
    facet_col=facet_column,
    facet_row=facet_row,
    color=color_column,
    labels={
        "endTime": "Time [days]",
        "value": "Amount of Tumor.CancerCell [cell count]",
        "fullDose": "Full dose [mg]",
        "primingDose": "Priming dose [mg]",
        "administrationMode": "Administration mode",
    },
    height=600, color_discrete_map=color_map
)

# Determine the maximum and minimum values in the filtered data for annotation positioning
max_value = max(dfScalars.query(query_here)['value'])
min_value = min(dfScalars.query(query_here)['value'])

annotations = []

# Iterate over each summary measure to add annotations
for key in summary_measure_map.keys():
    query_here_key = f'generalScalarId == "Tumor.CancerCell" and reduceFunction == "{key}"'
    
    # Iterate over each row in the filtered data for the current summary measure
    for i, row in dfScalars.query(query_here_key).iterrows():
        # Determine the xref and yref for the correct subplot
        col_idx = dfScalars.query(query_here)[facet_column].unique().tolist().index(row[facet_column])
        row_idx = dfScalars.query(query_here)[facet_row].unique().tolist().index(row[facet_row])
        xref = f"x{col_idx+1}"
        yref = f"y{row_idx+1}"

        # Add annotations for summary measures
        if key in summary_measures:
            annotations.append(dict(
                x=35,
                y=(max_value - (max_value - min_value) / 5 * offset_map[row[color_column]] - 
                   (max_value - min_value) / 20 * measure_offset_map[key]),
                xref=xref if col_idx != 0 else 'x',  # Use 'x' for the first column subplot
                yref=yref if row_idx != 0 else 'y',  # Use 'y' for the first row subplot
                text=f'{row[color_column]}: {summary_measure_map[key]} [{row["unit"]}]: {"%.2E" % (row["value"])}',
                showarrow=False,
                font=dict(color=color_map[row[color_column]]),  # Use the color from the color map
                bgcolor="white",
                bordercolor="black",
                textangle=0,
                xanchor="left",
                yanchor="top"
            ))

        # Add annotations and plot points for specific measures
        if key in plot_measures:
            # Determine the x and y values for the plot points
            x_value = float(dfScalars.query(
                f'generalScalarId == "Tumor.CancerCell" and reduceFunction == "time-of-{key}" and '
                f'{color_column} == "{row[color_column]}" and {facet_column} == "{row[facet_column]}" and '
                f'{facet_row} == "{row[facet_row]}"')['value'].iloc[0])
            y_value = [row["value"]]

            # Add scatter plot points to the main plot
            fig.add_trace(go.Scatter(
                x=[x_value],
                y=y_value,
                mode='markers',
                marker_symbol="x",
                marker=dict(color=color_map[row[color_column]]),
                showlegend=False
            ), row=row_idx + 1, col=col_idx + 1)

            # Add annotations for the plot points
            annotations.append(dict(
                x=x_value,
                y=y_value[0],
                xref=xref if col_idx != 0 else 'x',  # Use 'x' for the first column subplot
                yref=yref if row_idx != 0 else 'y',  # Use 'y' for the first row subplot
                text=f'{row[color_column]}: {summary_measure_map[key]} [{row["unit"]}]:<br> {"%.2E" % y_value[0]} at {x_value} days',
                showarrow=True,
                font=dict(color=color_map[row[color_column]]),  # Use the color from the color map
                bgcolor="white",
                bordercolor="black",
                textangle=0,
                xanchor="left",
                yanchor="top",
                ay=50 - offset_map[row[color_column]] * 20,
                ax=20
            ))

# Assign the created annotations to the plot layout
fig['layout']['annotations'] = annotations

# Display the plot
fig.show()
