In [7]:
import os
import sys

sys.path.insert(0, '..')

import pyarrow.parquet as pq

from pysalient import io as io
from pysalient import visualisation as vis
from pysalient.evaluation import evaluation

In [2]:
sample_data_path = os.path.join("data", "anonymised_sample.parquet")
# count rows
table = pq.read_table(sample_data_path)
print(f"Number of rows: {table.num_rows}")
# print column names
print(table.column_names)
# Convert the 'true_label' column to a pandas Series
true_label_series = table["true_label"].to_pandas()

# Count the number of true labels (1)
true_count = (true_label_series == 1).sum()

print(f"Number of true labels (1): {true_count}")

# Convert the table to a pandas DataFrame for easier grouping
df = table.to_pandas()

# show table
print(df.head(5))


grouped = df.groupby("encounter_id")

# Count the number of unique groups (encounters)
num_groups = df["encounter_id"].nunique()
print(f"Number of unique encounter groups: {num_groups}")

# Calculate the sum of 'true_label' for each group
group_sums = grouped["true_label"].sum()

# Count how many groups have at least one true positive (sum > 0)
groups_with_positives = (group_sums > 0).sum()
print(
    f"Number of encounter groups with at least one true positive: {groups_with_positives}"
)

Number of rows: 11794
['encounter_id', 'event_timestamp', 'true_label', 'prediction_probability']
Number of true labels (1): 1468
                                        encounter_id  event_timestamp  \
0  bf969c647159506779e9776a62f087139bd8662c04c5be...              2.0   
1  bf969c647159506779e9776a62f087139bd8662c04c5be...              4.0   
2  bf969c647159506779e9776a62f087139bd8662c04c5be...              5.0   
3  bf969c647159506779e9776a62f087139bd8662c04c5be...              6.0   
4  bf969c647159506779e9776a62f087139bd8662c04c5be...              7.0   

   true_label  prediction_probability  
0           0                0.026304  
1           0                0.053344  
2           0                0.060045  
3           0                0.039508  
4           0                0.041009  
Number of unique encounter groups: 100
Number of encounter groups with at least one true positive: 50


In [3]:
# Define the path relative to the project root
# Assuming the notebook is run from the project root or examples/ directory
sample_data_path = os.path.join("data", "anonymised_sample.parquet")

assigned_table_events = None

if os.path.exists(sample_data_path):
    # Use the actual column names identified during inspection directly
    # Ensure these names actually exist based on the printout above!
    assigned_table_events = io.load_evaluation_data(
        source=sample_data_path,
        y_proba_col="prediction_probability",
        y_label_col="true_label",
        aggregation_cols=None,
        timeseries_col="event_timestamp",
        # We don't provide task_col or model_col from the source
        # assign_task_name="AKI",  # Assign this name to the new 'task' column
        # assign_model_name="LogRegress",  # Assign this name to the new 'model' column
    )

    print("\nSuccessfully loaded data with assigned names (Example 1):")
    print(assigned_table_events.schema)
    print(f"\nNumber of rows: {assigned_table_events.num_rows}")

    # Display first few rows to verify new columns
    print("\nFirst 5 rows (with added 'task' and 'model' columns):")
    print(assigned_table_events.slice(0, 5).to_pandas())

else:
    print(
        f"Skipping data loading (Example 1) as file was not found: {sample_data_path}"
    )



Successfully loaded data with assigned names (Example 1):
encounter_id: string
event_timestamp: double
true_label: int64
prediction_probability: float
-- schema metadata --
pysalient.io.y_proba_col: 'prediction_probability'
pysalient.io.y_label_col: 'true_label'
pysalient.io.timeseries_col: 'event_timestamp'
pysalient.io.aggregation_cols: '[]'

Number of rows: 11794

First 5 rows (with added 'task' and 'model' columns):
                                        encounter_id  event_timestamp  \
0  bf969c647159506779e9776a62f087139bd8662c04c5be...              2.0   
1  bf969c647159506779e9776a62f087139bd8662c04c5be...              4.0   
2  bf969c647159506779e9776a62f087139bd8662c04c5be...              5.0   
3  bf969c647159506779e9776a62f087139bd8662c04c5be...              6.0   
4  bf969c647159506779e9776a62f087139bd8662c04c5be...              7.0   

   true_label  prediction_probability  
0           0                0.026304  
1           0                0.053344  
2           0   

In [4]:
# Define evaluation parameters
eval_modelid = "LogRegress_01"  # Use a generic ID as model wasn't assigned here
eval_filter = "ExampleFilterDummy"  # Describe the data subset
eval_thresholds = (0.1, 0.9, 0.1)  # Range: 0.1, 0.2, ..., 0.9
# eval_thresholds=[0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9] # Example: List of thresholds

# Run the evaluation
evaluation_results = evaluation(
    data=assigned_table_events,  # Use the table loaded with col_map
    modelid=eval_modelid,
    filter_desc=eval_filter,
    thresholds=eval_thresholds,
    decimal_places=3,  # Control rounding of output floats # check that -1 is no rounding.
    calculate_au_ci=True,  # Enable AU CI calculation (uses bootstrap)
    calculate_threshold_ci=True,
    threshold_ci_method="bootstrap",  # Method for threshold CIs (ignored if calculate_threshold_ci=False)
    ci_alpha=0.05,  # 95% CI
    bootstrap_seed=42,  # For reproducible CIs
    bootstrap_rounds=500,  # Fewer rounds for notebook speed
    force_threshold_zero=True,
    verbosity=1,
)

In [8]:
# Visualisation
styled_results = vis.format_evaluation_table(
    evaluation_results, decimal_places=3, ci_column=False
)
display(styled_results)

Unnamed: 0,modelid,filter_desc,threshold,time_to_first_alert_value,time_to_first_alert_unit,AUROC,AUPRC,Prevalence,Sample_Size,Label_Count,TP,TN,FP,FN,PPV,Sensitivity,Specificity,NPV,Accuracy,F1_Score
0,LogRegress_01,ExampleFilterDummy,0.0,,,0.651,0.294,0.124,11794,1468,1468,0,10326,0,0.124,1.0,0.0,0.0,0.124,0.221
1,LogRegress_01,ExampleFilterDummy,0.1,,,0.651,0.294,0.124,11794,1468,267,9835,491,1201,0.352,0.182,0.952,0.891,0.857,0.24
2,LogRegress_01,ExampleFilterDummy,0.2,,,0.651,0.294,0.124,11794,1468,92,10309,17,1376,0.844,0.063,0.998,0.882,0.882,0.117
3,LogRegress_01,ExampleFilterDummy,0.3,,,0.651,0.294,0.124,11794,1468,39,10317,9,1429,0.812,0.027,0.999,0.878,0.878,0.051
4,LogRegress_01,ExampleFilterDummy,0.4,,,0.651,0.294,0.124,11794,1468,27,10320,6,1441,0.818,0.018,0.999,0.877,0.877,0.036
5,LogRegress_01,ExampleFilterDummy,0.5,,,0.651,0.294,0.124,11794,1468,27,10321,5,1441,0.844,0.018,1.0,0.877,0.877,0.036
6,LogRegress_01,ExampleFilterDummy,0.6,,,0.651,0.294,0.124,11794,1468,23,10326,0,1445,1.0,0.016,1.0,0.877,0.877,0.031
7,LogRegress_01,ExampleFilterDummy,0.7,,,0.651,0.294,0.124,11794,1468,16,10326,0,1452,1.0,0.011,1.0,0.877,0.877,0.022
8,LogRegress_01,ExampleFilterDummy,0.8,,,0.651,0.294,0.124,11794,1468,4,10326,0,1464,1.0,0.003,1.0,0.876,0.876,0.005
9,LogRegress_01,ExampleFilterDummy,0.9,,,0.651,0.294,0.124,11794,1468,0,10326,0,1468,0.0,0.0,1.0,0.876,0.876,0.0
