In [0]:
import requests
import re
from datetime import datetime
from pyspark.sql import functions as F
from concurrent.futures import ThreadPoolExecutor
from threading import Lock

In [0]:
dbutils.widgets.text("batch_date", "")

In [0]:
# Global Variables
hospital_datasets = get_hospital_datasets()
max_parallel = 8 #variable for threadpool
user_input = dbutils.widgets.get("batch_date")
current_catalog = spark.sql("SELECT current_catalog() AS current_catalog").collect()[0].current_catalog
if user_input:
    try:
        dt = datetime.strptime(user_input, "%Y-%m-%d")
    except ValueError:
        raise ValueError("Date must be in YYYY-MM-DD format")
else:
    dt = datetime.today()
print(dt)
print(current_catalog)

In [0]:
def get_hospital_datasets():
    #Initialize api dataset
    cms_url = "https://data.cms.gov/provider-data/api/1/metastore/schemas/dataset/items"
    resp = requests.get(cms_url)
    cms_data = resp.json()

    hospital_datasets = [
        ds for ds in cms_data
        if "Hospitals" in str(ds.get("theme", "")) or "hospital" in str(ds.get("title", "")).lower()
    ]

    return hospital_datasets

In [0]:
def get_csv_links(hospital_datasets):
    csv_links = []
    for dataset in hospital_datasets:
        for dist in dataset.get("distribution", []):
            download_url = dist.get("downloadURL")
            if download_url and download_url.endswith(".csv"):
                csv_links.append(download_url)
    return csv_links

In [0]:
def create_schemas():
    spark.sql("CREATE SCHEMA IF NOT EXISTS metadata")
    spark.sql("CREATE SCHEMA IF NOT EXISTS bronze")

In [0]:
def create_hospital_meta_table():
    spark.sql("""
    CREATE TABLE IF NOT EXISTS metadata.hospital_meta (
        table_name STRING,
        download_url STRING,
        last_update_date TIMESTAMP, -- modified date
        batch_date TIMESTAMP, --  current date
        file_count INT,
        update_flag STRING -- if 'Y" means we need to update the table, else 'N' means we don't need to update the table
    ) USING DELTA 
    """)

In [0]:
def to_snake_case(name):
        name = re.sub(r'[\s\-]+', '_', name)
        name = re.sub(r'([a-z0-9])([A-Z])', r'\1_\2', name)
        name = re.sub(r'[^a-zA-Z0-9_]', '', name)
        return name.lower()

In [0]:
# Downloading data from api and store it as a delta file in the landing zone

meta_update_records = [] # List of records to update in metadata table

def download_data(url):
    try:
        file_name = url.split("/")[-1]
        if file_name.lower().endswith('.csv'):
            file_name = file_name[:-4]

        # Replace any '-' with '_'
        file_name = file_name.replace("-", "_")

        dbfs_path = f"dbfs:/tmp/{file_name}.csv"

        # Find modified date for this url
        modified_date = None
        for ds in hospital_datasets:
            for dist in ds.get("distribution", []):
                if dist.get("downloadURL") == url:
                    modified_date = ds.get("modified")
                    break
            if modified_date:
                break

        # Convert modified_date to datetime variable
        if modified_date:
            from datetime import datetime
            try:
                modified_date_dt = datetime.strptime(modified_date, "%Y-%m-%d")
            except Exception:
                modified_date_dt = None
        else:
            modified_date_dt = None

        # Check metadata table for last update date
        meta_table = f"{current_catalog}.metadata.hospital_meta"
        table_name = "_".join(file_name.lower().split())
        meta_df = spark.table(meta_table).filter(F.col("table_name") == table_name)
        meta_row = meta_df.select("last_update_date").collect()
        last_update_date = meta_row[0]["last_update_date"] if meta_row else None

        # Only download if record doesn't exist or modified_date is greater
        if last_update_date:
            if not modified_date_dt or modified_date_dt <= last_update_date:
                print(f"Skipping {url} (not newer than last update)")
                return None

        # Download directly to DBFS
        dbutils.fs.cp(url, dbfs_path)

        # Read CSV
        df = spark.read.option("header", "true").csv(dbfs_path)

        # Convert to snake_case
        new_columns = [to_snake_case(c) for c in df.columns]
        df = df.toDF(*new_columns)

        # Clean up the file name
        cleaned_name = "_".join(file_name.lower().split())
        delta_path = f"/mnt/landing/hospital/{dt.year}/{dt.month}/{dt.day}/{cleaned_name}"

        # Save as Delta
        count = df.count()
        df.write.format("delta").mode("overwrite").save(delta_path)

        # Collect metadata record for later batch update
        record = (cleaned_name, url, modified_date, count)
        if record not in meta_update_records:
            meta_update_records.append(record)

        print(f"Saved {url} to {delta_path} and collected metadata")
        return meta_update_records

    except Exception as e:
        print(f"Failed {url}: {e}")
        return None

In [0]:
def update_metadata(meta_update_records):
    from pyspark.sql.functions import lit, col, when, greatest, to_timestamp

    meta_table = f"{current_catalog}.metadata.hospital_meta"

    # meta_update_records: List of tuples (table_name, download_url, last_update_date, file_count)
    # Add batch_date and update_flag
    batch_date = dt
    records = []
    for rec in meta_update_records:
        table_name, download_url, last_update_date, file_count = rec
        records.append((
            table_name,
            download_url,
            last_update_date,
            batch_date,
            file_count,
            "Y"  # default update_flag to 'Y'
        ))

    columns = ["table_name", "download_url", "last_update_date", "batch_date", "file_count", "update_flag"]
    meta_update_df = spark.createDataFrame(records, columns)
    meta_update_df = meta_update_df.withColumn("last_update_date", to_timestamp(col("last_update_date")))

    meta_update_df.createOrReplaceTempView("meta_update")

    # merge below updates metadata table as follows:
    # - If a record with same table_name exists:
    #   - If new last_update_date is more recent, update all fields and set update_flag to 'Y'
    #   - If new last_update_date is the same or older, do not update any fields
    # - If no record exists for table_name, insert a new row
    spark.sql(f"""
    MERGE INTO {meta_table} t
    USING meta_update s
    ON t.table_name = s.table_name
    WHEN MATCHED AND s.last_update_date > t.last_update_date THEN
      UPDATE SET
        t.download_url = s.download_url,
        t.last_update_date = s.last_update_date,
        t.batch_date = s.batch_date,
        t.file_count = s.file_count,
        t.update_flag = 'Y'
    WHEN NOT MATCHED THEN
      INSERT (table_name, download_url, last_update_date, batch_date, file_count, update_flag)
      VALUES (s.table_name, s.download_url, s.last_update_date, s.batch_date, s.file_count, s.update_flag)
    """)
    meta_update_records = [] #clear out the list after write

In [0]:
def landing_to_bronze(landing_base, catalog, schema):
    from pathlib import Path
    from concurrent.futures import ThreadPoolExecutor

    meta_table = f"{current_catalog}.metadata.hospital_meta"
    landing_path = f"{landing_base}/{dt.year}/{dt.month}/{dt.day}/"

    # Get tables with update_flag = 'Y'
    flagged_tables_df = spark.sql(f"SELECT table_name FROM {meta_table} WHERE update_flag = 'Y'")
    flagged_tables = [row['table_name'] for row in flagged_tables_df.collect()]
    files = dbutils.fs.ls(landing_path)

    loaded_count = 0
    loaded_count_lock = Lock()

    def process_folder(f):
        nonlocal loaded_count
        if f.isDir():
            table_name = Path(f.name).stem.lower()
            if table_name in flagged_tables:
                delta_path = f.path
                uc_table = f"{catalog}.{schema}.{table_name}"
                df = spark.read.format("delta").load(delta_path)
                df.write.format("delta").mode("overwrite").option("overwriteSchema", "true").saveAsTable(uc_table)
                # Set update_flag to 'N' after loading
                spark.sql(f"UPDATE {meta_table} SET update_flag = 'N' WHERE table_name = '{table_name}'")
                with loaded_count_lock:
                    loaded_count += 1
                    print(f"Loaded {delta_path} into {uc_table} (Total loaded: {loaded_count})")

    with ThreadPoolExecutor(max_workers=8) as executor:
        executor.map(process_folder, files)

    print(f"Total tables loaded: {loaded_count}")

In [0]:
def main():
    #initializing: 
    csv_links = get_csv_links(hospital_datasets)
    create_schemas()
    create_hospital_meta_table()

    # step 1 - download data 
    urls_to_download = csv_links

    with ThreadPoolExecutor(max_workers=max_parallel) as executor:
        executor.map(download_data, urls_to_download)

    # step 2 - update metadata
    if len(meta_update_records) > 0:
        print("Updating metadata:")
        update_metadata(meta_update_records)

    # step 3 - move data to bronze layer tbls
    landing_base = "/mnt/landing/hospital"
    schema = "bronze"
    landing_to_bronze(landing_base, current_catalog, schema)

In [0]:
main()

In [0]:
# display metadata
display(spark.sql(f"select * from {current_catalog}.metadata.hospital_meta"))

In [0]:
# setting up situation where we have updated files
spark.sql(f"""UPDATE {current_catalog}.metadata.hospital_meta
SET last_update_date = timestamp('2020-12-01T00:00:00.000+00:00')
WHERE last_update_date <= timestamp('2024-12-03T00:00:00.000+00:00')""")

In [0]:
# only downloading files that have modified date newer than the last_updated_date in the metadata table for respective dataset
main()