In [0]:
cohort = spark.sql("""
WITH antibiotic_exposed AS (
  SELECT d.person_id FROM drug_exposure d
  JOIN person p ON d.person_id = p.person_id
  WHERE d.drug_exposure_start_date < DATE_ADD(p.birth_datetime, 365)
  GROUP BY d.person_id
),

obese_diagnosis AS (
  SELECT c.person_id FROM condition_occurrence c
  JOIN person p ON c.person_id = p.person_id
  WHERE YEAR(c.condition_start_date) BETWEEN YEAR(p.birth_datetime) + 5 AND YEAR(p.birth_datetime) + 10
  GROUP BY c.person_id
),

obese_bmi AS (
  SELECT m.person_id FROM measurement m
  JOIN person p ON m.person_id = p.person_id
  WHERE m.measurement_concept_id = 3038553 AND m.value_as_number >= 30
    AND YEAR(m.measurement_date) BETWEEN YEAR(p.birth_datetime) + 5 AND YEAR(p.birth_datetime) + 10
  GROUP BY m.person_id
),

obesity_combined AS (
  SELECT DISTINCT person_id FROM (
    SELECT * FROM obese_diagnosis
    UNION
    SELECT * FROM obese_bmi
  )
)

SELECT
  p.person_id,
  CASE WHEN a.person_id IS NOT NULL THEN 1 ELSE 0 END AS exposed,
  CASE WHEN o.person_id IS NOT NULL THEN 1 ELSE 0 END AS obese
FROM person p
LEFT JOIN antibiotic_exposed a ON p.person_id = a.person_id
LEFT JOIN obesity_combined o ON p.person_id = o.person_id
""")

cohort.write.mode("overwrite").format("delta").saveAsTable("cohort")

cohort.show()

In [0]:
from pyspark.sql.functions import col, count, sum

cohort_agg = cohort.groupBy("exposed").agg(
    count("*").alias("total"),
    sum("obese").alias("obesity_cases")
).withColumn("rate", (col("obesity_cases") / col("total") * 100).cast("double"))

cohort_agg.write.mode("overwrite").format("delta").saveAsTable("survey_result")

cohort_agg.show()