In [0]:
cohort = spark.sql("""
WITH antibiotic_exposed AS (
  SELECT d.person_id FROM synthea.omop.drug_exposure d
  JOIN synthea.omop.person p ON d.person_id = p.person_id
  WHERE d.drug_exposure_start_date < DATE_ADD(p.birth_datetime, 365)
  GROUP BY d.person_id
),

obese_diagnosis AS (
  SELECT c.person_id FROM synthea.omop.condition_occurrence c
  JOIN synthea.omop.person p ON c.person_id = p.person_id
  WHERE YEAR(c.condition_start_date) BETWEEN YEAR(p.birth_datetime) + 5 AND YEAR(p.birth_datetime) + 10
  GROUP BY c.person_id
),

obese_bmi AS (
  SELECT m.person_id FROM synthea.omop.measurement m
  JOIN synthea.omop.person p ON m.person_id = p.person_id
  WHERE m.measurement_concept_id = 3038553 AND m.value_as_number >= 30
    AND YEAR(m.measurement_date) BETWEEN YEAR(p.birth_datetime) + 5 AND YEAR(p.birth_datetime) + 10
  GROUP BY m.person_id
),

obesity_combined AS (
  SELECT DISTINCT person_id FROM (
    SELECT * FROM obese_diagnosis
    UNION
    SELECT * FROM obese_bmi
  )
)

SELECT
  p.person_id,
  CASE WHEN a.person_id IS NOT NULL THEN 1 ELSE 0 END AS exposed,
  CASE WHEN o.person_id IS NOT NULL THEN 1 ELSE 0 END AS obese
FROM synthea.omop.person p
LEFT JOIN antibiotic_exposed a ON p.person_id = a.person_id
LEFT JOIN obesity_combined o ON p.person_id = o.person_id
""")

cohort.write.mode("overwrite").format("delta").saveAsTable("synthea.analytics.cohort")

cohort.show()

+--------------------+-------+-----+
|           person_id|exposed|obese|
+--------------------+-------+-----+
|001bf5aa-89a9-4db...|      0|    0|
|0076c218-1d8d-41a...|      0|    0|
|00a2421c-80c4-444...|      0|    0|
|00d4f791-d903-490...|      0|    0|
|00d6d2b3-ed74-446...|      0|    0|
|00ee5ffc-06ca-43f...|      0|    0|
|00f84f11-7ddf-4ab...|      0|    0|
|01c31137-599b-462...|      0|    0|
|01d5cc50-3fc2-432...|      0|    0|
|02164be8-33fb-4fd...|      0|    0|
|02916a8b-caec-4d3...|      0|    0|
|03ac1f8c-0f3d-438...|      0|    0|
|03b25307-b7d2-488...|      0|    0|
|03bf37b6-bf38-427...|      0|    0|
|03c796d7-373d-47f...|      0|    0|
|03cbebc3-0abb-492...|      0|    0|
|0438df3d-d086-46e...|      0|    0|
|044d12e2-a9c9-497...|      0|    0|
|049c880d-2651-484...|      0|    0|
|06405f78-8d04-443...|      0|    0|
+--------------------+-------+-----+
only showing top 20 rows


In [0]:
from pyspark.sql.functions import col, count, sum

cohort_agg = cohort.groupBy("exposed").agg(
    count("*").alias("total"),
    sum("obese").alias("obesity_cases")
).withColumn("rate", (col("obesity_cases") / col("total") * 100).cast("double"))

cohort_agg.write.mode("overwrite").format("delta").saveAsTable("synthea.analytics.survey_result")

cohort_agg.show()

+-------+-----+-------------+----+
|exposed|total|obesity_cases|rate|
+-------+-----+-------------+----+
|      0|  471|            0| 0.0|
+-------+-----+-------------+----+

