In [1]:
import numpy as np
import pandas as pd
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt

data_set_col = "data_set"
threshold = 0.02875984678242102 # full model
data_location = 'df_with_predictions.feather'
plot_save_location = '/'
n_boot_iters = 50_000

In [2]:
%run ./fig5a_helper.py

In [3]:
# Set plot parameters
plt.rcParams["font.size"] = 10
plt.rcParams['lines.markersize'] = 7
plt.rcParams['figure.facecolor'] = (1,1,1,0)
plt.rcParams['axes.facecolor']= (1,1,1,1)

In [4]:
df = pd.read_feather(data_location)

In [5]:
# Filter out MRF samples and failing samples
df = df.loc[(df.pass_all_cv == True) & (df.is_nmrf == True)].reset_index(drop = True).copy()

In [6]:
df["train"] = df[data_set_col] == "train"
df["test"] = df[data_set_col] == "test"

# NMRF AMA ROC 

In [7]:
# LR+ for NMRF AMA
tn, fp, fn, tp = confusion_matrix(df.loc[(df['test'] == True) & (df.is_age_gte35 == True), 'is_green_triangle'].astype(int), df.loc[(df['test'] == True) & (df.is_age_gte35 == True), 'pe_pred_class_full'].astype(int)).ravel()
specificity = tn / (tn+fp)
sensitivity = tp / (tp + fn)
sensitivity / (1 - specificity)

3.4937611408199643

In [8]:
# LR+ for NMRF AMA using USPSTF
df['is_pred_uspstf'] = df['uspstf_risk_level'].isin(["high_pe_risk_1high", "high_pe_risk_2mod"])
tn, fp, fn, tp = confusion_matrix(df.loc[(df['test'] == True) & (df.is_age_gte35 == True), 'is_green_triangle'].astype(int), df.loc[(df['test'] == True) & (df.is_age_gte35 == True), 'is_pred_uspstf'].astype(int)).ravel()
specificity = tn / (tn+fp)
sensitivity = tp / (tp + fn)
sensitivity / (1 - specificity)

1.1684053651266764

# DGA <= 35 weeks, all non-early GT cases reclassified as control

In [9]:
df['preterm_pe_35'] = False
df.loc[(df['is_green_triangle'] == True) & (df['delivery_ga'] <= 35), 'preterm_pe_35'] = True
df['preterm_pe_35'].value_counts()

False    6598
True       47
Name: preterm_pe_35, dtype: int64

# Combined plot

In [10]:
# Initialize data set for all and AMA samples to put on same plot. Primarily, initialize subgroup columns.
df_gt = df.loc[df['test'] == True][["pe_pred_prob_full", "is_green_triangle", "is_age_gte35", "preterm_pe_35"]].reset_index(drop=True).copy()
df_gt['plot_all'] = True
df_gt['plot_ama'] = df_gt['is_age_gte35']
df_gt['plot_preterm35'] = False

# Map preterm_pe_35 into is_green_triangle for the <= 35 thresholding plot (required for plot_roc_by_model_subset).
# Set subgroup columns such that only preterm35 uses this alternative definition.
df_preterm35 = df_gt.copy()
df_preterm35["is_green_triangle"] = df_preterm35["preterm_pe_35"]
df_preterm35['plot_all'] = False
df_preterm35['plot_ama'] = False
df_preterm35['plot_preterm35'] = True

# Combine
df_combined = pd.concat([df_gt, df_preterm35], axis = 0).reset_index(drop = True).copy()

  df_combined = pd.concat([df_gt, df_preterm35], axis = 0).reset_index(drop = True).copy()


In [None]:
cm = plt.get_cmap('tab10')
cm_colors = [cm.colors[i] for i in range(3)]
plot_roc_by_model_subset(
    model_outputs=pd.DataFrame(
        {'y_true': df_combined['is_green_triangle'],
        'y_prob_full': df_combined['pe_pred_prob_full']}
    ),
    prob_thresholds={'full': threshold},
    subgroup_filters= df_combined[['plot_all', 'plot_preterm35', 'plot_ama']],
    n_boot_iters = n_boot_iters,
    color_dict = {'plot_all': cm_colors[0],
                 'plot_preterm35': cm_colors[2],
                 'plot_ama': cm_colors[1]},
    linestyles ={'full': '-'},
    mark_sensitivity_thres=None,
    mark_specificity_thres=None,
    sens_spec_thresholds = None,
    title = "ROC: performance in NMRF samples, with key subgroups"
)
plt.savefig(plot_save_location + "roc_nmrf_combined_all_preterm35_ama.svg", format="svg", bbox_inches = "tight")
plt.savefig(plot_save_location + "roc_nmrf_combined_all_preterm35_ama.pdf", format="pdf", bbox_inches = "tight")

  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_plus / lr_minus,
  "DOR": lr_pl