In [1]:
import json
import os 
import time
from pyspark.sql.types import *
import pyspark.sql.functions as f
from pyspark.sql import SparkSession
from pyspark.sql.functions import udf
from pyspark.sql.window import Window
from pyspark import SparkContext, SparkConf
from pyspark.sql import SQLContext,SparkSession
import time
import datetime

In [2]:
spark_submit_str = ('--driver-memory 30g --executor-memory 1g --packages org.apache.spark:spark-sql_2.11:2.3.0 '
'--jars /home/jovyan/jars/spark-cassandra-connector.jar'
' --conf spark.shuffle.partitions=50,spark.executor.extraJavaOptions="-XX:+UseG1GC -XX:+PrintFlagsFinal -XX:+PrintReferenceGC -verbose:gc -XX:+PrintGCDetails -XX:+PrintGCTimeStamps -XX:+PrintAdaptiveSizePolicy -XX:+UnlockDiagnosticVMOptions -XX:+G1SummarizeConcMark -Xms88g -Xmx88g -XX:InitiatingHeapOccupancyPercent=35 -XX:ConcGCThread=20" pyspark-shell')
os.environ['PYSPARK_SUBMIT_ARGS'] = spark_submit_str

In [3]:
# get flat obs from cassandra
def get_flat_obs(spark): 
    return spark.read\
    .format("org.apache.spark.sql.cassandra")\
    .options(table="flat_obs_rebuild", keyspace="etl")\
    .option("partitionColumn", "patient_id")\
    .option("fetchsize", 10000)\
    .option("lowerBound", 1)\
    .option("upperBound", 8000000)\
    .option("numPartitions", 800)\
    .load()

In [4]:
def get_obs_schema():
    schema = ArrayType(
            StructType([
                StructField('obs_id'     , IntegerType()   , True),
                StructField('concept_id',  IntegerType()   , True),
                StructField('obs_group_id',IntegerType()   , True),
                StructField('parent_concept_id', IntegerType()   , True),
                StructField('value',  StringType()   , True),
                StructField('value_type',StringType()   , True),
                StructField('obs_datetime' ,TimestampType()    , True)
            ])
        )
    return schema

In [5]:
def get_orders_schema():
    schema = ArrayType(
            StructType([
                StructField('order_id',  IntegerType()   , True),
                StructField('order_concept_id',  IntegerType()   , True),
                StructField('date_activated',  TimestampType()   , True),
                StructField('voided',  IntegerType()   , True)
            ]))
    return schema

In [6]:
def load_to_cassandra(hiv_summary_df):
             hiv_summary_df.write.format("org.apache.spark.sql.cassandra")\
            .options(table="hiv_summary", keyspace="etl")\
            .mode("append")\
            .save()

In [7]:
conf = SparkConf()
sc = SparkContext(master="local[*]",appName="flatHivSummary")
spark = SparkSession(sc)

In [8]:
start_time = datetime.datetime.utcnow()
print("started app")

#extract flat_obs from cassandra
obs_order_df = get_flat_obs(spark)

#convert stringified obs to parsed obs
obs_parsed = obs_order_df.withColumn("parsed_obs", f.from_json("obs", get_obs_schema()))

#convert stringified orders to parsed obs
orders_parsed = obs_parsed.withColumn("parsed_orders", f.from_json("orders", get_orders_schema()))

#explode nested obs structure
exploded_obs = orders_parsed.withColumn("flattened_obs", f.explode("parsed_obs"))

exploded_obs.withColumn("flattened_orders", f.explode("parsed_orders")).createOrReplaceTempView("obs")

# transformations
spark.sql("""select
            *,
            case 
                when flattened_obs.concept_id = 9082 and flattened_obs.value = '9036' then 'negative'
                else null end as patient_care_status,
            case 
                when flattened_obs.concept_id = 5303 and flattened_obs.value= '822' then 'exposed'
                when flattened_obs.concept_id = 5303 and flattened_obs.value = '664' then 'negative'
                when flattened_obs.concept_id = 5303 and flattened_obs.value = '1067' then 'unknown'
                else null end as child_hiv_status,
            case
              when encounter_type in (1,2,3,4,10,14,15,17,19,26,32,
              33,34,47,105,106,112,113,114,117,120,127,128,129,138,153,154,158) then 1
            else null end as is_clinical_encounter
            from obs
        """).createOrReplaceTempView("hiv_summary_stage_0")

spark.sql("""select 
             *,
             case 
                 when flattened_obs.concept_id = 7013 and flattened_obs.value is not null then to_date(flattened_obs.value)
                 when flattened_obs.concept_id = 7015 and flattened_obs.value is not null then to_date(flattened_obs.value)
                 when encounter_type not in (21,99999) then encounter_datetime
                 else null end as enrollment_date,
             case 
                 when flattened_obs.concept_id = 1946 and flattened_obs.value = "1065" then 1
                 when flattened_obs.concept_id = 1285 and flattened_obs.value in ("1287","9068") then 1
                 when flattened_obs.concept_id = 1596 then 1
                 when flattened_obs.concept_id = 9082 and flattened_obs.value in ("159","9036","9083","1287","9068","9079", "9504", "1285") then 1
                 when encounter_type = 31 then 1
                 else null end as out_of_care,
             case 
                 when flattened_orders.order_concept_id = 856 then 1
                 else null end as viral_load_ordered,
             case 
                 when flattened_obs.concept_id = 1255 and flattened_obs.value is not null then location_id
                 when flattened_obs.concept_id IN (1250, 1088, 2154) then location_id
                 else null end as arv_location_id
             from hiv_summary_stage_0
             where 
                 child_hiv_status is null 
                 and patient_care_status is null 
                 and is_clinical_encounter = 1 
             """).createOrReplaceTempView("hiv_summary_stage_1")


spark.sql("""select patient_id,
                    encounter_id,
                    first(location_id) as location_id,
                    first(encounter_datetime) as encounter_datetime,
                    first(encounter_type) as encounter_type,
                    first(enrollment_date, true) as enrollment_date,
                    first(out_of_care, true) as out_of_care,
                    first(location_id) as enrollment_location_id,
                    first(viral_load_ordered, true) as viral_load_ordered,
                    first(arv_location_id, true) as arv_location_id,
                    first(gender) as gender,
                    first(birthdate) as birthdate,
                    first(death_date) as death_date,
                    first(is_clinical_encounter) as is_clinical_encounter
                    from hiv_summary_stage_1
                    group by patient_id, encounter_id
        """).createOrReplaceTempView("hiv_summary_stage_2")

hiv_summary_3 = spark.sql("""select patient_id,
                    null as patient_uuid,
                    encounter_id,
                    encounter_datetime,
                    location_id,
                    gender,
                    birthdate,
                    death_date,
                    is_clinical_encounter,
                    current_timestamp() as analysis_date,
                    first(out_of_care, true) over p as out_of_care,
                    first(enrollment_date, true) over p as enrollment_date,
                    case when enrollment_date is not null then location_id
                    else null end as enrollment_location_id,
                    first(arv_location_id, true) over p as arv_start_location_id,
                    viral_load_ordered
                    from hiv_summary_stage_2
                    window p as (partition by patient_id order by encounter_datetime)
        """)

hiv_summary_3.createOrReplaceTempView("hiv_summary_stage_3")

hiv_summary = spark.sql("""select patient_id,
                    null as patient_uuid,
                    encounter_id,
                    encounter_datetime,
                    location_id,
                    gender,
                    birthdate,
                    death_date,
                    is_clinical_encounter,
                    analysis_date,
                    out_of_care,
                    enrollment_date,
                    first(enrollment_location_id, true) over p as enrollment_location_id,
                    arv_start_location_id,
                    viral_load_ordered
                    from hiv_summary_stage_3
                    window p as (partition by patient_id order by encounter_datetime)
        """)

hiv_summary.createOrReplaceTempView("hiv_sum")


#hiv_summary = spark.sql("select *,  from hiv_sum")
#spark.sql("select * from hiv_sum where ordered_viral_load is not null").show(50)
spark.sql("select * from hiv_sum where patient_id = 13460").show(50)

#print(hiv_summary.count())
load_to_cassandra(hiv_summary)

end_time = datetime.datetime.utcnow()
print("Finished: " + time.ctime()) 
print("Took {0} seconds".format((end_time - start_time).total_seconds()))

started app
+----------+------------+------------+-------------------+-----------+------+---------+----------+---------------------+--------------------+-----------+-------------------+----------------------+---------------------+------------------+
|patient_id|patient_uuid|encounter_id| encounter_datetime|location_id|gender|birthdate|death_date|is_clinical_encounter|       analysis_date|out_of_care|    enrollment_date|enrollment_location_id|arv_start_location_id|viral_load_ordered|
+----------+------------+------------+-------------------+-----------+------+---------+----------+---------------------+--------------------+-----------+-------------------+----------------------+---------------------+------------------+
|     13460|        null|     6932376|2017-06-07 10:48:43|         13|     M|     null|      null|                    1|2018-06-27 13:22:...|       null|2017-06-07 10:48:43|                    13|                   13|                 1|
|     13460|        null|     713107