In [1]:
# The code was removed by Watson Studio for sharing.

Waiting for a Spark session to start...
Spark Initialization Done! ApplicationId = app-20210506093546-0003
KERNEL_ID = a893abc8-1668-4e93-a7f0-a0ef2d02d4dc


# Hospital Research

The Research Team [GICOS](http://www.saber.ula.ve/handle/123456789/2453) (Group of Investigation in Comunity and Heathcare) is working in partnership with the Hospital HULA, a Healthcare College institution located at [Merida, Venezuela](https://www.google.com/maps/place/Instituto+Aut%C3%B3nomo+Hospital+Universitario+de+Los+Andes+-+I.A.H.ULA/@8.582205,-71.1584343,17z/data=!4m5!3m4!1s0x8e6487088631fa61:0xf5daba20de41c28e!8m2!3d8.5821997!4d-71.1562456)

## The problem
The hospital has a big number of patients that were readmitted after surgery, with limited or none existent insurance so they had to ask for a public insurance provided at the hospital, but the insurance company was receiving too much cost and will have to increase their price to cover this kind of patients. So, the insurance company needed to predict specifically which was going to be the amount for their new customers.

## The Solution

Train a model that takes four sets of beneficary-related data sets to predict beneficiary price for customers. This notebook runs on the latest Python version and Spark.

## Table of Contents

1. [Pre-requisites](#pre)
2. [Load the data](#load-data)
3. [Extract and transform the data](#etl)
4. [Analyze the data](#analyze)
5. [Train the model](#train)

<a id = "pre"></a>
## Pre-requisites

#### Start Spark session

In [2]:
spark = SparkSession.builder \
        .appName("spark-model-training") \
        .getOrCreate()

### Authentication

In [4]:
# The code was removed by Watson Studio for sharing.

In [4]:
# The code was removed by Watson Studio for sharing.

## Load the data set 

This dataset only has ICD9 codes for diagnoses.  So we're using AHRQ CCS categories to bucket ICD9 codes into ~250 CCS buckets

The file 'DB.csv' is loaded into the project

In [5]:
from pyspark.sql.functions import col, lit, trim

ccs_df = (
    spark.read
    .format('org.apache.spark.sql.execution.datasources.csv.CSVFileFormat') \
    .option("header", "true") \
    .option("comment", "N") \
    .option("quote", "'") \
    .load(cos.url('$dxref 2015.csv', 'predictinghealthcarecostsfromclai-donotdelete-pr-zgfwvlnfa1en0m')) \
    .withColumn("dx_code", trim(col("ICD-9-CM CODE"))) \
    .withColumn("ccs_code", trim(col("CCS CATEGORY"))) \
    .withColumn("ontology", lit("ICD9")) \
    .select("dx_code", "ontology", "ccs_code")
)
ccs_df.show()

+-------+--------+--------+
|dx_code|ontology|ccs_code|
+-------+--------+--------+
|       |    ICD9|       0|
|  01000|    ICD9|       1|
|  01001|    ICD9|       1|
|  01002|    ICD9|       1|
|  01003|    ICD9|       1|
|  01004|    ICD9|       1|
|  01005|    ICD9|       1|
|  01006|    ICD9|       1|
|  01010|    ICD9|       1|
|  01011|    ICD9|       1|
|  01012|    ICD9|       1|
|  01013|    ICD9|       1|
|  01014|    ICD9|       1|
|  01015|    ICD9|       1|
|  01016|    ICD9|       1|
|  01080|    ICD9|       1|
|  01081|    ICD9|       1|
|  01082|    ICD9|       1|
|  01083|    ICD9|       1|
|  01084|    ICD9|       1|
+-------+--------+--------+
only showing top 20 rows



In [5]:
from datetime import datetime
from pyspark.sql.functions import lit, struct

eob = spark.read.json(project_cos.url("ExplanationOfBenefit.ndjson", project.get_project_bucket_name()))
eob.show(5)

coverage = spark.read.json(project_cos.url("Coverage.ndjson", project.get_project_bucket_name()))
coverage = coverage.withColumn("period", struct(
    lit(datetime(1900, 1, 1)).alias("start"), 
    lit(datetime(2099, 1, 1)).alias("end")
))
coverage.show(5)

patient = spark.read.json(project_cos.url("Patient.ndjson", project.get_project_bucket_name()))
patient = patient.withColumn("birthDate", lit(datetime(1960, 1, 1)))
patient.show(5)

+--------------+--------------------+--------+--------------------+--------------------+--------+---------------+-------------------+--------------------+-----------+--------------------+--------------------+--------------------+--------------------+---------+--------+--------------------+------+---------+--------------------+
|benefitBalance|      billablePeriod|careTeam|           diagnosis|           extension|facility|hospitalization|                 id|          identifier|information|           insurance|                item|             patient|             payment|procedure|provider|        resourceType|status|totalCost|                type|
+--------------+--------------------+--------+--------------------+--------------------+--------+---------------+-------------------+--------------------+-----------+--------------------+--------------------+--------------------+--------------------+---------+--------+--------------------+------+---------+--------------------+
|          nu

<a id = "etl"></a>
## Extract and transform the data

Pull out age and gender from the patient resource file.

We chose 2000-01-01 because thats what this synthetic dataset supports.

In [6]:
from pyspark.sql.functions import floor, lit, months_between

demographics = (
    patient.select(
        col("id").alias("patient"),
        "gender",
        floor(months_between(lit(datetime(2000, 1, 1)), "birthDate") / 12).alias("age")
    )
)
demographics.show()

+--------------+------+---+
|       patient|gender|age|
+--------------+------+---+
|19990000000140|  male| 40|
|19990000000141|  male| 40|
|19990000000142|  male| 40|
|19990000000144|female| 40|
|19990000000145|female| 40|
|19990000000146|  male| 40|
|19990000000147|  male| 40|
|19990000000149|  male| 40|
|19990000000150|female| 40|
|19990000000151|female| 40|
|19990000000153|female| 40|
|19990000000154|female| 40|
|19990000000155|female| 40|
|19990000000156|  male| 40|
|19990000000158|female| 40|
|19990000000159|female| 40|
|19990000000160|female| 40|
|19990000000162|  male| 40|
|19990000000163|  male| 40|
|19990000000164|  male| 40|
+--------------+------+---+
only showing top 20 rows



We're going to be predicting total paid $ year n+1 given features calculated on data from year n.  So here we're splitting the data into `features_period` (year n) and `outcome_period` (year n+1) for our ETL

In [7]:
from pyspark.sql.functions import col

features_period = eob.where(
    (col("billablePeriod.start") >= datetime(1999, 1, 1)) &
    (col("billablePeriod.start") < datetime(2000, 1, 1))
)

outcome_period = eob.where(
    (col("billablePeriod.start") >= datetime(2000, 1, 1)) &
    (col("billablePeriod.start") < datetime(2001, 1, 1)) 
)

Create a row/vector of CCS codes that each patient has during the `features_period`

In [8]:
from pyspark.sql.functions import countDistinct, explode, substring

ccs_categories = ccs_df.select("ccs_code").distinct().rdd.flatMap(lambda x: x).collect()

ccs_codes = (
    # Explode arrays to get to codes
    features_period.select(
        substring("patient.reference", 9, 255).alias("patient"), 
        explode("diagnosis").alias("diagnosis")
    ).select(
        "patient",
        col("diagnosis.sequence").alias("sequence"),
        explode("diagnosis.diagnosisCodeableConcept.coding").alias("coding")
    ).select(
        "patient",
        "sequence",
        col("coding.system").alias("system"),
        col("coding.code").alias("dx_code")
    )
    # Map ICD -> CCS
    .join(ccs_df, on="dx_code", how="inner")
    .select("patient", "ccs_code")
    .distinct()
    # Groupby pivot to get 1/0 columns for CCS codes
    .groupby("patient")
    .pivot("ccs_code", ccs_categories)
    .agg(countDistinct("patient"))
    .fillna(0)
)
ccs_codes.show(truncate=False)

+--------------+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+----+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+----+----+---+---+----+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+----+---+---+---+---+---+---+---+---+---+--

<a id = "analyze"></a>
### Analyze the data

Count the number of office visits in the `features_period` each patient had.

Definitions for the CPT codes we're using to define office visits https://www.aafp.org/journals/fpm/blogs/inpractice/entry/coding_office_visits_the_easy_way.html

In [9]:
office_visit_em_codes = [str(i) for i in range(99201, 99206)] + [str(i) for i in range(99211, 99216)]

em_dates = (
    features_period.select(
        substring("patient.reference", 9, 255).alias("patient"),
        explode("item").alias("item")
    ).select(
        "patient",
        col("item.servicedPeriod.start").alias("service_date"),
        col("item.sequence").alias("sequence"),
        explode("item.service.coding").alias("coding")
    ).select(
        "patient",
        "service_date",
        "sequence",
        col("coding.system").alias("system"),
        col("coding.code").alias("code")
    ).where(
        (col("system") == "https://bluebutton.cms.gov/resources/codesystem/hcpcs") &
        (col("code").isin(office_visit_em_codes))
    ).groupby("patient")
    .agg(countDistinct("service_date").alias("em_dates"))
)
em_dates.show()

+--------------+--------+
|       patient|em_dates|
+--------------+--------+
|19990000000166|       2|
|19990000000138|       1|
|19990000000151|       1|
|19990000000142|       1|
|19990000000181|       2|
|19990000000141|       1|
|19990000000155|       2|
|19990000000174|       6|
|19990000000161|       1|
|19990000000137|       4|
|19990000000182|       3|
|19990000000157|       2|
|19990000000176|       5|
|19990000000169|       2|
|19990000000170|       1|
|19990000000145|       3|
|19990000000183|       2|
|19990000000152|       2|
|19990000000144|       3|
|19990000000175|       5|
+--------------+--------+
only showing top 20 rows



Count the number of unique days each patient has claims on during the `features_period`.

In [10]:
from pyspark.sql.functions import countDistinct

num_claims = (
    features_period
    .groupby(substring("patient.reference", 9, 255).alias("patient"))
    .agg(countDistinct("billablePeriod.start").alias("unique_days"))
)
num_claims.show()

+--------------+-----------+
|       patient|unique_days|
+--------------+-----------+
|19990000000166|          4|
|19990000000138|          1|
|19990000000151|          1|
|19990000000142|          2|
|19990000000181|          2|
|19990000000141|          2|
|19990000000155|          3|
|19990000000174|          8|
|19990000000161|          2|
|19990000000137|          6|
|19990000000182|          7|
|19990000000157|          3|
|19990000000176|          7|
|19990000000169|          3|
|19990000000170|          2|
|19990000000145|          6|
|19990000000183|          3|
|19990000000152|          5|
|19990000000144|          3|
|19990000000175|          7|
+--------------+-----------+
only showing top 20 rows



Create the outcome variable we'll be trying to predict.  Sum the `total_paid_amt` across all claims for patients in the `outcome_period`

In [11]:
from pyspark.sql.functions import sum

total_paid = (
    outcome_period
    .groupby(substring("patient.reference", 9, 255).alias("patient"))
    .agg(sum("payment.amount.value").alias("total_paid_amt"))
)
total_paid.show()

+--------------+--------------+
|       patient|total_paid_amt|
+--------------+--------------+
|19990000000166|           650|
|19990000000138|          2930|
|19990000000151|           270|
|19990000000142|            90|
|19990000000181|           610|
|19990000000141|           250|
|19990000000155|           130|
|19990000000174|          1440|
|19990000000161|         10900|
|19990000000137|           230|
|19990000000182|            20|
|19990000000157|          2530|
|19990000000176|          2060|
|19990000000169|           510|
|19990000000170|           500|
|19990000000145|           360|
|19990000000183|           400|
|19990000000152|           140|
|19990000000144|           340|
|19990000000175|           920|
+--------------+--------------+
only showing top 20 rows



<a id = "train"></a>
## Train the model

Join everything together into a "one row per prediction" dataset to feed through a machine learning algorithm.

We're splitting the dataset an 80% training and 20% validation split

In [12]:
modeling_df = (
    demographics
    .join(ccs_codes, on="patient", how="left_outer")
    .join(em_dates, on="patient", how="left_outer")
    .join(num_claims, on="patient", how="left_outer")
    .fillna(0, subset= ccs_categories + em_dates.columns + num_claims.columns)
    .join(total_paid, on="patient", how="inner")
)
modeling_df.cache()
modeling_df.show()

train_df, test_df = modeling_df.randomSplit([0.8, 0.2], seed=72)

+--------------+------+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+----+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+----+----+---+---+----+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+----+---+---+---+----+---+---+---+---+---+---+---

Use the `pyspark.ml` module to convert the modeling data above into the appropriate data types and formats to feed through the `LinearRegression` training algorithm.

In [13]:
from pyspark.ml import Model, Pipeline
from pyspark.ml.feature import Bucketizer, OneHotEncoder, StringIndexer, VectorAssembler
from pyspark.ml.regression import LinearRegression

age_splits = [
    -float("inf"),
    0,
    10,
    20,
    30,
    40,
    50,
    60,
    70,
    80,
    float("inf")
]

pipeline = Pipeline(stages=[
    StringIndexer(inputCol="gender", outputCol="gender_indexed"),
    OneHotEncoder(inputCol="gender_indexed", outputCol="gender_encoded"),
    Bucketizer(splits=age_splits, inputCol="age", outputCol="age_encoded"),   
    VectorAssembler(
        inputCols=["age_encoded", "em_dates", "gender_encoded", "unique_days"] + ccs_categories,
        outputCol="features"
    ),
    LinearRegression(featuresCol="features", labelCol="total_paid_amt", fitIntercept=True)
])

trained_model_pipeline = pipeline.fit(train_df)
print(f"Training Data r2={trained_model_pipeline.stages[-1].summary.r2}")

Training Data r2=0.9999999999804761


In [14]:
from pyspark.ml.evaluation import RegressionEvaluator

predictions = trained_model_pipeline.transform(test_df)
evaluator = RegressionEvaluator(
    predictionCol="prediction",
    labelCol="total_paid_amt",
    metricName="r2"
)
print(f"Testing Data r2={evaluator.evaluate(predictions)}")

Testing Data r2=-75.01896166725219
