
# Step 1 - Data Generation

Run the Setup notebook to generate synthetic data if necessary

# Step 2 - Data Governance

### Cells 3-8 show how to mask sensitive fields so that only users in certain user groups will be able to see the information.

<img src="https://raw.githubusercontent.com/databricks-demos/dbdemos-resources/main/images/hls/patient-readmission/hls-patient-readmision-flow-2.png" style="float: right; margin-left: 10px; margin-top:10px" width="650px" />

In [0]:
# Create widgets that will set the values for catalog name, schema name, and volume name:

dbutils.widgets.text(name = "source_catalog_name", defaultValue="", label="1. Source Catalog Name")
dbutils.widgets.text(name = "source_schema_name", defaultValue="", label="2. Source Schema Name")
dbutils.widgets.text(name = "destination_catalog_name", defaultValue="", label="3. Destination Catalog Name")
dbutils.widgets.text(name = "destination_schema_name", defaultValue="", label="4. Destination Schema Name")

In [0]:
source_catalog_name = dbutils.widgets.get("source_catalog_name")
source_schema_name = dbutils.widgets.get("source_schema_name")
destination_catalog_name = dbutils.widgets.get("destination_catalog_name")
destination_schema_name = dbutils.widgets.get("destination_schema_name")

In [0]:
spark.sql(f'CREATE CATALOG IF NOT EXISTS {catalog_name}')


In [0]:
%sql
USE CATALOG IDENTIFIER(:source_catalog_name);
USE SCHEMA IDENTIFIER(:source_schema_name);
SHOW TABLES;

In [0]:
%sql
-- Let's grant our ANALYSTS a SELECT permission:
-- Note: make sure you created an analysts and dataengineers group first.
GRANT SELECT ON TABLE patients TO `analysts`;
-- GRANT SELECT ON TABLE condition_occurrence TO `analysts`;
-- GRANT SELECT ON TABLE encounters TO `analysts`;

-- We'll grant an extra MODIFY to our Data Engineer
-- GRANT SELECT, MODIFY ON SCHEMA dbdemos_hls_readmission TO `dataengineers`;

In [0]:
spark.sql(f"CREATE OR REPLACE TABLE {destination_catalog_name}.{destination_schema_name}.protected_patients AS SELECT * FROM patients;")

In [0]:
%sql
-- hls_admin group will have access to all data, all other users will see masks over the sensitive fields.
CREATE OR REPLACE FUNCTION simple_mask(column_value STRING)
   RETURN IF(is_account_group_member('hls_admin'), column_value, "****");
   
-- ALTER FUNCTION simple_mask OWNER TO `account users`; -- grant access to all user to the function for the demo - don't do it in production

-- Mask all PII information
ALTER TABLE protected_patients ALTER COLUMN FIRST SET MASK simple_mask;
ALTER TABLE protected_patients ALTER COLUMN LAST SET MASK simple_mask;
ALTER TABLE protected_patients ALTER COLUMN PASSPORT SET MASK simple_mask;
ALTER TABLE protected_patients ALTER COLUMN DRIVERS SET MASK simple_mask;
ALTER TABLE protected_patients ALTER COLUMN SSN SET MASK simple_mask;
ALTER TABLE protected_patients ALTER COLUMN ADDRESS SET MASK simple_mask;

SELECT * FROM protected_patients

In [0]:
# %sql
# CREATE SHARE IF NOT EXISTS mcutini_diabetes_readmissions_share
#   COMMENT 'Share the patients table in mcutini.arizona_patients for diabetes research.';

# ALTER SHARE mcutini_diabetes_readmissions_share OWNER TO `HOOLI`;

# ALTER SHARE mcutini_diabetes_readmissions_share ADD TABLE patients WITH HISTORY;

In [0]:
# %sql
# DESCRIBE SHARE mcutini_diabetes_readmissions_share;

# Step 3 - Data Analysis

### Cells 10-35 show the experience of analyzing your data using Python and SQL.

In [0]:
%sql 
select * from patients

In [0]:
%sql
select GENDER, ceil(months_between(current_date(),BIRTHDATE)/12/5)*5 as age, count(*) as count from patients group by GENDER, age order by age
-- Can use buildin visualization (Area: Key: age, group: gender_source_value, Values: count)

In [0]:
import plotly.express as px
px.area(_sqldf.toPandas(), x="age", y="count", color="GENDER", line_group="GENDER")

In [0]:
#We can also leverage pure Python to access data
from pyspark.sql.functions import col, desc
df = spark.table("patients").join(spark.table("conditions"), col("Id")==col("PATIENT")) \
          .groupBy(['GENDER', 'conditions.DESCRIPTION']).count() \
          .orderBy(desc('count')).limit(20).toPandas()
#And use our usual plot libraries
px.bar(df, x="DESCRIPTION", y="count", color="GENDER", barmode="group")

## Cohort Definition

Let's define a cohort that we can do analysis on.

In [0]:
%sql
CREATE OR REPLACE TABLE cohort (
  id INT,
  name STRING,
  patient STRING,
  cohort_start_date DATE,
  cohort_end_date DATE
);
ALTER TABLE cohort OWNER TO `account users`;

In [0]:
import random
from pyspark.sql.functions import lit
def create_save_cohort(name, condition_codes = []):
  cohort1 = (spark.sql('select patient, to_date(start) as cohort_start_date, to_date(stop) as cohort_end_date from conditions')
                 .withColumn('id', lit(random.randint(999999, 99999999)))
                 .withColumn('name', lit(name)))
  if len(condition_codes)> 0:
    cohort1 = cohort1.where(col('CODE').isin(condition_codes))
  cohort1.write.mode("append").saveAsTable('cohort')

#Create cohorts based on patient condition (for ex: 840539006 is COVID)
create_save_cohort('COVID-19-cohort', [840539006])
create_save_cohort('heart-condition-cohort', [1505002, 32485007, 305351004, 76464004])
create_save_cohort('all_patients')

# Step 4 - Data Science and Machine Learning

### Cells 38-50 show how do to data science on your data.


<img src="https://raw.githubusercontent.com/databricks-demos/dbdemos-resources/main/images/hls/patient-readmission/hls-patient-readmision-flow-4.png" style="float: right; margin-left: 30px; margin-top:10px" width="650px" />

We have cleaned and secured our data. We have also created cohorts of patients to analyze. We can now predict 30-day readmissions

In [0]:
%pip install mlflow==3.1.1 #2.19.0
dbutils.library.restartPython()

## Feature engineering

In [0]:
# Let's create our label: we'll predict the  30 days readmission risk
from pyspark.sql import Window
import pyspark.sql.functions as F
windowSpec = Window.partitionBy("PATIENT").orderBy("START")
labels = spark.table('encounters').select("PATIENT", "Id", "START", "STOP") \
              .withColumn('30_DAY_READMISSION', F.when(to_timestamp(F.col('START')).cast('long') - F.lag(to_timestamp(F.col('STOP'))).over(windowSpec).cast('long') < 30*24*60*60, 1).otherwise(0))
display(labels)

Databricks data profile. Run in Databricks to view.

In [0]:
import pyspark.pandas as ps

# Define Patient Features logic
def compute_pat_features(data):
  data = data.pandas_api()
  data = ps.get_dummies(data, columns=['MARITAL', 'RACE', 'ETHNICITY', 'GENDER'],dtype = 'int64').to_spark()
  return data

In [0]:
cohort_name = 'all_patients' #or could be 'COVID-19-cohort'
cohort = spark.sql(f"SELECT p.* FROM cohort c INNER JOIN patients p on c.patient=p.id WHERE c.name='{cohort_name}'").dropDuplicates(["id"])
cohort_features_df = compute_pat_features(cohort)
cohort_features_df.display()

In [0]:
from pyspark.sql.functions import col

def compute_enc_features(data):
  data = data.dropDuplicates(["Id"])
  data = data.withColumn('enc_length', F.unix_timestamp(col('stop'))- F.unix_timestamp(col('start')))
  data = data.pandas_api()
#   return data
  data = ps.get_dummies(data, columns=['ENCOUNTERCLASS'],dtype = 'int64').to_spark()
  
  return (
    data
    .select(
      col('Id').alias('ENCOUNTER_ID'),
      'BASE_ENCOUNTER_COST',
      'TOTAL_CLAIM_COST',
      'PAYER_COVERAGE',
      'enc_length',
      'ENCOUNTERCLASS_ambulatory',
      'ENCOUNTERCLASS_emergency',
      'ENCOUNTERCLASS_inpatient',
      'ENCOUNTERCLASS_outpatient',
      'ENCOUNTERCLASS_wellness',
    )
  )

enc_features_df = compute_enc_features(spark.table('encounters'))
display(enc_features_df)

In [0]:
enc_features_df = compute_enc_features(spark.table('encounters'))
training_dataset = cohort_features_df.join(labels, [labels.PATIENT==cohort_features_df.Id], "inner") \
                                     .join(enc_features_df, [labels.Id==enc_features_df.ENCOUNTER_ID], "inner") \
                                     .drop("Id", "_rescued_data", "SSN", "DRIVERS", "PASSPORT", "FIRST", "LAST", "ADDRESS", "BIRTHPLACE")
### Adding extra feature such as patient age at encounter
training_dataset = training_dataset.withColumnRenamed("PATIENT", "patient_id") \
                                   .withColumn("age_at_encounter", ((F.datediff(col('START'), col('BIRTHDATE'))) / 365.25))

training_dataset.write.mode('overwrite').saveAsTable("training_dataset")
display(spark.table("training_dataset"))

## AutoML

In [0]:
feature_names = ['MARITAL_M', 'MARITAL_S', 'RACE_asian', 'RACE_black', 'RACE_white', 'ETHNICITY_hispanic', 'ETHNICITY_nonhispanic', 'GENDER_F', 'GENDER_M', 'INCOME'] \
              + ['BASE_ENCOUNTER_COST', 'TOTAL_CLAIM_COST', 'PAYER_COVERAGE', 'enc_length', 'ENCOUNTERCLASS_ambulatory', 'ENCOUNTERCLASS_emergency', 'ENCOUNTERCLASS_inpatient', 'ENCOUNTERCLASS_outpatient', 'ENCOUNTERCLASS_wellness'] \
              + ['age_at_encounter'] \
              + ['30_DAY_READMISSION']

In [0]:
def set_experiment_permission(experiment_path):
    from databricks.sdk import WorkspaceClient
    from databricks.sdk.service import iam
    w = WorkspaceClient()
    try:
      status = w.workspace.get_status(experiment_path)
      w.permissions.set("experiments", request_object_id=status.object_id,  access_control_list=[
                            iam.AccessControlRequest(group_name="users", permission_level=iam.PermissionLevel.CAN_MANAGE)])    
    except Exception as e:
      print(f"error setting up shared experiment {experiment_path} permission: {e}")

    print(f"Experiment on {experiment_path} was set public")

In [0]:
import mlflow
from datetime import datetime
model_name = f"{catalog_name}_diabetes_readmissions"
xp_path = "/Workspace/Users/matt.cutini@databricks.com/diabetes_readmission"
xp_name = f"automl_churn_{datetime.now().strftime('%Y-%m-%d_%H:%M:%S')}"
try:
    from databricks import automl
    automl_run = automl.classify(
        experiment_name = xp_name,
        experiment_dir = xp_path,
        dataset = training_dataset.select(feature_names),
        target_col = "30_DAY_READMISSION",
        primary_metric="roc_auc",
        timeout_minutes = 10
    )
    #Make sure all users can access dbdemos shared experiment
    set_experiment_permission(f"{xp_path}/{xp_name}")
except Exception as e:
    if "cannot import name 'automl'" in str(e):
        # Note: cannot import name 'automl' from 'databricks' likely means you're using serverless. Dbdemos doesn't support autoML serverless API - this will be improved soon.
        # Adding a temporary workaround to make sure it works well for now - ignore this for classic run
        automl_run = DBDemos.create_mockup_automl_run(f"{xp_path}/{xp_name}", training_dataset.select(feature_names).toPandas(), model_name = model_name, target_col = "30_DAY_READMISSION")
    else:
        raise e

In [0]:
#Enable Unity Catalog with mlflow registry
mlflow.set_registry_uri('databricks-uc')
    
model_registered = mlflow.register_model(f"runs:/{automl_run.best_trial.mlflow_run_id}/model", f"{catalog_name}.{schema_name}.{model_name}")

#Move the model in production
print("registering model version "+model_registered.version+" as production model")
client = mlflow.tracking.MlflowClient()
client.set_registered_model_alias(name=f"{catalog_name}.{schema_name}.{model_name}", alias="prod", version=model_registered.version)

#Make sure all other users can access the model for our demo
#DBDemos.set_model_permission(f"{catalog}.{db}.{model_name}", "ALL_PRIVILEGES", "account users")

## Batch Scoring

In [0]:
# Load model as a Spark UDF.
loaded_model = mlflow.pyfunc.spark_udf(spark, model_uri=f"models:/{catalog_name}.{schema_name}.{model_name}@prod", result_type='double')

In [0]:
features = loaded_model.metadata.get_input_schema().input_names()

#For this demo, reuse our dataset to test the batch inferences
test_dataset = spark.table('training_dataset')

patient_risk_df =  test_dataset \
                   .withColumn("risk_prediction", loaded_model(F.struct(*features))) \
                   .select('ENCOUNTER_ID', 'PATIENT_ID', 'risk_prediction')

display(patient_risk_df)

In [0]:
patient_risk_df.write.mode("overwrite").saveAsTable(f"patient_readmission_prediction")

## Model Serving

In [0]:
full_model_name = f"{catalog_name}.{schema_name}.{model_name}"
import mlflow
from mlflow import MlflowClient

#Enable Unity Catalog with mlflow registry
mlflow.set_registry_uri("databricks-uc")
client = MlflowClient(registry_uri="databricks-uc")

#Get model with PROD alias (make sure you run the notebook 04.2 to save the model in UC)
latest_model = client.get_model_version_by_alias(full_model_name, "prod")

In [0]:
from databricks.sdk import WorkspaceClient
from databricks.sdk.service.serving import ServedEntityInput, EndpointCoreConfigInput, AutoCaptureConfigInput

serving_endpoint_name = "mcutini_diabetes_readmissions_endpoint"
w = WorkspaceClient()

endpoint_config = EndpointCoreConfigInput(
    name=serving_endpoint_name,
    served_entities=[
        ServedEntityInput(
            entity_name=full_model_name,
            entity_version=latest_model.version,
            scale_to_zero_enabled=True,
            workload_size="Small"
        )
    ]
)

force_update = False #Set this to True to release a newer version (the demo won't update the endpoint to a newer model version by default)
try:
  existing_endpoint = w.serving_endpoints.get(serving_endpoint_name)
  print(f"endpoint {serving_endpoint_name} already exist - force update = {force_update}...")
  if force_update:
    w.serving_endpoints.update_config_and_wait(served_entities=endpoint_config.served_entities, name=serving_endpoint_name)
except:
    print(f"Creating the endpoint {serving_endpoint_name}, this will take a few minutes to package and deploy the endpoint...")
    w.serving_endpoints.create_and_wait(name=serving_endpoint_name, config=endpoint_config)

In [0]:
from mlflow.store.artifact.models_artifact_repo import ModelsArtifactRepository
from mlflow.models.model import Model

p = ModelsArtifactRepository(f"models:/{full_model_name}@prod").download_artifacts("") 
dataset =  {"dataframe_split": Model.load(p).load_input_example(p).to_dict(orient='split')}

In [0]:
from mlflow import deployments
deployment_client = mlflow.deployments.get_deploy_client("databricks")
predictions = deployment_client.predict(endpoint=serving_endpoint_name, inputs=dataset)

print(f"Patient readmission risk: {predictions}.")

## Model Explainability

In [0]:
#For this demo, reuse our dataset to test the batch inferences
dataset_to_explain = spark.table('training_dataset')
dataset_to_explain.display()

In [0]:
import mlflow
#Enable Unity Catalog with mlflow registry
mlflow.set_registry_uri('databricks-uc')
client = mlflow.tracking.MlflowClient()

model = mlflow.pyfunc.load_model(model_uri=f"models:/{catalog_name}.{schema_name}.{model_name}@prod")
features = model.metadata.get_input_schema().input_names()

In [0]:
import shap
mlflow.autolog(disable=True)
mlflow.sklearn.autolog(disable=True)

df = dataset_to_explain.sample(fraction=0.1).toPandas()

train_sample = df[features].sample(n=np.minimum(100, df.shape[0]), random_state=42)

# Use Kernel SHAP to explain feature importance on the sampled rows from the validation set.
predict = lambda x: model.predict(pd.DataFrame(x, columns=features).astype(train_sample.dtypes.to_dict()))

explainer = shap.KernelExplainer(predict, train_sample, link="identity")
shap_values = explainer.shap_values(train_sample, l1_reg=False, nsamples=100)

In [0]:
import plotly.express as px
mean_abs_shap = np.absolute(shap_values).mean(axis=0).tolist()
df = pd.DataFrame(list(zip(mean_abs_shap,features)), columns=['SHAP_value', 'feature'])
px.bar(df.sort_values('SHAP_value', ascending=False).head(10), x='feature', y='SHAP_value')

In [0]:
shap.summary_plot(shap_values, train_sample)

In [0]:
#We'll need to add shap bundle js to display nice graph
with open(shap.__file__[:shap.__file__.rfind('/')]+"/plots/resources/bundle.js", 'r') as file:
   shap_bundle_js = '<script type="text/javascript">'+file.read()+'</script>'

html = shap.force_plot(explainer.expected_value, shap_values[0,:], train_sample.iloc[0,:])
displayHTML(shap_bundle_js + html.html())

In [0]:
plot_html = shap.force_plot(explainer.expected_value, shap_values, train_sample)
displayHTML(shap_bundle_js + plot_html.html())

In [0]:
shap.dependence_plot("INCOME", shap_values, train_sample[features], interaction_index="TOTAL_CLAIM_COST")

In [0]:
import pandas as pd
def compute_shap_values(iterator):
  for X in iterator:
    yield pd.DataFrame(explainer.shap_values(X, check_additivity=False))

df = dataset_to_explain.mapInPandas(compute_shap_values, schema=", ".join([x+"_shap_value float" for x in features]))

# Skip as this can take some time to run
#display(df)