# Postprocessing/Visualizing scalar results using reduce functions

## Introduction

This cookbook will guide you through the creation of a simple visualization from an existing trial in jinko.  
In particular, you will be able to retrieve scalar results and plot them using plotly.  

Linked resources: [Jinko](https://jinko.ai/project/e0fbb5bb-8929-439a-bad6-9e12d19d9ae4?labels=24574ece-6bde-4d76-896a-187426965a51).

In [None]:
# Jinko specifics imports & initialization

import sys

sys.path.insert(0, "../lib")
import jinko_helpers as jinko

# Connect to Jinko (see README.md for more options)

jinko.initialize()

In [None]:
# Cookbook specifics imports

import io
import json
import pandas as pd
import plotly.express as px
import plotly.io as pio
import zipfile

# Cookbook specifics constants

# Fill the short Id of your Trial (ex: tr-EKRx-3HRt)
trialId = "tr-Fzt9-uO98"

## Let's use the API and plot the data

### Load the trial

In [None]:
if trialId is None:
    raise Exception("Please specify a Trial Id")
else:
    print(f"Using Trial ID: {trialId}")

# Convert short Id to coreItemId
try:
    coreItemId = jinko.getCoreItemId(trialId, 1)
except Exception as e:
    print(f"Failed to find corresponding trial, check the trialId")
    raise

# List all Trial versions (https://doc.jinko.ai/api/#/paths/core-v2-trial_manager-trial-status/post)
try:
    trialVersions = jinko.makeRequest(
        f'/core/v2/trial_manager/trial/{coreItemId["id"]}/status'
    ).json()
    print(f"Fetched {len(trialVersions)} versions for the trial.")
except Exception as e:
    print(f"Error fetching trial versions: {e}")
    raise

# Get the latest completed version
try:
    latestCompletedVersion = next(
        (item for item in trialVersions if item["status"] == "completed"), None
    )
    if latestCompletedVersion is None:
        raise Exception("No completed Trial version found")
    else:
        print(
            "Successfully fetched this simulation:\n",
            json.dumps(latestCompletedVersion, indent=1),
        )
        # Store the trial Id and the snapshot Id to use in the API requests
        simulationId = latestCompletedVersion["simulationId"]
        trialId = simulationId["coreItemId"]
        trialSnapshotId = simulationId["snapshotId"]
except Exception as e:
    print(f"Error processing trial versions: {e}")
    raise

### Display a results summary

In [None]:
# Retrieve results summary (https://doc.jinko.ai/api/#/paths/core-v2-trial_manager-trial-trialId--snapshots--trialIdSnapshot--results_summary/get)
response = jinko.makeRequest(
    "/core/v2/trial_manager/trial/%s/snapshots/%s/results_summary"
    % (trialId, trialSnapshotId),
    method="GET",
)
responseSummary = json.loads(response.content)

defaultScalars = ['__jinkoAllocationMiB.tmax', '__jinkoSolvingTime.tmax', 'SimulationTMax', 'SimulationTMin']

# Print a summary of the results content
print("Keys in the results summary:\n", list(responseSummary.keys()), "\n")
print("Available patients:\n", responseSummary["patients"], "\n")
print("Available arms:\n", responseSummary["arms"], "\n")
print(
    "Available scalars:\n",
    [scalar["id"] for scalar in responseSummary["scalars"]],
    "\n",
)
print(
    "Available cross-arm scalars:\n",
    [scalar["id"] for scalar in responseSummary["scalarsCrossArm"]],
    "\n",
)
print(
    "Available categorical parameters:\n",
    [scalar["id"] for scalar in responseSummary["categoricals"]],
    "\n",
)
print(
    "Available cross-arm categorical parameters:\n",
    [scalar["id"] for scalar in responseSummary["categoricalsCrossArm"]],
    "\n",
)

# Store the list of scenario descriptors to use them later
scenarioOverrides = [
    scalar["id"]
    for scalar in (responseSummary["scalars"] + responseSummary["categoricals"])
    if "ScenarioOverride" in scalar["type"]["labels"]
]
print("List of scenario overrides:\n", scenarioDescriptors, "\n")

# Store the list of scalars that are not scenario descriptors to use them later
resultScalars = [
    scalar["id"]
    for scalar in responseSummary["scalars"]
    if "Custom" in scalar["type"]["labels"]
]
print("List of result scalars:\n", resultScalars, "\n")

### Download scalars results data

In [None]:
# Retrieve scalar results (https://doc.jinko.ai/api/#/paths/core-v2-result_manager-scalars_summary/post)

# replace here by the scalar ids list you want
idsForScalars = {
    "scalars": resultScalars
    , "scenarioOverrides": scenarioOverrides
}
csvScalars = {}
def retrieve_scalars(scalar_type):
    try:
        print("Retrieving scalar results...")
        response = jinko.makeRequest(
            "/core/v2/result_manager/scalars_summary",
            method="POST",
            json={
                "select": idsForScalars[scalar_type],
                "trialId": latestCompletedVersion["simulationId"],
            },
        )
        if response.status_code == 200:
            print("Scalar results retrieved successfully.")
            archive = zipfile.ZipFile(io.BytesIO(response.content))
            filename = archive.namelist()[0]
            print(f"Extracted scalar results file: {filename}")
            csvScalars[scalar_type] = archive.read(filename).decode("utf-8")
        else:
            print(
                f"Failed to retrieve scalar results: {response.status_code} - {response.reason}"
            )
            response.raise_for_status()
    except Exception as e:
        print(f"Error during scalar results retrieval or processing: {e}")
        raise

retrieve_scalars("scalars")
retrieve_scalars("scenarioOverrides")

### Postprocess the data in a pandas dataframe

data us post processed using pandas library, and transform into a table that can easily be plotted. 

In [None]:
import re

# Load scalars into a dataframe
dfScalars = pd.read_csv(io.StringIO(csvScalars["scalars"]))
print("\nRaw scalar data (first rows):\n")

# All currently available reduce functions
reduceFuncs = ["time-of-max", "time-of-min", "max-slope", "min-slope", "amplitude", "end-minus-start", "at", "auc", "avg", "max", "min"]
reduceFuncRegex = '('+'|'.join(reduceFuncs)+')'

#Postprocessing - Split the scalarId column into pieces depending on the applicable pattern
dfScalars[['generalScalarId','reduceFunction', "startTime"]] = dfScalars['scalarId'].str.extract(r'(.*?)-' + reduceFuncRegex + '(?:-from-|-)(.*)')

def extract_endtime(value):
    match = re.match(r'(?:.*?-to-(.*)|^(.*)$)', value)
    if match:
        return match.group(1) or match.group(2)
    return None

def extract_starttime(value):
    match = re.match(r'(?:(.*?)-to-.*|^(.*)$)', value)
    if match:
        return match.group(1) or match.group(2)
    return None

dfScalars['endTime'] = dfScalars['startTime'].apply(extract_endtime)

dfScalars['startTime'] = dfScalars['startTime'].apply(extract_starttime)

display(dfScalars.head())

dfDescriptors = pd.read_csv(io.StringIO(csvScalars["scenarioOverrides"]))

# Pivot to a wide format to obtain protocol overrides in columns
dfDescriptors = dfDescriptors.pivot(
    index=["armId", "patientId"], columns="scalarId", values="value"
)

print("\nPivotted scenario override table (first rows): \n")
display(dfDescriptors.head())

dfScalars = dfScalars.merge(dfDescriptors, how='inner', on=['patientId', 'armId'])

display(dfScalars.head())

import datetime

def parse_isoduration(str):
## Taken from https://stackoverflow.com/questions/36976138/is-there-an-easy-way-to-convert-iso-8601-duration-to-timedelta
## Parse the ISO8601 duration as years,months,weeks,days, hours,minutes,seconds
## Returns: time in weeks containing exactly 7 days of 86400 seconds
## Examples: "PT1H30M15.460S", "P5DT4M", "P2WT3H"
    def get_isosplit(str, split):
        if split in str:
            n, str = str.split(split, 1)
        else:
            n = '0'
        return n.replace(',', '.'), str  # to handle like "P0,5Y"

    str = str.split('P', 1)[-1]  # Remove prefix
    s_yr, str = get_isosplit(str, 'Y')  # Step through letter dividers
    s_mo, str = get_isosplit(str, 'M')
    s_wk, str = get_isosplit(str, 'W')
    s_dy, str = get_isosplit(str, 'D')
    _, str    = get_isosplit(str, 'T')
    s_hr, str = get_isosplit(str, 'H')
    s_mi, str = get_isosplit(str, 'M')
    s_sc, str = get_isosplit(str, 'S')
    n_yr = float(s_yr) * 31557600   # approx seconds for year, month, week
    n_mo = float(s_mo) * 2629800
    n_wk = float(s_wk) * 604800
    dt = datetime.timedelta(days=float(s_dy), hours=float(s_hr), minutes=float(s_mi), seconds=float(s_sc)+n_yr+n_mo+n_wk)
    return dt.total_seconds()/86400/7


dfScalars['endTime'] = dfScalars['endTime'].apply(parse_isoduration)

dfScalars['startTime'] = dfScalars['startTime'].apply(parse_isoduration)

dfScalars.loc[dfScalars['reduceFunction'].str.contains('time'), 'value'] /= (86400*7)

dfScalars = dfScalars.sort_values(by='startTime', ascending=True, na_position='first')

display(dfScalars)


### Plot the data

Finally we plot the time series data by facetting over scenario overrides. 

In [None]:
# adapt the plot to your ids

fig = px.line(
    dfScalars.query('generalScalarId == "Tumor.CancerCell" and reduceFunction == "time-of-max"'),
    x="endTime",
    y="value",
    facet_col="fullDose.tmin",
    facet_row="primingDose.tmin",
    color="administrationMode",
    labels={
        "endTime": "Time (weeks)",
        "value": "Time of max within each week for Tumor.CancerCell (weeks)",
        "fullDose.tmin": "Full dose (mg)",
        "primingDose.tmin": "Priming dose (mg)",
        "administrationMode": "Administration",
    },
    height=600,
)
fig.show()


fig = px.line(
    dfScalars.query('generalScalarId == "Tumor.CancerCell" and reduceFunction == "at"'),
    x="endTime",
    y="value",
    facet_col="fullDose.tmin",
    facet_row="primingDose.tmin",
    color="administrationMode",
    labels={
        "endTime": "Time (weeks)",
        "value": "Amount of Tumor.CancerCell (cell count)",
        "fullDose.tmin": "Full dose (mg)",
        "primingDose.tmin": "Priming dose (mg)",
        "administrationMode": "Administration",
    },
    height=600,
)

fig.show()