In [2]:
import snowflake.connector
from snowflake.snowpark.functions import count, when, col, corr, array_cat, sum as sum_
import os
from snowflake.snowpark import Session
import pandas as pd

cnxn_params = {
    "user":'hartsingh',
    "authenticator":'externalbrowser',
    "account":'vaa16628',
    "region":'us-east-1',
    "warehouse":"ACORN_CDS_PROD_COMMERCIALANALYTICS",
    "database":"ACORN_CDS_PROD_BIOXCEL_ADHOC",
    "role":'ACORN_CDS_PRD_CA_APP'
}

session = Session.builder.configs(cnxn_params).create()

Initiating login request with your identity provider. A browser window should have opened for you to complete the login. If you can't see it, check existing browser windows, or your OS settings. Press CTRL+C to abort and try again...
Going to open: https://mdsol.okta.com/app/snowflake/exkugim5j5lsfZiBH0x7/sso/saml?SAMLRequest=lZJdb9owGIX%2FSuRdJ3YCodQCqrQIFY0yxMdUceclDnhx7Myvk9B%2FX4ePqbtopd1Fzjl%2Bjt%2Fzjh5OpfQabkBoNUZhQJDHVaozoQ5jtNvO%2FCHywDKVMakVH6M3DuhhMgJWyoomtT2qNf9Tc7Ceu0gB7X6MUW0U1QwEUMVKDtSmdJO8LGgUEMoAuLEOh66WDIRjHa2tKMZt2wZtL9DmgCNCCCb32Kk6yTf0AVF9zaiMtjrV8mY5uTd9gggx6XcIp3CE1dX4KNRlBF9Rfl1EQJ%2B325W%2F%2BrHZIi%2B5ve5JK6hLbjbcNCLlu%2FXiEgBcgoaxcDCIhkENPmdg%2FTAApdtcsoKnuqxq664N3BfOeYalPgg3rPl0jKpCZGShWCH3zcscjIzU8LW%2FaU51tGzy%2FX2zjHfF8vjdvq6LeJakyPt5qzbqqp0D1HyuukKtOyJRzyd9n%2FS2YUx7MQ0HQRTe7ZE3dYUKxezZeUtdZqBloAvLztFYVeG%2FqTE%2FFfVBlPHvWEK%2BF4%2FP5HSHATTu2kKXhaFnvJn89xhG%2BKP9unxL18d8utJSpG%2FeTJuS2c%2FrCoPwfCIyPz9LKS%2BZkEmWGQ7gapNSt0%2BGM%2Bt23JqaIzy5UP%2Fd8sk7&Relay

In [27]:
%load_ext autoreload
%autoreload 2

# Objective

In [None]:
# accurately estimating mmse can be helpful to fill in scores for patients we do not have a score for 
# this score helps us estimate the severity of their condition
# lasso can be good hear since we also what to see directionality of features

# Pull raw data from snowflake

In [10]:
df_sql = session.sql(f"""
                     with scores as(
                            select distinct patient_key, LOINC, panel, test, value, result_date
                            from dbo.TBLDRG_EHR_RESULTS 
                            where LOINC in ('72107-6', '72133-2', '72172-0') and try_to_number(value) is not NULL and try_to_number(value) <= 30
                            order by patient_key
                         ),
                          scores2 as(
                            select t.*
                            from(
                              select *, row_number() over(partition by patient_key order by result_date desc) as date_rank
                              from scores
                            ) t
                            where t.date_rank = 1
                            order by t.patient_key
                         ),
                          top10_ndc as(
                            select top 10 drug_ndc, count(*) as freq
                            from processing_cns.stg_rx
                            where patient_key in (select patient_key from scores2)
                            group by drug_ndc
                            order by 2 desc
                         ),
                          top10_cpt as(
                            select top 10 procedure_arr, count(*) as freq
                            from processing_cns.stg_mx
                            where patient_key in (select patient_key from scores2) and PROCEDURE_ARR != []
                            group by procedure_arr
                            order by 2 desc
                         ),
                          top10_dx as(
                            select top 10 DIAGNOSIS_CODE_ARR, count(*) as freq
                            from processing_cns.stg_dx
                            where patient_key in (select patient_key from scores2) and array_size(DIAGNOSIS_CODE_ARR) = 1
                            group by DIAGNOSIS_CODE_ARR
                            order by 2 desc
                         ),
                          scores2_dx as(
                            select t.patient_key,
                                   t.icd_code,
                                   count(t.date_dx) as dx_freq
                            from(
                              select a.*, b.icd_code, case when b.date_dx is NULL then '2222-01-01' else b.date_dx end as date_dx
                              from scores2 a
                              left join(
                                select patient_key,
                                       array_to_string(DIAGNOSIS_CODE_ARR,'') as icd_code,
                                       array_to_string(YEAR_OF_SERVICE_ARR,'') as date_dx
                                from processing_cns.stg_dx
                                where DIAGNOSIS_CODE_ARR in (select DIAGNOSIS_CODE_ARR from top10_dx)
                              ) b on a.patient_key = b.patient_key 
                            ) t
                            --where t.date_dx < t.result_date
                            where t.icd_code is not NULL
                            group by t.patient_key, t.icd_code
                         ),
                          scores2_cpt as(
                            select t.patient_key,
                                   t.cpt_code,
                                   count(t.date_cpt) as cpt_freq
                            from(
                              select a.*, b.cpt_code, case when b.date_cpt is NULL then '2222-01-01' else b.date_cpt end as date_cpt
                              from scores2 a
                              left join(
                                select patient_key,
                                       array_to_string(PROCEDURE_ARR,'') as cpt_code,
                                       array_to_string(PX_YEAR_OF_SERVICE_ARR,'') as date_cpt
                                from processing_cns.stg_mx
                                where PROCEDURE_ARR in (select PROCEDURE_ARR from top10_cpt)
                              ) b on a.patient_key = b.patient_key
                            ) t
                            --where t.date_cpt < t.result_date
                            where t.cpt_code is not NULL
                            group by t.patient_key, t.cpt_code
                         ),
                          scores2_rx as(
                            select t.patient_key,
                                   t.drug_ndc,
                                   count(t.date_rx) as ndc_freq
                            from(
                              select a.*, b.drug_ndc, case when b.date_rx is NULL then '2222-01-01' else b.date_rx end as date_rx
                              from scores2 a
                              left join(
                                select patient_key,drug_ndc,date_of_service as date_rx
                                from processing_cns.stg_rx
                                where drug_ndc in (select drug_ndc from top10_ndc)
                              ) b on a.patient_key = b.patient_key
                            ) t
                            --where t.date_rx < t.result_date
                            where t.drug_ndc is not NULL
                            group by t.patient_key, t.drug_ndc
                         ),
                          scores2_rx_dx_mx as(
                            select a.patient_key, a.value, b.drug_ndc, b.ndc_freq, c.cpt_code, c.cpt_freq, d.icd_code, d.dx_freq
                            from scores2 a
                            left join scores2_rx b on a.patient_key = b.patient_key
                            left join scores2_cpt c on a.patient_key = c.patient_key
                            left join scores2_dx d on a.patient_key = d.patient_key
                            order by a.patient_key
                         )
                     select * from scores2_rx_dx_mx;
                     """)
df = df_sql.to_pandas()
df

Initiating login request with your identity provider. A browser window should have opened for you to complete the login. If you can't see it, check existing browser windows, or your OS settings. Press CTRL+C to abort and try again...
Going to open: https://mdsol.okta.com/app/snowflake/exkugim5j5lsfZiBH0x7/sso/saml?SAMLRequest=lZJfb9owFMW%2FSuQ9J3bCn4EFVGlRVSS6ZRD60DcvMeDFsVNfh8A%2B%2FZwEpu6hlfYWOef4d3zPnd2dS%2BmduAGh1RyFAUEeV5nOhTrM0S599CfIA8tUzqRWfI4uHNDdYgaslBWNa3tUG%2F5Wc7Ceu0gBbX%2FMUW0U1QwEUMVKDtRmdBs%2Fr2kUEMoAuLEOh66WHIRjHa2tKMZN0wTNINDmgCNCCCZT7FSt5At6h6g%2BZ1RGW51pebOc3Zs%2BQISYDFuEUzhCcjXeC9WP4DPKz14E9ClNEz%2F5vk2RF99e96AV1CU3W25OIuO7zboPAC7BibFwPI4mQQ0%2BZ2D9MAClm71kBc90WdXWXRu4L7znOZb6INywVss5qgqRR9v1dDNM3i7H58GxCnfTalnoHxCHo2x5%2Fg2FfLmsk3FYizzLkPdyqzZqq10B1Hyl2kKtOyLRwCdDnwzSiFBC6GgSTKPwFXlLV6hQzHbOW%2BoyBy0DXVjWRWNVhf%2Bmxvxc1AdRjn6NJOxfxf0TOX%2FFABq3baF%2BYWiHN4v%2FHsMMv7dfl%2B%2Bb62O1TLQU2cV71KZk9uO6wiDsTkTu7zsp5SUTMs5zwwFcbVLq5sFwZt2OW1NzhBc99d8tX%2FwB&RelayState=5938

Unnamed: 0,PATIENT_KEY,VALUE,DRUG_NDC,NDC_FREQ,CPT_CODE,CPT_FREQ,ICD_CODE,DX_FREQ
0,003f98c2-c495-581c-867d-1dd364120d29,27,,,,,,
1,004ce195-f105-5b1c-96f2-4860c44de5f8,25,,,,,,
2,00c98c5a-4705-578a-9bd9-069365a87d16,23,,,99214,1.0,F0151,27.0
3,00c98c5a-4705-578a-9bd9-069365a87d16,23,,,99213,5.0,F0151,27.0
4,01228727-7926-51c3-bb02-2d388430e4e7,25,,,,,,
...,...,...,...,...,...,...,...,...
2488,ff2396b1-71b9-5edb-9623-8f50745cf743,24,,,,,F0390,2.0
2489,ff41d362-a0fd-5134-a332-3f6e8432e74c,29,,,,,,
2490,ff792ca2-fb00-5dd1-9527-7372c189f450,26,,,,,,
2491,ffc343eb-a942-508b-84f8-537d4de1aae4,29,,,,,,


# Comments on data

In [None]:
# 1401 unique patients who have mmse or moca scores
# 9 patients have multiple scores on the same day - one is picked arbitrarily
# latest score was taken for each patient
# top 10 of each category codes were taken 
# a lot of patients did not have rx, mx, or cpt data
# claims could occur after score was delivered - shouldnt matter much since we arent trying to predict future scores

# Preprocessing

In [95]:
import preprocessing
preprocess = preprocessing.Preprocessing()
df_cpt = preprocess.long_to_wide(df, 'PATIENT_KEY', 'CPT_CODE', 'CPT_FREQ', 'cpt')
df_icd = preprocess.long_to_wide(df, 'PATIENT_KEY', 'ICD_CODE', 'DX_FREQ', 'icd')
df_ndc = preprocess.long_to_wide(df, 'PATIENT_KEY', 'DRUG_NDC', 'NDC_FREQ', 'ndc')
df_score = df[['PATIENT_KEY','VALUE']].drop_duplicates()
df_pp = df_score.merge(df_cpt, on='PATIENT_KEY', how='left')
df_pp = df_pp.merge(df_icd, on='PATIENT_KEY', how='left')
df_pp = df_pp.merge(df_ndc, on='PATIENT_KEY', how='left')
df_pp['VALUE'] = df_pp['VALUE'].astype(float)
df_pp = df_pp.fillna(0)

# make binary
df_pp1 = df_pp[['PATIENT_KEY','VALUE']]
df_pp2 = df_pp.drop(['PATIENT_KEY','VALUE'], axis=1)
df_pp2[df_pp2 >= 1] = 1
df_pp = pd.concat([df_pp1, df_pp2], axis=1)
df_pp


Unnamed: 0,PATIENT_KEY,VALUE,cpt_99213,cpt_99214,cpt_99232,cpt_99233,cpt_99285,cpt_99308,cpt_99309,cpt_99336,...,ndc_13668010310,ndc_29300017205,ndc_29300017216,ndc_31722073805,ndc_33342029709,ndc_43547027503,ndc_43547027509,ndc_43547027603,ndc_43547027609,ndc_43547027611
0,003f98c2-c495-581c-867d-1dd364120d29,27.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1,004ce195-f105-5b1c-96f2-4860c44de5f8,25.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
2,00c98c5a-4705-578a-9bd9-069365a87d16,23.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
3,01228727-7926-51c3-bb02-2d388430e4e7,25.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
4,01381c39-92ad-5256-9a1b-79caed70b019,14.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1396,ff2396b1-71b9-5edb-9623-8f50745cf743,24.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1397,ff41d362-a0fd-5134-a332-3f6e8432e74c,29.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1398,ff792ca2-fb00-5dd1-9527-7372c189f450,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
1399,ffc343eb-a942-508b-84f8-537d4de1aae4,29.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0


# Some EDA

In [46]:
import plotly.express as px
fig = px.histogram(df_pp, x="VALUE")
fig.show()

In [96]:
df_pp.describe()

Unnamed: 0,VALUE,cpt_99213,cpt_99214,cpt_99232,cpt_99233,cpt_99285,cpt_99308,cpt_99309,cpt_99336,cpt_H2032,...,ndc_13668010310,ndc_29300017205,ndc_29300017216,ndc_31722073805,ndc_33342029709,ndc_43547027503,ndc_43547027509,ndc_43547027603,ndc_43547027609,ndc_43547027611
count,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,...,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0,1401.0
mean,21.227695,0.095646,0.123483,0.025696,0.015703,0.038544,0.013562,0.017844,0.007852,0.000714,...,0.007138,0.009279,0.009279,0.007138,0.006424,0.009279,0.01142,0.007852,0.013562,0.017131
std,6.357019,0.29421,0.329109,0.158283,0.124368,0.192574,0.115704,0.132433,0.088292,0.026717,...,0.084213,0.095914,0.095914,0.084213,0.07992,0.095914,0.106292,0.088292,0.115704,0.129804
min,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
25%,18.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
50%,23.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
75%,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
max,30.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,...,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0


# Modeling - Lasso

In [97]:
import modeling
models = modeling.Modeling()
cols = df_pp.drop(['VALUE', 'PATIENT_KEY'], axis=1).columns.tolist()
model_lasso = models.lasso_regression(df_pp, cols, 'VALUE')

In [98]:
model_lasso['performance']

{'train': {'r2': 0.22275497436392855,
  'rmse': 5.641198920720199,
  'mape': 0.35720047000453486},
 'test': {'r2': 0.19813660694367285,
  'rmse': 5.510436524394024,
  'mape': 0.3504713275799618}}

In [99]:
model_lasso['model'].coef_

array([-1.84632555, -1.61818728, -0.        , -0.        , -0.50315553,
       -1.77083516, -0.40823594, -0.        ,  0.        , -0.        ,
       -0.        , -1.37475107, -1.80246965,  0.        , -2.65903964,
       -2.71763906, -0.        , -2.86498942, -2.91584668, -0.        ,
       -0.        , -0.        , -1.29803586, -2.4171191 , -0.        ,
       -0.        ,  0.        , -0.        , -2.83438866,  0.        ])

In [100]:
model_lasso['model'].intercept_

22.496981180948264

In [101]:
max(model_lasso['df_preds_train']['preds_train'])

22.496981180948264

# Random Forest

In [102]:
model_rf = models.random_forest_regression(df_pp, cols, 'VALUE')

In [103]:
model_rf['performance']

{'r2': 0.19630374704440906, 'rmse': 5.699432658779989}

In [104]:
model_rf['importance']

Unnamed: 0,feature,importance,std
1,cpt_99214,0.179421,0.095403
14,icd_F0390,0.167255,0.094145
0,cpt_99213,0.125569,0.092132
18,icd_G309,0.067879,0.029194
11,icd_F0151,0.053105,0.019936
15,icd_F0391,0.042684,0.026912
12,icd_F0280,0.042262,0.01608
10,icd_F0150,0.036779,0.019298
4,cpt_99285,0.028524,0.01174
13,icd_F0281,0.025391,0.010595


# Next steps

In [None]:
# include age and gender as feature
# split dependent variable into classes - low, medium, high severity
# troubleshoot lasso - why are all coefs negative
# make features binary
# build a model based on rx, dx, cpt, mmse, demo data to predict patients likelihood to get prescribed an alzhiemers treatment
# pca all features