In [2]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scipy.stats as stats
import pickle
import seaborn as sns
pd.set_option('display.max_columns', None)
import statsmodels.formula.api as sf
from sklearn.metrics import roc_auc_score
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import auc
from sklearn import metrics

In [3]:
import plotly.graph_objects as go


In [4]:
import dash
import dash_html_components as html
import dash_core_components as dcc
from dash.dependencies import Input, Output
import pathlib


In [5]:
with open('./models/prod_RF1.pickle','rb') as f:
    model,var_sel_ml,var_columns,var_risk_ATC7,var_risk_DGN3,df_Esummary=pickle.load(f)
    f.close()

In [19]:
with open('./plots/Sankey_4th_version.pickley_4th_version.pickle','rb') as f:
    plot_sankey=pickle.load(f)
    f.close()

In [6]:
vars_meds={'liek1':'B01AA03',
          'liek2':'B01AC06',
           'liek3':'B05BB01'}

vars_meds={a:a for a in var_risk_ATC7}

In [7]:
vars_dgn={'primary arterial hypertension':'I10',
          'Chronic ischemic heart disease':'I25'}

vars_dgn={a:a for a in var_risk_DGN3}

In [8]:
vars_lifestyle={'Smoking':'Smoking','Obesity':'Obesity','Unemployed in last 12 months':'Unemployed'}

In [9]:
vars_sex=['Male','Female']

In [10]:
vars_district=['BB','HE','BA']

vars_district=var_columns[var_columns.str[0:5]=='okres'].str[12:].tolist()

In [11]:
var_age=80;var_sex='Female';var_meds=['C01DX12','H02AB07'];var_dgn=['Z13','H36'];var_district='Banská Bystrica'

In [12]:
def get_patient_risk(var_age,var_sex,var_meds,var_dgn,var_district):
    data_patient=pd.DataFrame(columns=var_columns)
    data_patient.loc[0,:]=0
    data_patient

    age_kat=pd.cut([var_age],bins=list(range(0,110,10)),right=False).astype(str)[0]
    age_kat='vek_kat10_'+age_kat
    data_patient.loc[0,age_kat]=1

    if var_sex=='Female':
        data_patient.loc[0,'pohlavie_Z']=1
    if var_sex=='Male':
        data_patient.loc[0,'pohlavie_M']=1

    data_patient.loc[0,var_meds]=1
    data_patient.loc[0,var_dgn]=1

    if len(var_district)!=0:
        data_patient.loc[0,'okres_nazov_'+var_district]=1
    
    risk=model.predict_proba(data_patient)[0][1]
    
    if var_sex=='Female':
        group_risk=df_Esummary.loc[age_kat[10:],'Z']
    if var_sex=='Male':
        group_risk=df_Esummary.loc[age_kat[10:],'M']

    return risk,group_risk

In [22]:

app = dash.Dash(
    __name__,
    meta_tags=[{"name": "viewport", "content": "width=device-width, initial-scale=1"}],
)
app.title = "CKD Risk App"

server = app.server

app_color = {"graph_bg": "#082255", "graph_line": "#007ACE"}

app.layout = html.Div(
    [
        # header
        html.Div(
            [
                html.Div(
                    [
                        html.H4("CHRONIC KIDNEY DISEASE", className="app__header__title"),
                        html.P(
                            "This app calculates risk score for chronic kidney disease based of medical history and demographics.",
                            className="app__header__title--grey",
                        ),
                    ],
                    className="app__header__desc",
                ),
                html.Div(
                    [
                        html.A(
                            html.Button("SOURCE CODE", className="link-button"),
                            href="https://github.com/plotly/dash-sample-apps/tree/main/apps/dash-wind-streaming",
                        ),
                        html.A(
                            html.Img(
                                src=app.get_asset_url("dash-new-logo.png"),
                                className="app__menu__img",
                            ),
                            href="https://www.vszp.sk/",
                        ),
                    ],
                    className="app__header__logo",
                ),
            ],
            className="app__header",
        ),
        html.Div(
            [
                # wind speed
                html.Div(
                    [
                        html.Div(
                            [html.H6("ML model risk results for people outside nephrology care", className="graph__title")]
                        ),
                        dcc.Graph(
                            id="risk-results",
                            figure=dict(
                                layout=dict(
                                    plot_bgcolor=app_color["graph_bg"],
                                    paper_bgcolor=app_color["graph_bg"],
                                )
                            ),
                        ),
                        dcc.Graph(
                            id="plot-sankey",
                            figure=dict(
                                layout=dict(
                                    plot_bgcolor=app_color["graph_bg"],
                                    paper_bgcolor=app_color["graph_bg"],
                                )
                            ),
                        ),
                    ],
                    className="two-thirds column wind__speed__container",
                ),
                html.Div(
                    [
                        # histogram
                        html.Div(
                            [
                                html.Div(
                                    [
                                        html.H6(
                                            "Information about patient",
                                            className="graph__title",
                                        )
                                    ]
                                ),
                                html.Div(
                                    [html.H5("Age",className="var__title",),
                                        dcc.Input(id='vars_age', type='number', min=1, max=100, step=1,value=60),
                                    ],
                                    className="slider",
                                ),
                                html.Div(
                                    [html.H5("Sex",className="var__title",),
                                        dcc.RadioItems(
                                            id="vars_sex",
                                            options=[{'label':k,'value':k} for k in vars_sex],value='Female',
                                            labelStyle={'display': 'inline-block'}
                                        ),
                                    ],
                                    className="sex__container",
                                ),
                                html.Br(),

                                
                                html.Div(
                                    [html.H5("Prescription history",className="var__title",),
                                        dcc.Checklist(
                                            id="vars_meds",
                                            options=[{'label':k,'value':vars_meds[k]} for k in vars_meds],value=[],
                                            inputClassName="auto__checkbox",
                                            labelClassName="auto__label",
                                        ),

                                    ],
                                    className="auto__container",
                                ),
                                html.Br(),
                                html.Br(),
                                html.Div(
                                    [html.H5("Diagnosis history",className="var__title",),
                                        dcc.Checklist(
                                            id="vars_dgn",
                                            options=[{'label':k,'value':vars_dgn[k]} for k in vars_dgn],value=[],
                                            inputClassName="auto__checkbox",
                                            labelClassName="auto__label",
                                        ),
                                    ],
                                    className="auto__container",
                                ),
                                
                                html.Div(
                                    [html.H5("District",className="var__title",),
                                        dcc.Dropdown(
                                            id="vars_district",
                                            options=[{'label':k,'value':k} for k in vars_district],value='Banská Bystrica',

                                        ),
                                    ],
                                    className="category__container",
                                ),
                                html.Br(),
                                html.Div(
                                    [html.H5("Lifestyle",className="var__title",),
                                        dcc.Checklist(
                                            id="vars_lifestyle",
                                            options=[{'label':k,'value':vars_lifestyle[k]} for k in vars_lifestyle],value=[],
                                            inputClassName="auto__checkbox",
                                            labelClassName="auto__label",
                                        ),

                                    ],
                                    className="auto__container",
                                ),
                                
                                
                                dcc.Graph(
                                    id="wind-histogram",
                                    figure=dict(
                                        layout=dict(
                                            plot_bgcolor=app_color["graph_bg"],
                                            paper_bgcolor=app_color["graph_bg"],
                                        )
                                    ),
                                ),
                            ],
                            className="graph__container first",
                        ),
                        # wind direction
                        html.Div(
                            [
                                html.Div(
                                    [
                                        html.H6(
                                            "Patient journey forecast", className="graph__title"
                                        )
                                    ]
                                ),
                                dcc.Graph(
                                    id="wind-direction",
                                    figure=dict(
                                        layout=dict(
                                            plot_bgcolor=app_color["graph_bg"],
                                            paper_bgcolor=app_color["graph_bg"],
                                        )
                                    ),
                                ),
                            ],
                            className="graph__container second",
                        ),
                    ],
                    className="one-third column histogram__direction",
                ),
            ],
            className="app__content",
        ),
    ],
    className="app__container",
)

@app.callback(
    Output("risk-results", "figure"),
    Output("plot-sankey", "figure"),
    Input("vars_age", "value"),
    Input("vars_sex", "value"),
    [Input("vars_meds", "value")],
    Input("vars_dgn", "value"),
    Input("vars_district", "value"),
    Input("vars_lifestyle", "value"),
)
def get_results(var_age,var_sex,var_meds,var_dgn,var_district,var_lifestyle):
    out_risk_pct,group_risk=get_patient_risk(var_age,var_sex,var_meds,var_dgn,var_district)
    if 'Obesity' in var_lifestyle:
        out_risk_pct=out_risk_pct*3
    if 'Smoking' in var_lifestyle:
        out_risk_pct=out_risk_pct*1.5

        
    out_risk=out_risk_pct/group_risk
    fig = go.Figure()
    fig.add_trace(go.Indicator(
        title='Relative risk score in age and sex category',
            value = out_risk,
            delta = {'reference': 30},
        mode= "number+gauge",
            gauge = {'axis': {'visible': False},
                    'bar': {'color': "grey"},
        
                'steps': [
            {'range': [0, 3], 'color': 'green'},
            {'range': [3, 10], 'color': 'yellow'},
                {'range': [10, 100], 'color': 'red'}],},
            domain = {'row': 0, 'column': 0}))
    fig.add_trace(go.Indicator(
        title='Probability of dialysis need in 3 years:',
        mode = "number",
        number={ 'suffix': " in 10000" },
        value = int(out_risk_pct*10000),
        domain = {'row': 0, 'column': 1}))
    
    fig.update_layout(
        grid = {'rows': 1, 'columns': 2, 'pattern': "independent"},
        paper_bgcolor =app_color["graph_bg"],
        font={'color':'#fff'},
    )
    plot_sankey.update_layout(
        paper_bgcolor =app_color["graph_bg"],
        font={'color':'#fff'},
    )
    return fig,plot_sankey

#
#    Input("vars_age", "value"),
#     Input("vars_district", "value"),
#     Input("vars_meds", "value"),
#     Input("vars_dgn", "value")


#def get_results(vars_sex,vars_age,vars_district,vars_meds,vars_dgn):


# Running the server
if __name__ == "__main__":
    app.run_server(debug=False, port=8050)

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/

Dash is running on http://127.0.0.1:8050/



INFO:__main__:Dash is running on http://127.0.0.1:8050/



 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)
INFO:werkzeug: * Running on http://127.0.0.1:8050/ (Press CTRL+C to quit)
127.0.0.1 - - [21/Nov/2021 10:39:17] "[37mGET / HTTP/1.1[0m" 200 -
INFO:werkzeug:127.0.0.1 - - [21/Nov/2021 10:39:17] "[37mGET / HTTP/1.1[0m" 200 -
127.0.0.1 - - [21/Nov/2021 10:39:18] "[37mGET /_dash-layout HTTP/1.1[0m" 200 -
INFO:werkzeug:127.0.0.1 - - [21/Nov/2021 10:39:18] "[37mGET /_dash-layout HTTP/1.1[0m" 200 -
127.0.0.1 - - [21/Nov/2021 10:39:18] "[37mGET /_dash-dependencies HTTP/1.1[0m" 200 -
INFO:werkzeug:127.0.0.1 - - [21/Nov/2021 10:39:18] "[37mGET /_dash-dependencies HTTP/1.1[0m" 200 -
127.0.0.1 - - [21/Nov/2021 10:39:18] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
INFO:werkzeug:127.0.0.1 - - [21/Nov/2021 10:39:18] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
127.0.0.1 - - [21/Nov/2021 10:39:52] "[37mPOST /_dash-update-component HTTP/1.1[0m" 200 -
INFO:werkzeug:127.0.0.1 - - [21/Nov/2021 10:39:52] "[37mPOS

In [33]:
data_patient

Unnamed: 0,pohlavie,okres_nazov,vek_kat10,praca_mesiace_2016,J01CR04,C03CA01,M04AA01,C02CA06,A11CC05,B02BA01,H02AB07,C01DX12,B03AE01,E10,N11,Z13,H36,E11,N28,H50,Z24,N40,I25
0,Z,Banská Bystrica,"[80, 90)",0,0,0,0,0,0,0,1,1,0,0,0,1,1,0,0,0,0,0,0


pohlavie,M,Z
vek_kat10,Unnamed: 1_level_1,Unnamed: 2_level_1
"[0, 10)",3.7e-05,2e-05
"[10, 20)",6.4e-05,7.6e-05
"[20, 30)",0.00015,8.8e-05
"[30, 40)",0.000287,0.000174
"[40, 50)",0.000527,0.000154
"[50, 60)",0.00114,0.000476
"[60, 70)",0.002106,0.000964
"[70, 80)",0.002637,0.001257
"[80, 90)",0.001789,0.000855
"[90, 100)",0.0,0.000155


## TESTINGvars_district

In [19]:
with open('./models/prod_RF1.pickle','rb') as f:
    model,var_sel_ml,var_columns,var_risk_ATC7,var_risk_DGN3,df_Esummary=pickle.load(f)
    f.close()

In [20]:
var_sel_ml

['pohlavie',
 'okres_nazov',
 'vek_kat10',
 'praca_mesiace_2016',
 'J01CR04',
 'C03CA01',
 'M04AA01',
 'C02CA06',
 'A11CC05',
 'B02BA01',
 'H02AB07',
 'C01DX12',
 'B03AE01',
 'E10',
 'N11',
 'Z13',
 'H36',
 'E11',
 'N28',
 'H50',
 'Z24',
 'N40',
 'I25']

In [35]:
vars_age=60
fig = go.Figure()
fig.add_trace(go.Indicator(
        value = vars_age,
        delta = {'reference': 100},
        gauge = {'axis': {'visible': False}},
        domain = {'row': 0, 'column': 0}))
fig.update_layout(
    grid = {'rows': 2, 'columns': 2, 'pattern': "independent"},
    template = {'data' : {'indicator': [{
        'title': {'text': "Speed"},
        'mode' : "number+delta+gauge",
        'delta' : {'reference': 90}}]
                         }})
fig.write_html("test.html")
  