In [2]:
import snowflake.connector
from snowflake.snowpark.functions import count, when, col, corr, array_cat, sum as sum_
import os
from snowflake.snowpark import Session
import pandas as pd

cnxn_params = {
    "user":'hartsingh',
    "authenticator":'externalbrowser',
    "account":'vaa16628',
    "region":'us-east-1',
    "warehouse":"ACORN_CDS_PROD_COMMERCIALANALYTICS",
    "database":"ACORN_CDS_PROD_BIOXCEL_ADHOC",
    "role":'ACORN_CDS_PRD_CA_APP'
}

session = Session.builder.configs(cnxn_params).create()

Initiating login request with your identity provider. A browser window should have opened for you to complete the login. If you can't see it, check existing browser windows, or your OS settings. Press CTRL+C to abort and try again...
Going to open: https://mdsol.okta.com/app/snowflake/exkugim5j5lsfZiBH0x7/sso/saml?SAMLRequest=lZJfb9owFMW%2FSuQ9J04CpMwCKlrUkYptqNBJ65ubOMHDf4KvQ2Cffk5SqvahlSb5wbLPuf75nju5PknhHZkBrtUURUGIPKYynXNVTtHj9s4fIw8sVTkVWrEpOjNA17MJUCkqMq%2FtTj2wQ83Aeq6QAtJdTFFtFNEUOBBFJQNiM7KZf1%2BROAhJZbTVmRbojeVzBwVgxjrCiyUH7vB21lYE46ZpgmYQaFPiOAxDHH7FTtVKvlz0J%2FenD%2FQRDoet3imcfP3CdsNV34LPsJ57EZDldrv21z83W%2BTNL6i3WkEtmdkwc%2BQZe3xY9QDgCI6URkkSj4MafEbB%2BlEASjeFoHuWaVnV1pUN3A4XLMdCl9z9PF1MUbXneZwM7%2Be7076459%2BK9KpZHrheJTT%2FW5VDKcej9HksfpuDrc%2FjDHm%2FLtHGbbQpQM1S1QZq3VEYD%2Fxw6NY2jsgoIdEgGA2jJ%2BQtXKBcUds5L9QyBy0Cvbe0Q6NVhV%2BpMTvt65LL0Z%2BRgOKJ3yzD0xUG0LhNF%2FUDQ7rnzey%2F2zDBb%2B0vw%2FfD5ZEu1lrw7OzdaSOp%2FTiuKIi6E577RSclTFIu5nluGICLTQjd3BpGrZtxa2qG8Kx%2F9f2Uz%2F4B&RelaySt

In [3]:
%load_ext autoreload
%autoreload 2

# Objective

In [None]:
# accurately estimating mmse can be helpful to fill in scores for patients we do not have a score for 
# this score helps us estimate the severity of their condition
# lasso can be good hear since we also what to see directionality of features

# Pull raw data from snowflake

In [36]:
df_sql = session.sql(f"""
                     with scores as(
                           select distinct patient_key, result_date, value
                           from dbo.TBLDRG_EHR_RESULTS 
                           where LOINC in ('72107-6', '72133-2', '72172-0') and try_to_number(value) is not NULL and try_to_number(value) <= 30
                           order by patient_key, result_date
                        ),
                         top10_ndc as(
                           select top 10 drug_ndc, count(*) as freq
                           from processing_full.stg_ax_rx
                           where patient_key in (select distinct patient_key from scores)
                           group by drug_ndc
                           order by 2 desc
                        ),
                         top10_cpt as(
                           select top 10 procedure_arr, count(*) as freq
                           from processing_full.stg_ax_mx
                           where patient_key in (select distinct patient_key from scores) and PROCEDURE_ARR != []
                           group by procedure_arr
                           order by 2 desc
                        ),
                         top10_dx as(
                           select top 10 DIAGNOSIS_CODE_ARR, count(*) as freq
                           from processing_full.stg_ax_dx
                           where patient_key in (select distinct patient_key from scores) and array_size(DIAGNOSIS_CODE_ARR) = 1
                           group by DIAGNOSIS_CODE_ARR
                           order by 2 desc
                        ),
                         scores_dx as(
                           select patient_key,
                                  array_to_string(DIAGNOSIS_CODE_ARR,'') as icd_code,
                                  array_to_string(YEAR_OF_SERVICE_ARR,'') as date_dx
                           from processing_full.stg_ax_dx
                           where DIAGNOSIS_CODE_ARR in (select DIAGNOSIS_CODE_ARR from top10_dx) and patient_key in (select distinct patient_key from scores)
                           order by patient_key, date_dx
                        ),
                         scores_cpt as(
                           select patient_key,
                                  array_to_string(PROCEDURE_ARR,'') as cpt_code,
                                  left(array_to_string(PX_YEAR_OF_SERVICE_ARR,''), 10) as date_cpt
                           from processing_full.stg_ax_mx
                           where PROCEDURE_ARR in (select PROCEDURE_ARR from top10_cpt) and patient_key in (select distinct patient_key from scores)
                           order by patient_key, date_cpt
                        ),
                         scores_rx as(
                           select patient_key,drug_ndc,date_of_service as date_rx
                           from processing_full.stg_ax_rx
                           where drug_ndc in (select drug_ndc from top10_ndc) and patient_key in (select distinct patient_key from scores)
                           order by patient_key, date_rx
                        ),
                         scores_dob as(
                           select *
                           from(
                             select distinct patient_key, patient_dob
                             from processing_full.stg_ax_mx 
                             where patient_key in (select distinct patient_key from scores) and patient_dob is not null and patient_dob < '2023-01-01'
                           ) a
                           union
                           select *
                           from(
                             select distinct patient_key, patient_dob
                             from processing_full.stg_ax_rx
                             where patient_key in (select distinct patient_key from scores) and patient_dob is not null and patient_dob < '2023-01-01'
                           ) b
                        ),
                         scores_gender as(
                           select distinct patient_key, patient_gender
                           from processing_full.stg_ax_mx 
                           where patient_key in (select distinct patient_key from scores) and patient_gender is not NULL
                        ),
                         scores_dx_join as(
                           select distinct
                                  coalesce(a.patient_key, b.patient_key) as patient_key, 
                                  coalesce(a.result_date, b.date_dx) as date,
                                  value, icd_code
                           from scores a
                           full outer join scores_dx b on a.patient_key = b.patient_key and a.result_date = b.date_dx
                           order by 1,2
                        ),
                         scores_dx_cpt_join as(
                           select distinct
                                  coalesce(a.patient_key, b.patient_key) as patient_key, 
                                  coalesce(a.date, b.date_cpt) as date,
                                  a.value, a.icd_code, b.cpt_code
                           from scores_dx_join a
                           full outer join scores_cpt b on a.patient_key = b.patient_key and a.date = b.date_cpt
                           order by 1,2
                        ),
                         scores_dx_cpt_rx_join as(
                           select distinct
                                  coalesce(a.patient_key, b.patient_key) as patient_key, 
                                  coalesce(a.date, b.date_rx) as date,
                                  a.value, a.icd_code, a.cpt_code, b.drug_ndc
                           from scores_dx_cpt_join a
                           full outer join scores_rx b on a.patient_key = b.patient_key and a.date = b.date_rx
                           order by 1,2
                        ),
                         scores_dx_cpt_rx_demo_join as(
                           select distinct
                                  a.*, datediff(year, b.patient_dob, a.date) as age,
                                  case when c.patient_gender = 'F' then 1 
                                       when c.patient_gender = 'M' then 0 else c.patient_gender end as gender_female
                           from scores_dx_cpt_rx_join a
                           left join scores_dob b on a.patient_key = b.patient_key
                           left join scores_gender c on a.patient_key = c.patient_key
                           order by a.patient_key, a.date
                        )
                    
                    select * from scores_dx_cpt_rx_demo_join;
                     """)
df = df_sql.to_pandas()
df

Initiating login request with your identity provider. A browser window should have opened for you to complete the login. If you can't see it, check existing browser windows, or your OS settings. Press CTRL+C to abort and try again...
Going to open: https://mdsol.okta.com/app/snowflake/exkugim5j5lsfZiBH0x7/sso/saml?SAMLRequest=lZJRb9owFIX%2FSuQ9J3ZCocwCKihiRaMrgnTS%2BmaSC3hx7MzXgbBfPwfK1D200t4s%2Bzu%2B595zB3dNqYIDWJRGD0kcMRKAzkwu9W5IntNZ2CcBOqFzoYyGITkBkrvRAEWpKj6u3V6v4FcN6AL%2FkUbePgxJbTU3AiVyLUpA7jK%2BHj8ueBIxLhDBOl%2BOvJFUH2sqa5zJjLpKcpTe3t65ilN6PB6jYycydkcTxhhln6mnWuTTlW98T%2B%2FwMWU3Le8Jjy9fC02kvozgI1ebC4T8IU2X4fJpnZJgfO3u3misS7BrsAeZwfNqcTGA3sFBiLjXS%2FpRjSEIdGEcoTbHrRIFZKasaue%2FjfyJbiGnyuykH9Z8OiRVIXMFq2qW9RPQy83voigmzX6R%2FUj27uvTl0XTPy1g8zDepNOtZhkJvl%2BjTdpo54g1zHUbqPNXLOmE7CZkvTTucXbLk07UjTsvJJj6QKUW7qy8ui5zNCoyhRNna6Kq6F%2FXFJqi3smy%2B7OrcPsiJw%2BsuaWIhrbpksvC8HN5O%2FrvMQzoW%2Fnr8n3zecynS6NkdgpmxpbCvR9XHMXnG5mH2zPKoRRSjfPcAqKPTSlzvLcgnN9xZ2sgdHSp%2Bu%2BWj%2F4A&RelayState=59

Unnamed: 0,PATIENT_KEY,DATE,VALUE,ICD_CODE,CPT_CODE,DRUG_NDC,AGE,GENDER_FEMALE
0,003f98c2-c495-581c-867d-1dd364120d29,2016-04-10,,,99233,,58.0,1.0
1,003f98c2-c495-581c-867d-1dd364120d29,2016-04-11,,,99232,,58.0,1.0
2,003f98c2-c495-581c-867d-1dd364120d29,2016-04-12,,,99232,,58.0,1.0
3,003f98c2-c495-581c-867d-1dd364120d29,2016-04-13,,,99232,,58.0,1.0
4,003f98c2-c495-581c-867d-1dd364120d29,2016-04-14,,,99232,,58.0,1.0
...,...,...,...,...,...,...,...,...
14219,ffcc2e7b-e540-506d-84d4-e9ce769d1e50,2020-02-04,,,99214,,66.0,1.0
14220,ffcc2e7b-e540-506d-84d4-e9ce769d1e50,2020-03-16,,F419,,,66.0,1.0
14221,ffcc2e7b-e540-506d-84d4-e9ce769d1e50,2020-10-27,20,,,,66.0,1.0
14222,ffcc2e7b-e540-506d-84d4-e9ce769d1e50,2021-05-25,23,,,,67.0,1.0


In [38]:
len(np.unique(df['PATIENT_KEY']))

1401

In [39]:
sum(df['VALUE'].isna())

12477

In [40]:
sum(df['ICD_CODE'].isna())

6932

In [41]:
sum(df['CPT_CODE'].isna())

6362

In [42]:
sum(df['DRUG_NDC'].isna())

13324

# Comments on data

In [None]:
# 1401 unique patients who have mmse or moca scores
# top 10 of each category codes were taken 
# a lot of patients did not have rx, mx, or cpt data
# claims could occur after score was delivered 

# Preprocessing

In [None]:
# we will forward fill with the same score and codes until it is changed 
# remove patients with unknown score
# removed records rx, dx, and cpt is unknown or age or gender is unknown
# one-hot encode claims and gender features

In [90]:
import preprocessing_time_series
preprocess = preprocessing_time_series.PreprocessingTimeSeries()
df_pp = preprocess.clean(df)
df_pp = preprocess.forward_fill(df_pp, 'VALUE', 'PATIENT_KEY')
df_pp = preprocess.forward_fill(df_pp, 'ICD_CODE', 'PATIENT_KEY')
df_pp = preprocess.forward_fill(df_pp, 'CPT_CODE', 'PATIENT_KEY')
df_pp = preprocess.forward_fill(df_pp, 'DRUG_NDC', 'PATIENT_KEY')
df_pp = df_pp[~df_pp['VALUE'].isna()]
df_pp = preprocess.one_hot(df_pp, 'ICD_CODE')
df_pp = preprocess.one_hot(df_pp, 'CPT_CODE')
df_pp = preprocess.one_hot(df_pp, 'DRUG_NDC')
df_pp['VALUE'] = df_pp['VALUE'].astype(float)
df_pp = preprocess.lag_dv(df_pp, 'VALUE')
df_pp = df_pp.reset_index(drop=True)
df_pp

Unnamed: 0,PATIENT_KEY,DATE,VALUE,ICD_CODE,CPT_CODE,DRUG_NDC,AGE,GENDER_FEMALE,ICD_CODE_F0150,ICD_CODE_F0390,...,DRUG_NDC_00591024010,DRUG_NDC_00591024110,DRUG_NDC_16729013616,DRUG_NDC_16729013716,DRUG_NDC_67253090350,DRUG_NDC_67877024210,DRUG_NDC_69315090410,VALUE_lag1,VALUE_lag2,VALUE_lag3
0,004ce195-f105-5b1c-96f2-4860c44de5f8,2018-12-12,25.0,F329,,,66.0,1.0,0,0,...,0,0,0,0,0,0,0,0.0,0.0,0.0
1,004ce195-f105-5b1c-96f2-4860c44de5f8,2019-01-23,25.0,F329,99214,,67.0,1.0,0,0,...,0,0,0,0,0,0,0,25.0,0.0,0.0
2,0154e0c6-cd73-5308-99bf-9ff83a906761,2019-08-01,15.0,F0390,99214,,80.0,0.0,0,1,...,0,0,0,0,0,0,0,25.0,25.0,0.0
3,0154e0c6-cd73-5308-99bf-9ff83a906761,2019-09-05,15.0,F0390,99213,,80.0,0.0,0,1,...,0,0,0,0,0,0,0,15.0,25.0,25.0
4,0154e0c6-cd73-5308-99bf-9ff83a906761,2019-09-12,15.0,F0390,99213,,80.0,0.0,0,1,...,0,0,0,0,0,0,0,15.0,15.0,25.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2071,fe31016f-dbb5-5668-8bdb-a8bb6093d52e,2019-05-28,27.0,F418,99214,,71.0,1.0,0,0,...,0,0,0,0,0,0,0,27.0,25.0,25.0
2072,fe31016f-dbb5-5668-8bdb-a8bb6093d52e,2019-11-26,27.0,F418,99214,,71.0,1.0,0,0,...,0,0,0,0,0,0,0,27.0,27.0,25.0
2073,fe31016f-dbb5-5668-8bdb-a8bb6093d52e,2020-01-06,25.0,F418,99213,,72.0,1.0,0,0,...,0,0,0,0,0,0,0,27.0,27.0,27.0
2074,fe31016f-dbb5-5668-8bdb-a8bb6093d52e,2021-10-18,25.0,F418,99213,,73.0,1.0,0,0,...,0,0,0,0,0,0,0,25.0,27.0,27.0


In [None]:
# avg mmse measurements
# try lstm 
# train/test split should ne on patients
# multi headed attention

# Some EDA

In [60]:
import plotly.express as px
fig = px.histogram(df_pp, x="VALUE")
fig.show()

In [98]:
df_pp.describe()

Unnamed: 0,VALUE,AGE,GENDER_FEMALE,ICD_CODE_F0150,ICD_CODE_F0390,ICD_CODE_F17210,ICD_CODE_F329,ICD_CODE_F331,ICD_CODE_F411,ICD_CODE_F418,...,DRUG_NDC_00591024010,DRUG_NDC_00591024110,DRUG_NDC_16729013616,DRUG_NDC_16729013716,DRUG_NDC_67253090350,DRUG_NDC_67877024210,DRUG_NDC_69315090410,VALUE_lag1,VALUE_lag2,VALUE_lag3
count,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,...,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0,2076.0
mean,20.424855,71.264933,0.669557,0.046724,0.206647,0.056358,0.197495,0.08237,0.05395,0.096821,...,0.026493,0.019268,0.030829,0.027457,0.015896,0.010116,0.000963,20.411368,20.399326,20.387283
std,6.734313,11.463852,0.470486,0.211099,0.404998,0.230668,0.398205,0.274994,0.225973,0.295785,...,0.160636,0.137498,0.172895,0.163449,0.125103,0.10009,0.031031,6.747162,6.761263,6.775314
min,1.0,27.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,17.0,63.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,16.75,16.0,16.0
50%,22.0,73.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,22.0,22.0,22.0
75%,25.0,81.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,25.0,25.0,25.0
max,30.0,89.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,30.0,30.0,30.0


# Modeling - Rf

In [91]:
import modeling_time_series
models = modeling_time_series.ModelingTimeSeries()
lags = [x for x in df_pp.columns if 'lag' in x]
cols = df_pp.drop(['VALUE', 'PATIENT_KEY', 'DATE', 'ICD_CODE', 'CPT_CODE', 'DRUG_NDC'] + lags, axis=1).columns.tolist()
model_rf = models.rf_regressor(df_pp, cols, 'VALUE')

In [92]:
model_rf['performance']

{'full': {'r2': 0.8412910574223038,
  'rmse': 2.682189136451688,
  'mape': 0.13282823124340884},
 'train': {'r2': 0.8518661238315942,
  'rmse': 2.590571901852798,
  'mape': 0.13055978955053563},
 'test': {'r2': 0.4645409785856254,
  'rmse': 4.936004168994381,
  'mape': 0.24798696014985672}}

In [93]:
model_rf['importance']

Unnamed: 0,feature,importance,std
0,AGE,0.293346,0.021367
3,ICD_CODE_F0390,0.200521,0.018114
1,GENDER_FEMALE,0.064747,0.012238
23,DRUG_NDC_00591024001,0.038863,0.013325
2,ICD_CODE_F0150,0.03789,0.010407
16,CPT_CODE_99214,0.031532,0.008335
27,DRUG_NDC_16729013716,0.02849,0.010033
14,CPT_CODE_99204,0.027373,0.007381
18,CPT_CODE_99232,0.024593,0.013822
21,DRUG_NDC_00093083205,0.022935,0.006009


In [94]:
# scatter plot of predictions
plot_df = pd.DataFrame({'preds': model_rf['df_preds_test']['preds_test'],
                        'actual': model_rf['df_preds_test']['VALUE']})
fig = px.scatter(plot_df, x="preds", y='actual', title="Random Forest - Test Data")
fig.show()

# Modeling - Lasso

In [95]:
import modeling_time_series
models = modeling_time_series.ModelingTimeSeries()
lags = [x for x in df_pp.columns if 'lag' in x]
cols = df_pp.drop(['VALUE', 'PATIENT_KEY', 'DATE', 'ICD_CODE', 'CPT_CODE', 'DRUG_NDC'] + lags, axis=1).columns.tolist()
model_lasso = models.lasso(df_pp, cols, 'VALUE')

In [96]:
model_lasso['performance']

{'full': {'r2': 0.2614546757805669,
  'rmse': 5.785984642637888,
  'mape': 0.5719214014920246},
 'train': {'r2': 0.3122235463195666,
  'rmse': 5.582029263882624,
  'mape': 0.5428908330186553},
 'test': {'r2': 0.2283259710682709,
  'rmse': 5.925559180282234,
  'mape': 0.46597077476046567}}

In [97]:
model_lasso['importance']

Unnamed: 0,features,coefficients
0,Intercept,28.021661
14,CPT_CODE_90837,1.671837
9,ICD_CODE_F418,0.431165
10,ICD_CODE_F419,0.397268
11,ICD_CODE_F4310,0.182489
17,CPT_CODE_99214,0.022019
30,DRUG_NDC_67877024210,0.0
29,DRUG_NDC_67253090350,0.0
28,DRUG_NDC_16729013716,-0.0
27,DRUG_NDC_16729013616,-0.0
