In [None]:
#@title Step 1: Download necessary files from our [GitHub page](https://github.com/cmb-chula/fetal-artery-doppler-percentile)
#@markdown Hit the Play button and proceed

#@markdown Please follow the instruction if there is an "Error" message

### clone github
!rm -rf fetal-doppler-percentile
!git clone https://github.com/cmb-chula/fetal-pulm-artery-doppler-percentile

print('')
print('**********************')
print('Finished downloading required data. Please proceed to the next code block')

In [None]:
#@title Step 2: Setup Python library and necessary codes
#@markdown Hit the play button and proceed

### install scikit-learn 1.2.0 to maintain compatability with trained models
!pip install --upgrade scikit-learn==1.2.0 pandas==1.5.2

import pickle
import pandas as pd
import numpy as np

from ipywidgets import HBox, VBox, Label, BoundedIntText, IntText, FloatText, Text, Checkbox, Layout, Button, HTML
from math import erfc, sqrt

warning_sd_range = 1.96

active_slider_style = {'handle_color': '#FF7F50'}
inactive_slider_style = {'handle_color': '#999999'}

### load data into Python
models = pickle.load(open('fetal-pulm-artery-doppler-percentile/svr_models.pkl', 'rb'))
data_stats = pd.read_csv('fetal-pulm-artery-doppler-percentile/data_stats.csv', index_col = 0)
data_stats_time = pickle.load(open('fetal-pulm-artery-doppler-percentile/data_stats.pkl', 'rb'))

all_params = sorted(data_stats.index[data_stats['Info'] != 'Input'])

### trimming decimal digits
def trim_decimal_digits(value, n_decimal, mode = 'str'):
    if mode == 'str':
        temp = str(int(value * n_decimal * 10 + 0.5))
        return temp[:-n_decimal] + '.' + temp[-n_decimal:]
    else:
        return int(value * n_decimal * 10 + 0.5) / (n_decimal * 10.00)

### create input user interface
input_header1 = Label(value = '[1] Enter gestation age (GA), and if necessary, fetal heart rate (FHR)')
input_header2 = Label(value = 'Only parameters with *** require FHR as input')
input_header3 = Label(value = 'Warnings will be shown if input values are too high or too low compared to our database')

ga_label = Label(value = 'GA:  ')
ga_w_label = Label(value = ' week ')
ga_d_label = Label(value = ' day ')

ga_w_input = BoundedIntText(value = 30, min = 20, max = 40, indent = False)
ga_w_input.layout.width = '80px'

ga_d_input = BoundedIntText(value = 0, min = 0, max = 6, indent = False)
ga_d_input.layout.width = '80px'

fhr_label = Label(value = 'FHR (bpm):  ')
fhr_input = IntText(value = 145, indent = False)
fhr_input.layout.width = '80px'

fhr_warning = HTML(value = '', indent = False, )
fhr_warning.layout.width = '400px'

input_labels = VBox([ga_label, fhr_label], layout = Layout(width = '120px', align_items = 'flex-end'))
input_ga_ui = HBox([ga_w_input, ga_w_label, ga_d_input, ga_d_label])
input_fhr_ui = HBox([fhr_input, fhr_warning])
input_ui = VBox([input_ga_ui, input_fhr_ui])
input_ui = HBox([input_labels, input_ui])
input_ui = VBox([input_header1, input_header2, input_header3, input_ui])

### create output user interface
labels = {}
inputs = {}
checkboxes = {}
textfields = {}
warnings = {}

for p in all_params:
    mean_p = data_stats.loc[p, 'Mean']
    sd_p = data_stats.loc[p, 'SD']
    
    if not pd.isna(data_stats.loc[p, 'Unit']):
        description = p + ' (' + data_stats.loc[p, 'Unit'] + ')'
    else:
        description = p
        
    if data_stats.loc[p, 'Info'] == 'GA+FHR':
        description = '***' + description
    
    labels[p] = Label(value = description + ': ')
    inputs[p] = FloatText(value = trim_decimal_digits(mean_p, 1, mode = 'float'), disabled = True, indent = False)
    inputs[p].layout.width = '80px'
    
    checkboxes[p] = Checkbox(value = False, indent = False)
    checkboxes[p].layout.width = '20px'
    checkboxes[p].name = p ## Use name to define relationship to other UI elements
    
    textfields[p] = Text(value = '', disabled = True, indent = False)
    textfields[p].layout.width = '70px'
    
    warnings[p] = HTML(value = '', indent = False)
    warnings[p].layout.width = '400px'

output_header1 = Label(value = '[2] Select parameter(s) of interest, enter their values, and click the "Predict" button')
run_button = Button(description = 'Predict', tooltip = 'Click to get estimated percentiles')

output_labels = VBox([labels[p] for p in all_params], layout = Layout(width = '120px', align_items = 'flex-end'))
output_inputs = VBox([inputs[p] for p in all_params], layout = Layout(width = '100px'))

output_cb_labels = VBox([Label(value = 'Selected: ') for p in all_params], layout = Layout(width = '70px'))
output_cbs = VBox([checkboxes[p] for p in all_params], layout = Layout(width = '50px'))

output_tf_labels = VBox([Label(value = 'Percentile: ') for p in all_params], layout = Layout(width = '70px'))
output_tfs = VBox([textfields[p] for p in all_params], layout = Layout(width = '100px'))

output_warnings = VBox([warnings[p] for p in all_params], layout = Layout(width = '450px'))

output_ui = HBox([output_labels, output_inputs, output_cb_labels, output_cbs, output_tf_labels, output_tfs, output_warnings])
output_ui = VBox([output_header1, output_ui, run_button])

### define interaction between user interface elements
def gen_warning(m_name, mode):
    if mode == 'high':
        return f"<b><font color='red'>{'Warning! ' + m_name + ' value is too HIGH, please check your input'}</b>"
    else:
        return f"<b><font color='red'>{'Warning! ' + m_name + ' value is too LOW, please check your input'}</b>"

def toggle_input(sender): ## from checkboxes
    p = sender.owner.name
    
    if sender.owner.value == True:
        inputs[p].disabled = False
    else:
        inputs[p].disabled = True
        textfields[p].value = ''

def reset_warning():
    fhr_warning.value = ''
    
    for p in all_params:
        warnings[p].value = ''

def toggle_warning():
    ga = float(ga_w_input.value) + float(ga_d_input.value) / 7
    fhr = float(fhr_input.value)
    out_of_range_flags = {}
    
    for p in all_params:
        if checkboxes[p].value == True:
            measurement = inputs[p].value
            
            if data_stats.loc[p, 'Info'] == 'GA+FHR': ## check FHR
                if (fhr - data_stats_time.loc[ga, 'FHR Mean']) / data_stats_time.loc[ga, 'FHR SD'] >= warning_sd_range:
                    fhr_warning.value = gen_warning('FHR', 'high')
                elif (fhr - data_stats_time.loc[ga, 'FHR Mean']) / data_stats_time.loc[ga, 'FHR SD'] <= -warning_sd_range:
                    fhr_warning.value = gen_warning('FHR', 'low')
            
            if not data_stats.loc[p, 'Info'] == 'CONS':
                if (measurement - data_stats_time.loc[ga, p + ' Mean']) / data_stats_time.loc[ga, p + ' SD'] >= warning_sd_range:
                    warnings[p].value = gen_warning(p, 'high')
                    out_of_range_flags[p] = 1
                elif (measurement - data_stats_time.loc[ga, p + ' Mean']) / data_stats_time.loc[ga, p + ' SD'] <= -warning_sd_range:
                    warnings[p].value = gen_warning(p, 'low')
                    out_of_range_flags[p] = -1
            else:
                if (measurement - data_stats.loc[p, 'Mean']) / data_stats.loc[p, 'SD'] >= warning_sd_range:
                    warnings[p].value = gen_warning(p, 'high')
                elif (measurement - data_stats.loc[p, 'Mean']) / data_stats.loc[p, 'SD'] <= -warning_sd_range:
                    warnings[p].value = gen_warning(p, 'low')
                    
    return out_of_range_flags

def process_input(): ## GA, Log GA, GA^2, FHR, Log FHR, FHR^2
    ga = float(ga_w_input.value) + float(ga_d_input.value) / 7
    ga_log = np.log(ga)
    ga2 = ga ** 2
    
    fhr = float(fhr_input.value)
    fhr_log = np.log(fhr)
    fhr2 = fhr ** 2
    
    ga_std = (ga - data_stats.loc['GA', 'Mean']) / data_stats.loc['GA', 'SD']
    ga_log_std = (ga_log - data_stats.loc['log GA', 'Mean']) / data_stats.loc['log GA', 'SD']
    ga2_std = (ga2 - data_stats.loc['GA^2', 'Mean']) / data_stats.loc['GA^2', 'SD']
    
    fhr_std = (fhr - data_stats.loc['FHR', 'Mean']) / data_stats.loc['FHR', 'SD']
    fhr_log_std = (fhr_log - data_stats.loc['log FHR', 'Mean']) / data_stats.loc['log FHR', 'SD']
    fhr2_std = (fhr2 - data_stats.loc['FHR^2', 'Mean']) / data_stats.loc['FHR^2', 'SD']
    
    return pd.DataFrame([[ga_std, ga_log_std, ga2_std, fhr_std, fhr_log_std, fhr2_std]], index = ['0'],
                        columns = ['GA', 'log GA', 'GA^2', 'FHR', 'log FHR', 'FHR^2'])
        
def predict_param(p, z_score, input_std):
    if data_stats.loc[p, 'Info'] == 'GA+FHR':
        input_merged = input_std.copy()
    else:
        input_merged = input_std.loc[:, ['GA', 'log GA', 'GA^2']].copy()
    
    input_merged[p] = [z_score]
    prediction = models[p].predict(input_merged)
    
    return prediction
    
def predict_all(sender): ## from button
    input_std = process_input()
    ga = float(ga_w_input.value) + float(ga_d_input.value) / 7
    
    reset_warning()
    out_of_range_flags = toggle_warning()
    
    for p in all_params:
        if checkboxes[p].value == True:
            z_score = (inputs[p].value - data_stats.loc[p, 'Mean']) / data_stats.loc[p, 'SD']
            
            if data_stats.loc[p, 'Info'] == 'CONS': ## constant model
                prediction = 0.5 * erfc(-z_score / sqrt(2)) * 100
            else:
                if p in out_of_range_flags:
                    if out_of_range_flags[p] == 1:
                        measurement = data_stats_time.loc[ga, p + ' SD'] * warning_sd_range + data_stats_time.loc[ga, p + ' Mean']
                        z_score = (measurement - data_stats.loc[p, 'Mean']) / data_stats.loc[p, 'SD']
                    else:
                        measurement = - data_stats_time.loc[ga, p + ' SD'] * warning_sd_range + data_stats_time.loc[ga, p + ' Mean']
                        z_score = (measurement - data_stats.loc[p, 'Mean']) / data_stats.loc[p, 'SD']
                
                prediction = predict_param(p, z_score, input_std)[0] * 100
            
            if prediction < 0:
                prediction = 0.01
            elif prediction > 100:
                prediction = 99.99
            
            textfields[p].value = trim_decimal_digits(prediction, 1, mode = 'str')

### link user interface elements to interaction
for p in all_params:
    checkboxes[p].observe(toggle_input)
    
run_button.on_click(predict_all)

print('')
print('**********************')
print('Finished updating packages and preparing user interface. Please proceed to the next code block')

In [None]:
#@title Step 3: Launch the predictor interface
#@markdown Hit the run button and follow the instruction on the user interface
main_ui = VBox([input_ui, output_ui])
display(main_ui)