# Provider Fraud Detection Model - Data Wrangling
### In this notebook we will focus on collecting, organizing, defining, and cleaning the relevant datasets for the Provider Fraud Detection model
#### 1.) Ingesting & Inspecting Files
#### 2.) Standardizing Column Names & Types
#### 3.) Cleaning Values
#### 4.) Deduping (Removing Duplicates)
#### 5.) Merging/Joining Tables

In [1]:
# loading needed modules
import re
from collections import Counter
from functools import reduce

from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    DoubleType,
    DateType
)

from pyspark.sql.functions import (
    col,
    expr,
    lit,
    trim,
    to_date,
    current_date,
    datediff,
    sum as sum_,
    avg,
    stddev,
    countDistinct,
    when,
    try_divide,
    row_number,
    min,
    max,
    mean,
    isnan,
    isnull,
    regexp_replace,
    substring,
    upper,
    coalesce,
    sha2
)

In [2]:
# this function will help me normalize my column names to ensure a smooth workflow
# this function will...
    # lowercase, trim whitespace
    # drop parenthetical notes
    # replace non-alphanumeric chars with underscores
    # collapse runs of underscores
    # strip leading/trailing underscores

def normalize_col(column_name: str) -> str:
    text = column_name.strip().lower()
    text = re.sub(r'\(.*?\)', '', text)
    text = re.sub(r'[^0-9a-z]+', '_', text)
    text = re.sub(r'_+', '_', text)
    return text.strip('_')

In [3]:
# starting Spark session
spark = SparkSession.builder.appName('ProviderFraudDetection').getOrCreate()

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/12 18:42:02 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


## NPPES NPI Registry Dataset

In [4]:
# creating a csv path variable
npi_csv_path = "NPPES_Data_Dissemination_July_2025_V2/npidata_pfile_20050523-20250713.csv"


In [5]:
# reading the CSV file into a DataFrame
full_npi_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(npi_csv_path)
)

In [6]:
# inspecting the npi dataset column names
print("Available columns (first 50 shown):")
for c in full_npi_df.columns[:50]:
    print(repr(c))
    

Available columns (first 50 shown):
'NPI'
'Entity Type Code'
'Replacement NPI'
'Employer Identification Number (EIN)'
'Provider Organization Name (Legal Business Name)'
'Provider Last Name (Legal Name)'
'Provider First Name'
'Provider Middle Name'
'Provider Name Prefix Text'
'Provider Name Suffix Text'
'Provider Credential Text'
'Provider Other Organization Name'
'Provider Other Organization Name Type Code'
'Provider Other Last Name'
'Provider Other First Name'
'Provider Other Middle Name'
'Provider Other Name Prefix Text'
'Provider Other Name Suffix Text'
'Provider Other Credential Text'
'Provider Other Last Name Type Code'
'Provider First Line Business Mailing Address'
'Provider Second Line Business Mailing Address'
'Provider Business Mailing Address City Name'
'Provider Business Mailing Address State Name'
'Provider Business Mailing Address Postal Code'
'Provider Business Mailing Address Country Code (If outside U.S.)'
'Provider Business Mailing Address Telephone Number'
'Provider B

In [7]:
# selecting the relevant columns I want to keep
tax_cols = []
for i in range(1, 16):
    tax_cols += [f"Healthcare Provider Taxonomy Code_{i}",
                 f"Healthcare Provider Primary Taxonomy Switch_{i}"]

keep_cols_npi = [
    "NPI",
    "Entity Type Code",
    "Provider Business Practice Location Address State Name",
    "Provider Business Practice Location Address Postal Code",
    "Is Organization Subpart",
    "Parent Organization TIN",
    "Parent Organization LBN",
    "Provider Enumeration Date",
    "Last Update Date",
    "NPI Deactivation Date",
    "NPI Reactivation Date",
    "Is Sole Proprietor",
] + tax_cols

In [8]:
# creating a new DataFrame with only the selected columns
npi_df = full_npi_df.select(*keep_cols_npi)


In [9]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
npi_df.printSchema()
print("Sample:")
npi_df.limit(5).show(truncate=False)

Schema:
root
 |-- NPI: string (nullable = true)
 |-- Entity Type Code: string (nullable = true)
 |-- Provider Business Practice Location Address State Name: string (nullable = true)
 |-- Provider Business Practice Location Address Postal Code: string (nullable = true)
 |-- Is Organization Subpart: string (nullable = true)
 |-- Parent Organization TIN: string (nullable = true)
 |-- Parent Organization LBN: string (nullable = true)
 |-- Provider Enumeration Date: string (nullable = true)
 |-- Last Update Date: string (nullable = true)
 |-- NPI Deactivation Date: string (nullable = true)
 |-- NPI Reactivation Date: string (nullable = true)
 |-- Is Sole Proprietor: string (nullable = true)
 |-- Healthcare Provider Taxonomy Code_1: string (nullable = true)
 |-- Healthcare Provider Primary Taxonomy Switch_1: string (nullable = true)
 |-- Healthcare Provider Taxonomy Code_2: string (nullable = true)
 |-- Healthcare Provider Primary Taxonomy Switch_2: string (nullable = true)
 |-- Healthcare P

25/08/12 18:42:12 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+----------+----------------+------------------------------------------------------+-------------------------------------------------------+-----------------------+-----------------------+-----------------------+-------------------------+----------------+---------------------+---------------------+------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+--------------------------

In [10]:
# getting the total number of rows in the DataFrame
total_rows = npi_df.count()
print(f'Total rows: {total_rows}')



Total rows: 9026996


                                                                                

In [11]:
# getting the null counts for each column
npi_df.select(
    *[
        sum_(col(c).isNull().cast('int')).alias(c) for c in npi_df.columns
    ]
).show()



+---+----------------+------------------------------------------------------+-------------------------------------------------------+-----------------------+-----------------------+-----------------------+-------------------------+----------------+---------------------+---------------------+------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+---------------------------------

                                                                                

In [12]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

# applying the normalized names to every column I plan to keep
cleaned_names = []
for original_name in keep_cols_npi:
    new_name = normalize_col(original_name)
    cleaned_names.append(new_name)

# counting how many times each cleaned name shows up
name_counts = Counter(cleaned_names)

# pulling out any names that show up more than once
duplicates = [name for name, count in name_counts.items() if count > 1]

# bailing out early, if a coliision is detected
if duplicates:
    raise RuntimeError(f'Column name conflict after normalization: {duplicates}')

In [13]:
# building a new DataFrame with normalized column names

# building a list of expressions to rename the columns in the DataFrame
renamed_columns = [
    col(old).alias(new)
    for old, new in zip(keep_cols_npi, cleaned_names)
]

# applying 
normalized_npi_df = npi_df.select(*renamed_columns)

In [14]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
normalized_npi_df.printSchema()
normalized_npi_df.show(5, truncate=False)

root
 |-- npi: string (nullable = true)
 |-- entity_type_code: string (nullable = true)
 |-- provider_business_practice_location_address_state_name: string (nullable = true)
 |-- provider_business_practice_location_address_postal_code: string (nullable = true)
 |-- is_organization_subpart: string (nullable = true)
 |-- parent_organization_tin: string (nullable = true)
 |-- parent_organization_lbn: string (nullable = true)
 |-- provider_enumeration_date: string (nullable = true)
 |-- last_update_date: string (nullable = true)
 |-- npi_deactivation_date: string (nullable = true)
 |-- npi_reactivation_date: string (nullable = true)
 |-- is_sole_proprietor: string (nullable = true)
 |-- healthcare_provider_taxonomy_code_1: string (nullable = true)
 |-- healthcare_provider_primary_taxonomy_switch_1: string (nullable = true)
 |-- healthcare_provider_taxonomy_code_2: string (nullable = true)
 |-- healthcare_provider_primary_taxonomy_switch_2: string (nullable = true)
 |-- healthcare_provider_

In [None]:
# casting and standardizing the data types
clean = (normalized_npi_df
         
    # trimming the stray spaces so joins/comparisons work correctly
    .withColumn("npi", trim(col("npi")))
    .withColumn("npi", regexp_replace(col("npi"), r"\D", ""))
    .withColumn("entity_type_code", col("entity_type_code").cast("int"))

    # mapping yes/no to 1/0
    .withColumn("is_organization_subpart", when(col("is_organization_subpart") == "Y", 1).otherwise(0))

    # is_sole_proprietor: only meaningful for individuals (entity_type_code==1)
    # Y->1, N->0, anything else (incl. X) -> null; orgs -> null
    .withColumn("is_sole_proprietor_raw", upper(trim(col("is_sole_proprietor"))))
    .withColumn(
        "is_sole_proprietor",
        when(col("entity_type_code") == 1,
             when(col("is_sole_proprietor_raw") == "Y", 1)
            .when(col("is_sole_proprietor_raw") == "N", 0)
            .otherwise(None)
        ).otherwise(None)
    ).drop("is_sole_proprietor_raw")

    # parse dates strings into real Date columns
    .withColumn("provider_enumeration_date", to_date(col("provider_enumeration_date"), "MM/dd/yyyy"))
    .withColumn("last_update_date", to_date(col("last_update_date"), "MM/dd/yyyy"))
    .withColumn("npi_deactivation_date", to_date(col("npi_deactivation_date"), "MM/dd/yyyy"))
    .withColumn("npi_reactivation_date", to_date(col("npi_reactivation_date"), "MM/dd/yyyy"))
)

In [16]:
# normalizing ZIP & TIN

# creating a zip5 with just the first 5 digits of the postal code column
clean = clean.withColumn(
    "zip5",
    substring(regexp_replace(col("provider_business_practice_location_address_postal_code"), r"[^0-9]", ""), 1, 5)
)

# normalizing state names to 2-letter codes
clean = clean.withColumn(
    "state_abbr",
    when(col("provider_business_practice_location_address_state_name").rlike(r"^[A-Za-z]{2}$"),
         upper(col("provider_business_practice_location_address_state_name"))
    )
    .otherwise(None)   # optional: keep None if it's a full name or messy
)

# stripping non-digits from TIN for consistent grouping
clean = (clean
    .withColumn("parent_org_tin_norm", regexp_replace(col("parent_organization_tin"), r"\D", ""))
)


In [17]:
# primary taxonomy extraction with nuance

# NPPES has 15 slots so building an expression that picks the code where the corresponding switch is "Y"
explicit_expr = None
for i in range(1, 16):
    sw = f"healthcare_provider_primary_taxonomy_switch_{i}"
    cd = f"healthcare_provider_taxonomy_code_{i}"
    condY = (col(sw) == "Y") & col(cd).isNotNull()
    explicit_expr = when(condY, col(cd)) if explicit_expr is None else explicit_expr.when(condY, col(cd))

# first non-null taxonomy code across all slots (for fallback)
code_cols = [col(f"healthcare_provider_taxonomy_code_{i}") for i in range(1, 16)]
first_code = coalesce(*code_cols)

# any 'X' present across switches? (for your "unknown → fallback" behavior)
any_x = None
for i in range(1, 16):
    sw = f"healthcare_provider_primary_taxonomy_switch_{i}"
    condX = (col(sw) == "X")
    any_x = condX if any_x is None else (any_x | condX)

clean = (clean
    .withColumn(
        "primary_taxonomy",
        when(explicit_expr.isNotNull(), explicit_expr)
        .when(any_x & first_code.isNotNull(), first_code)   # your fallback rule
        .otherwise(lit(None))
    )
    .withColumn(
        "primary_taxonomy_explicit",
        when(explicit_expr.isNotNull(), lit(1)).otherwise(lit(0))
    )
    .withColumn(
        "primary_taxonomy_unknown",
        when(explicit_expr.isNull() & any_x, lit(1)).otherwise(lit(0))
    )
    .withColumn(
        "primary_taxonomy_source",
        when(explicit_expr.isNotNull(), lit("explicit"))
        .when(any_x & first_code.isNotNull(), lit("fallback"))
        .otherwise(lit(None))
    )
)


In [None]:
# lifecycle / derived features
clean = (clean
    # this will show how many days since the Provider first got their NPI
    .withColumn(
        "npi_age_days", datediff(current_date(), col("provider_enumeration_date")))

    # this will show if the NPI is 'active' today...
    # if there's no deactivation date -> it's active
    # or if the Provider was deactivated and then reactivated after the deactivation date -> it's active
    .withColumn(
        "is_active",
        when(
            (col("npi_deactivation_date").isNull()) |
            ((col("npi_reactivation_date").isNotNull()) & (col("npi_reactivation_date") >= col("npi_deactivation_date"))),
            lit(1)
        ).otherwise(lit(0))
    )

    # this will show if a Provider was deactivated and then reactivated at some point
    .withColumn("was_reactivated", when(col("npi_reactivation_date").isNotNull(), 1).otherwise(0))
    .withColumn(
        "deactivated_then_reactivated",
        when((col("npi_deactivation_date").isNotNull()) & (col("npi_reactivation_date").isNotNull()), 1).otherwise(0)
    )

    # this will show if there are any location info (ex. if state field is not null)
    .withColumn(
        "has_location",
        when(col("state_abbr").isNotNull(), 1).otherwise(0)
)

    # this will show if the entity type code is missing
    .withColumn(
        "missing_entity_type",
        when(col("entity_type_code").isNull(), 1).otherwise(0)
    )
)



In [None]:
# ensuring 'primary_taxonomy' exists using whatever slots are present

slots = []
for i in range(1, 16):
    sw = f"healthcare_provider_primary_taxonomy_switch_{i}"
    cd = f"healthcare_provider_taxonomy_code_{i}"
    if sw in clean.columns and cd in clean.columns:
        slots.append((sw, cd))

if "primary_taxonomy" not in clean.columns:
    if slots:
        explicit = None
        for sw, cd in slots:
            cond = (col(sw) == "Y") & col(cd).isNotNull()
            explicit = when(cond, col(cd)) if explicit is None else explicit.when(cond, col(cd))

        code_any = None
        for _, cd in slots:
            code_any = col(cd) if code_any is None else coalesce(code_any, col(cd))

        any_x = None
        for sw, _ in slots:
            any_x = (col(sw) == "X") if any_x is None else (any_x | (col(sw) == "X"))

        clean = clean.withColumn(
            "primary_taxonomy",
            when(explicit.isNotNull(), explicit)
            .when(any_x & code_any.isNotNull(), code_any)
            .otherwise(lit(None))
        )
    else:
        clean = clean.withColumn("primary_taxonomy", lit(None).cast("string"))

# dropping raw taxonomy slot columns
slot_cols = [c for c in clean.columns
             if c.startswith("healthcare_provider_taxonomy_code_")
             or c.startswith("healthcare_provider_primary_taxonomy_switch_")]
clean = clean.drop(*slot_cols)


# building a fallback 'any taxonomy code' for completeness scoring
code_any = None
for i in range(1, 16):
    cd = f"healthcare_provider_taxonomy_code_{i}"
    if cd in clean.columns:
        code_any = col(cd) if code_any is None else coalesce(code_any, col(cd))
if code_any is None:
    code_any = lit(None)

# computing completeness_score
clean = clean.withColumn(
    "completeness_score",
    (when(coalesce(col("primary_taxonomy"), code_any).isNotNull(), 1).otherwise(0)
     + when(col("provider_business_practice_location_address_state_name").isNotNull(), 1).otherwise(0)
     + when(col("provider_enumeration_date").isNotNull(), 1).otherwise(0)
     + when(col("entity_type_code").isNotNull(), 1).otherwise(0))
)


# Quick sanity
print("primary_taxonomy present?", "primary_taxonomy" in clean.columns)

primary_taxonomy present? True


In [None]:
# validating the NPIs, they should be 10 digits long
npi_regex = r'^\d{10}$'
expected_npi_count = npi_regex

# counting how many NPIs are malformed
invalid_npi_count = (
    clean.filter(~col("npi").rlike(expected_npi_count))
    .count()
)

print(f"Invalid NPIs found: {invalid_npi_count}")



Invalid NPIs found: 0


                                                                                

In [21]:
# deduplicating based off of latest "last_update_date" and then highest completeness score
w = Window.partitionBy("npi").orderBy(col("last_update_date").desc_nulls_last(), col("completeness_score").desc())
wrangled_npi_df = (
    clean
    .withColumn("rn", row_number().over(w))
    .filter(col("rn") == 1)
    .drop("rn")
)

In [22]:
# comparing to make sure I have all my changes

print(sorted(clean.columns))
print(sorted(wrangled_npi_df.columns))

['completeness_score', 'deactivated_then_reactivated', 'entity_type_code', 'has_location', 'healthcare_provider_primary_taxonomy_switch_1', 'healthcare_provider_primary_taxonomy_switch_10', 'healthcare_provider_primary_taxonomy_switch_11', 'healthcare_provider_primary_taxonomy_switch_12', 'healthcare_provider_primary_taxonomy_switch_13', 'healthcare_provider_primary_taxonomy_switch_14', 'healthcare_provider_primary_taxonomy_switch_15', 'healthcare_provider_primary_taxonomy_switch_2', 'healthcare_provider_primary_taxonomy_switch_3', 'healthcare_provider_primary_taxonomy_switch_4', 'healthcare_provider_primary_taxonomy_switch_5', 'healthcare_provider_primary_taxonomy_switch_6', 'healthcare_provider_primary_taxonomy_switch_7', 'healthcare_provider_primary_taxonomy_switch_8', 'healthcare_provider_primary_taxonomy_switch_9', 'healthcare_provider_taxonomy_code_1', 'healthcare_provider_taxonomy_code_10', 'healthcare_provider_taxonomy_code_11', 'healthcare_provider_taxonomy_code_12', 'healthca

In [None]:
# doing sanity checks after cleaning my NPI dataframe values
print("Distinct is_sole_proprietor values:")
wrangled_npi_df.select("is_sole_proprietor").distinct().show()


print("Primary taxonomy source distinct values:")
wrangled_npi_df.select("primary_taxonomy_source").distinct().show()


print("Null counts after transformation:")
nulls = wrangled_npi_df.select(*[
    sum_(when(col(c).isNull(), 1).otherwise(0)).alias(c)
    for c in wrangled_npi_df.columns
])
nulls.show(truncate=False)

Distinct is_sole_proprietor values:


                                                                                

+------------------+
|is_sole_proprietor|
+------------------+
|                 1|
|                 0|
|              NULL|
+------------------+

Primary taxonomy switch distinct modes:


                                                                                

+---------------------------------------------+
|healthcare_provider_primary_taxonomy_switch_1|
+---------------------------------------------+
|                                            Y|
|                                            N|
|                                            X|
|                                         NULL|
+---------------------------------------------+

Null counts after transformation:




+---+----------------+------------------------------------------------------+-------------------------------------------------------+-----------------------+-----------------------+-----------------------+-------------------------+----------------+---------------------+---------------------+------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+---------------------------------

                                                                                

#### The NPPES NPI dataset has almost been fully wrangled as "wrangled_npi_df"
#### Stopping here to wrangle others...

## Medicare Physician & Other Practitioners - by Provider and Service Dataset

In [24]:
# creating a csv path variable
phys_pract_csv_path = "Medicare Physician & Other Practitioners - by Provider and Service/MUP_PHY_R25_P05_V20_D23_Prov_Svc.csv"

In [25]:
# reading the CSV file into a DataFrame
full_physician_practitioner_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(phys_pract_csv_path)
)

In [26]:
# inspecting the column names
print("Available columns (first 50 shown):")
for c in full_physician_practitioner_df.columns[:50]:
    print(repr(c))

Available columns (first 50 shown):
'Rndrng_NPI'
'Rndrng_Prvdr_Last_Org_Name'
'Rndrng_Prvdr_First_Name'
'Rndrng_Prvdr_MI'
'Rndrng_Prvdr_Crdntls'
'Rndrng_Prvdr_Ent_Cd'
'Rndrng_Prvdr_St1'
'Rndrng_Prvdr_St2'
'Rndrng_Prvdr_City'
'Rndrng_Prvdr_State_Abrvtn'
'Rndrng_Prvdr_State_FIPS'
'Rndrng_Prvdr_Zip5'
'Rndrng_Prvdr_RUCA'
'Rndrng_Prvdr_RUCA_Desc'
'Rndrng_Prvdr_Cntry'
'Rndrng_Prvdr_Type'
'Rndrng_Prvdr_Mdcr_Prtcptg_Ind'
'HCPCS_Cd'
'HCPCS_Desc'
'HCPCS_Drug_Ind'
'Place_Of_Srvc'
'Tot_Benes'
'Tot_Srvcs'
'Tot_Bene_Day_Srvcs'
'Avg_Sbmtd_Chrg'
'Avg_Mdcr_Alowd_Amt'
'Avg_Mdcr_Pymt_Amt'
'Avg_Mdcr_Stdzd_Amt'


In [27]:
# selecting the relevant columns I want to keep
keep_cols_physician_practitioners = [
    "Rndrng_NPI",                 
    "Rndrng_Prvdr_Ent_Cd",        
    "Rndrng_Prvdr_State_Abrvtn",  
    "Rndrng_Prvdr_Zip5",          
    "Rndrng_Prvdr_RUCA",          
    "Rndrng_Prvdr_RUCA_Desc",     
    "Rndrng_Prvdr_Mdcr_Prtcptg_Ind", 
    "Rndrng_Prvdr_Type",
    "Rndrng_Prvdr_Cntry",

    "HCPCS_Cd",                   
    "HCPCS_Desc",                 
    "HCPCS_Drug_Ind",             
    "Place_Of_Srvc",              
    "Tot_Benes",                  
    "Tot_Srvcs",                  
    "Tot_Bene_Day_Srvcs",         

    "Avg_Sbmtd_Chrg",             
    "Avg_Mdcr_Alowd_Amt",        
    "Avg_Mdcr_Pymt_Amt",          
    "Avg_Mdcr_Stdzd_Amt"        
]

In [28]:
# creating a new DataFrame with only the selected columns
phys_pract_df = full_physician_practitioner_df.select(*keep_cols_physician_practitioners)

In [29]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
phys_pract_df.printSchema()
print("Sample:")
phys_pract_df.limit(5).show(truncate=False)

Schema:
root
 |-- Rndrng_NPI: string (nullable = true)
 |-- Rndrng_Prvdr_Ent_Cd: string (nullable = true)
 |-- Rndrng_Prvdr_State_Abrvtn: string (nullable = true)
 |-- Rndrng_Prvdr_Zip5: string (nullable = true)
 |-- Rndrng_Prvdr_RUCA: string (nullable = true)
 |-- Rndrng_Prvdr_RUCA_Desc: string (nullable = true)
 |-- Rndrng_Prvdr_Mdcr_Prtcptg_Ind: string (nullable = true)
 |-- Rndrng_Prvdr_Type: string (nullable = true)
 |-- Rndrng_Prvdr_Cntry: string (nullable = true)
 |-- HCPCS_Cd: string (nullable = true)
 |-- HCPCS_Desc: string (nullable = true)
 |-- HCPCS_Drug_Ind: string (nullable = true)
 |-- Place_Of_Srvc: string (nullable = true)
 |-- Tot_Benes: string (nullable = true)
 |-- Tot_Srvcs: string (nullable = true)
 |-- Tot_Bene_Day_Srvcs: string (nullable = true)
 |-- Avg_Sbmtd_Chrg: string (nullable = true)
 |-- Avg_Mdcr_Alowd_Amt: string (nullable = true)
 |-- Avg_Mdcr_Pymt_Amt: string (nullable = true)
 |-- Avg_Mdcr_Stdzd_Amt: string (nullable = true)

Sample:
+----------+----

In [30]:
# getting the total number of rows in the DataFrame
total_rows_phys_pract = phys_pract_df.count()
print(f'Total rows: {total_rows_phys_pract}')



Total rows: 9660647


                                                                                

In [31]:
# getting the null counts for each column
phys_pract_df.select(
    *[
        sum_(col(c).isNull().cast('int')).alias(c) for c in phys_pract_df.columns
    ]
).show()



+----------+-------------------+-------------------------+-----------------+-----------------+----------------------+-----------------------------+-----------------+------------------+--------+----------+--------------+-------------+---------+---------+------------------+--------------+------------------+-----------------+------------------+
|Rndrng_NPI|Rndrng_Prvdr_Ent_Cd|Rndrng_Prvdr_State_Abrvtn|Rndrng_Prvdr_Zip5|Rndrng_Prvdr_RUCA|Rndrng_Prvdr_RUCA_Desc|Rndrng_Prvdr_Mdcr_Prtcptg_Ind|Rndrng_Prvdr_Type|Rndrng_Prvdr_Cntry|HCPCS_Cd|HCPCS_Desc|HCPCS_Drug_Ind|Place_Of_Srvc|Tot_Benes|Tot_Srvcs|Tot_Bene_Day_Srvcs|Avg_Sbmtd_Chrg|Avg_Mdcr_Alowd_Amt|Avg_Mdcr_Pymt_Amt|Avg_Mdcr_Stdzd_Amt|
+----------+-------------------+-------------------------+-----------------+-----------------+----------------------+-----------------------------+-----------------+------------------+--------+----------+--------------+-------------+---------+---------+------------------+--------------+------------------+------

                                                                                

In [32]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

# applying the normalized names to every column I plan to keep
normalized_names = []
for original_name in keep_cols_physician_practitioners:
    new_name = normalize_col(original_name)
    normalized_names.append(new_name)

# counting how many times each cleaned name shows up
name_counts = Counter(normalized_names)

# pulling out any names that show up more than once
duplicates = [name for name, count in name_counts.items() if count > 1]

# bailing out early, if a coliision is detected
if duplicates:
    raise RuntimeError(f'Column name conflict after normalization: {duplicates}')

In [33]:
# building a new DataFrame with normalized column names

# building a list of expressions to rename the columns in the DataFrame
renamed_columns_phys_pract = [
    col(old).alias(new)
    for old, new in zip(keep_cols_physician_practitioners, normalized_names)
]

# applying 
normalized_phys_pract_df = phys_pract_df.select(*renamed_columns_phys_pract)

In [34]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
normalized_phys_pract_df.printSchema()
normalized_phys_pract_df.show(5, truncate=False)

root
 |-- rndrng_npi: string (nullable = true)
 |-- rndrng_prvdr_ent_cd: string (nullable = true)
 |-- rndrng_prvdr_state_abrvtn: string (nullable = true)
 |-- rndrng_prvdr_zip5: string (nullable = true)
 |-- rndrng_prvdr_ruca: string (nullable = true)
 |-- rndrng_prvdr_ruca_desc: string (nullable = true)
 |-- rndrng_prvdr_mdcr_prtcptg_ind: string (nullable = true)
 |-- rndrng_prvdr_type: string (nullable = true)
 |-- rndrng_prvdr_cntry: string (nullable = true)
 |-- hcpcs_cd: string (nullable = true)
 |-- hcpcs_desc: string (nullable = true)
 |-- hcpcs_drug_ind: string (nullable = true)
 |-- place_of_srvc: string (nullable = true)
 |-- tot_benes: string (nullable = true)
 |-- tot_srvcs: string (nullable = true)
 |-- tot_bene_day_srvcs: string (nullable = true)
 |-- avg_sbmtd_chrg: string (nullable = true)
 |-- avg_mdcr_alowd_amt: string (nullable = true)
 |-- avg_mdcr_pymt_amt: string (nullable = true)
 |-- avg_mdcr_stdzd_amt: string (nullable = true)

+----------+-------------------+

In [35]:
# cleaning the values of the physician practitioner DataFrame

clean_phys_pract = (
    normalized_phys_pract_df

    # trim stray spaces and making sure the NPIs are exactly 10 digits
    .withColumn("rndrng_npi", trim(col("rndrng_npi")))
    .withColumn("npi_valid",
        when(col("rndrng_npi").rlike(r"^\d{10}$"), lit(1))
        .otherwise(lit(0))
    )

    # mapping the provider type: 
        # “I” → individual (1), everything else → org (0)
    .withColumn("is_individual",
        when(col("rndrng_prvdr_ent_cd") == "I", lit(1))
        .otherwise(lit(0))
    )

    # tidying up geographic strings
    .withColumn("rndrng_prvdr_state_abrvtn", trim(col("rndrng_prvdr_state_abrvtn")))
    .withColumn("rndrng_prvdr_zip5",          trim(col("rndrng_prvdr_zip5")))

    # flagging if the state is missing
    .withColumn("missing_state",
        when(col("rndrng_prvdr_state_abrvtn").isNull(), lit(1)).otherwise(lit(0))
    )

    # flagging if the zip code is missing
    .withColumn("missing_zip",
        when(col("rndrng_prvdr_zip5").isNull(), lit(1)).otherwise(lit(0))
    )

    # casting RUCA (rural‐urban code) to double type
    .withColumn("rndrng_prvdr_ruca", col("rndrng_prvdr_ruca").cast(DoubleType())
    )

    # medicare‐participation Y/N → 1/0
    .withColumn("medicare_participation",
        when(col("rndrng_prvdr_mdcr_prtcptg_ind") == "Y", lit(1))
        .otherwise(lit(0))
    )

    # clean HCPCS and place‐of‐service keys
    .withColumn("hcpcs_cd",     trim(col("hcpcs_cd")))
    .withColumn("place_of_srvc", trim(col("place_of_srvc")))


    # cast dollar amounts so ratios and averages behave
    .withColumn("avg_sbmtd_chrg",    col("avg_sbmtd_chrg").cast(DoubleType()))
    .withColumn("avg_mdcr_alowd_amt",col("avg_mdcr_alowd_amt").cast(DoubleType()))
    .withColumn("avg_mdcr_pymt_amt", col("avg_mdcr_pymt_amt").cast(DoubleType()))
    .withColumn("avg_mdcr_stdzd_amt",col("avg_mdcr_stdzd_amt").cast(DoubleType()))

    # drug flag: “Y” → 1, else → 0
    .withColumn("is_drug",
        when(col("hcpcs_drug_ind") == "Y", lit(1)).otherwise(lit(0))
    )
)

In [None]:
# keeping only United States rows
clean_phys_pract = (clean_phys_pract
    .withColumn("rndrng_prvdr_cntry", upper(trim(col("rndrng_prvdr_cntry"))))
    .filter(col("rndrng_prvdr_cntry").isNull() | col("rndrng_prvdr_cntry").isin("US","USA"))
)

In [36]:
# checking to make sure my cleaning worked
clean_phys_pract.select("hcpcs_drug_ind", "is_drug").distinct().show()



+--------------+-------+
|hcpcs_drug_ind|is_drug|
+--------------+-------+
|             Y|      1|
|             N|      0|
+--------------+-------+



                                                                                

In [37]:
# few other santity checks

# no “bad” NPIs should pass
clean_phys_pract.filter(col("npi_valid") == 0).count()

# ensuring dollar‐amount columns have no unexpected negatives or huge outliers
clean_phys_pract.select(
  min("avg_mdcr_alowd_amt"),
  max("avg_mdcr_alowd_amt"),
  mean("avg_mdcr_alowd_amt")
).show()

# checking that my “missing” flags line up with nulls
clean_phys_pract.filter(col("missing_zip") == 1).select("rndrng_prvdr_zip5").show()


                                                                                

+-----------------------+-----------------------+-----------------------+
|min(avg_mdcr_alowd_amt)|max(avg_mdcr_alowd_amt)|avg(avg_mdcr_alowd_amt)|
+-----------------------+-----------------------+-----------------------+
|                    0.0|           52794.040667|     106.10699656654631|
+-----------------------+-----------------------+-----------------------+





+-----------------+
|rndrng_prvdr_zip5|
+-----------------+
+-----------------+



                                                                                

##### Because of the nature of this dataset, I will not be doing a traditional deduping process
##### This dataset is SUPPOSED to have multiple lines per provider
##### I will be checking for true duplicates

In [38]:
# safe-casting to make sure my volume columns are really integers
safe_phys = (
    clean_phys_pract
    .withColumn("tot_benes", expr("try_cast(tot_benes AS INT)")) 
    .withColumn("tot_srvcs", expr("try_cast(tot_srvcs AS INT)")) 
    .withColumn("tot_bene_day_srvcs", expr("try_cast(tot_bene_day_srvcs AS INT)"))
)

# checking for exact duplicates across the key combinations
    # (provider + procedure + place + drug flag)
total_rows = safe_phys.count()

unique_rows = (
    safe_phys 
    .dropDuplicates([
        "rndrng_npi",
        "hcpcs_cd",
        "place_of_srvc",
        "is_drug"
    ]) 
    .count()
)

duplicate_rows = total_rows - unique_rows


print(f"Total rows:{total_rows:,}")
print(f"Unique prov×proc×place×drug:{unique_rows:,}")
print(f"Exact duplicate rows:{duplicate_rows:,}")


[Stage 56:>                                                         (0 + 8) / 9]

Total rows:9,660,647
Unique prov×proc×place×drug:9,660,647
Exact duplicate rows:0


                                                                                

In [None]:
# aggregating by provider for modeling purposes
# more reasoning -> the GNN model will be using the provider as a node, so I want to aggregate all the features by provider

# making sure my per-row volumes are truly DOUBLE
clean_for_agg = (
    clean_phys_pract
    .filter(col("npi_valid") == 1)   # <<< ADD THIS
    .withColumn("srvcs_d", col("tot_srvcs").cast("double"))
    .withColumn("benes_d", col("tot_benes").cast("double"))
    .withColumn("bene_days_d", col("tot_bene_day_srvcs").cast("double"))
)

clean_for_agg.count()

# grouping by NPI and rolling up all the per-procedure metrics into provider totals/averages

# aggregate weighted sums
agg_w = (
    clean_for_agg
    .groupBy("rndrng_npi")
    .agg(
        # volumes
        sum_("srvcs_d").alias("total_services"),
        sum_("benes_d").alias("total_beneficiaries"),
        sum_("bene_days_d").alias("total_bene_day_services"),

        # service-weighted sums for charges/allowed/payments
        sum_(col("avg_sbmtd_chrg") * col("srvcs_d")).alias("sum_submitted_w"),
        sum_(col("avg_mdcr_alowd_amt") * col("srvcs_d")).alias("sum_allowed_w"),
        sum_(col("avg_mdcr_pymt_amt") * col("srvcs_d")).alias("sum_payment_w"),

        # diversity & quality signals
        countDistinct("hcpcs_cd").alias("num_unique_procedures"),
        stddev(col("avg_sbmtd_chrg")).alias("stddev_submitted_charge"),
        avg(col("is_drug")).alias("frac_drug_services"),
        avg(col("missing_zip")).alias("frac_missing_zip")
    )
)

# deriving weighted averages & ratios of totals
provider_agg = (
    agg_w
    .withColumn("w_avg_submitted_charge",
                when(col("total_services") > 0, col("sum_submitted_w") / col("total_services")))
    .withColumn("w_avg_allowed",
                when(col("total_services") > 0, col("sum_allowed_w") / col("total_services")))
    .withColumn("w_avg_payment",
                when(col("total_services") > 0, col("sum_payment_w") / col("total_services")))
    .withColumn("charge_allowed_ratio",
                when(col("sum_allowed_w") > 0, col("sum_submitted_w") / col("sum_allowed_w")))
    .withColumn("payment_allowed_ratio",
                when(col("sum_allowed_w") > 0, col("sum_payment_w") / col("sum_allowed_w")))
    .drop("sum_submitted_w", "sum_allowed_w", "sum_payment_w")
    .withColumnRenamed("rndrng_npi", "npi")  # unify key name for joins later
)


                                                                                

In [None]:
# adding some derived features to provider_agg
provider_agg = (provider_agg
    .withColumn("services_per_bene",
        when(col("total_beneficiaries") > 0,
             col("total_services") / col("total_beneficiaries")))
    .withColumn("bene_days_per_bene",
        when(col("total_beneficiaries") > 0,
             col("total_bene_day_services") / col("total_beneficiaries")))
)

In [40]:
# doing a sanity check on the aggregation
provider_agg.printSchema()
provider_agg.show(5, truncate=False)

root
 |-- npi: string (nullable = true)
 |-- total_services: double (nullable = true)
 |-- total_beneficiaries: double (nullable = true)
 |-- total_bene_day_services: double (nullable = true)
 |-- num_unique_procedures: long (nullable = false)
 |-- stddev_submitted_charge: double (nullable = true)
 |-- frac_drug_services: double (nullable = true)
 |-- frac_missing_zip: double (nullable = true)
 |-- w_avg_submitted_charge: double (nullable = true)
 |-- w_avg_allowed: double (nullable = true)
 |-- w_avg_payment: double (nullable = true)
 |-- charge_allowed_ratio: double (nullable = true)
 |-- payment_allowed_ratio: double (nullable = true)





+----------+--------------+-------------------+-----------------------+---------------------+-----------------------+-------------------+----------------+----------------------+------------------+------------------+--------------------+---------------------+
|npi       |total_services|total_beneficiaries|total_bene_day_services|num_unique_procedures|stddev_submitted_charge|frac_drug_services |frac_missing_zip|w_avg_submitted_charge|w_avg_allowed     |w_avg_payment     |charge_allowed_ratio|payment_allowed_ratio|
+----------+--------------+-------------------+-----------------------+---------------------+-----------------------+-------------------+----------------+----------------------+------------------+------------------+--------------------+---------------------+
|1003846908|1832.0        |1445.0             |1809.0                 |16                   |474.3959654350813      |0.0                |0.0             |260.098799126607      |86.9379475978571  |63.06741812205404 |2.991775

                                                                                

In [41]:
# checking the total after the aggregation
total_providers = provider_agg.count()
print(f"Total providers after aggregation: {total_providers:,}")



Total providers after aggregation: 1,175,281


                                                                                

In [42]:
# doing some more sanity checks


# checking the ranges
print("=== Range summary ===")
provider_agg.select(
    min("total_services").alias("min_services"),
    max("total_services").alias("max_services"),
    min("total_beneficiaries").alias("min_beneficiaries"),
    max("total_beneficiaries").alias("max_beneficiaries"),
    min("total_bene_day_services").alias("min_bene_day_services"),
    max("total_bene_day_services").alias("max_bene_day_services"),
    min("charge_allowed_ratio").alias("min_charge_ratio"),
    max("charge_allowed_ratio").alias("max_charge_ratio"),
    min("payment_allowed_ratio").alias("min_payment_ratio"),
    max("payment_allowed_ratio").alias("max_payment_ratio")
).show(truncate=False)

# checking for any negatives or absurd outliers
provider_agg.filter(
    (col("total_services") < 0) |
    (col("total_beneficiaries") < 0) |
    (col("total_bene_day_services") < 0) |
    (col("charge_allowed_ratio") < 0) |
    (col("payment_allowed_ratio") < 0) |
    (col("charge_allowed_ratio") > 100) |
    (col("payment_allowed_ratio") > 100)
).select(
    "npi",
    "total_services",
    "total_beneficiaries",
    "total_bene_day_services",
    "charge_allowed_ratio",
    "payment_allowed_ratio"
).show(10, truncate=False)


# checking frac_drug_services
print("=== Distinct frac_drug_services values ===")
provider_agg.select("frac_drug_services").distinct().orderBy("frac_drug_services").show()

# checking how many providers have zero drug‐share
zero_drug_count = provider_agg.filter(col("frac_drug_services") == 0.0).count()
print(f"Providers with frac_drug_services = 0.0: {zero_drug_count:,}")


=== Range summary ===


                                                                                

+------------+------------+-----------------+-----------------+---------------------+---------------------+------------------+----------------+-----------------+-----------------+
|min_services|max_services|min_beneficiaries|max_beneficiaries|min_bene_day_services|max_bene_day_services|min_charge_ratio  |max_charge_ratio|min_payment_ratio|max_payment_ratio|
+------------+------------+-----------------+-----------------+---------------------+---------------------+------------------+----------------+-----------------+-----------------+
|11.0        |1.6019747E7 |11.0             |1.0403747E7      |11.0                 |1.5177802E7          |0.8987661160749534|10638.461538    |0.0              |1.0              |
+------------+------------+-----------------+-----------------+---------------------+---------------------+------------------+----------------+-----------------+-----------------+



                                                                                

+----------+--------------+-------------------+-----------------------+--------------------+---------------------+
|npi       |total_services|total_beneficiaries|total_bene_day_services|charge_allowed_ratio|payment_allowed_ratio|
+----------+--------------+-------------------+-----------------------+--------------------+---------------------+
|1063088466|72.0          |72.0               |72.0                   |116.22288990192418  |0.7967847769030011   |
|1083682405|14.0          |14.0               |14.0                   |101.85851588546967  |0.8004288306050098   |
|1154792729|109.0         |107.0              |109.0                  |110.06187916211955  |0.7920176326586722   |
|1215149265|1464.0        |1080.0             |1091.0                 |112.61769308360151  |0.796413036679571    |
|1295876969|92.0          |89.0               |92.0                   |118.12411101282882  |0.7967593545258765   |
|1366087421|86.0          |82.0               |86.0                   |125.15443

                                                                                

+--------------------+
|  frac_drug_services|
+--------------------+
|                 0.0|
|0.007246376811594203|
|0.007462686567164179|
|0.007692307692307693|
|0.007751937984496124|
|0.007874015748031496|
|0.007936507936507936|
|               0.008|
|0.008064516129032258|
| 0.00819672131147541|
|0.008264462809917356|
| 0.00847457627118644|
|0.008620689655172414|
|0.008695652173913044|
|0.008771929824561403|
|0.008849557522123894|
|0.008928571428571428|
|0.009009009009009009|
| 0.00909090909090909|
|0.009259259259259259|
+--------------------+
only showing top 20 rows




Providers with frac_drug_services = 0.0: 953,913


                                                                                

## 'List of Excluded Individuals and Entities' (LEIE) - OIG Dataset

In [43]:
# creating a csv path variable
leie_csv_path = "Office of Inspector General - Excluded Individuals and Entities/20250710 LEIE.csv"

In [44]:
# reading the CSV file into a DataFrame
full_leie_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(leie_csv_path)
)

In [45]:
# inspecting the column names
print("Available columns:")
for c in full_leie_df.columns[:]:
    print(repr(c))

Available columns:
'LASTNAME'
'FIRSTNAME'
'MIDNAME'
'BUSNAME'
'GENERAL'
'SPECIALTY'
'UPIN'
'NPI'
'DOB'
'ADDRESS'
'CITY'
'STATE'
'ZIP'
'EXCLTYPE'
'EXCLDATE'
'REINDATE'
'WAIVERDATE'
'WVRSTATE'


In [46]:
# selecting the relevant columns I want to keep
keep_cols_leie = [
    "NPI",        
    "LASTNAME",   
    "FIRSTNAME",  
    "MIDNAME",
    "BUSNAME",    
    "ADDRESS",    
    "CITY",       
    "STATE",      
    "ZIP",        
    "SPECIALTY",  
    "EXCLTYPE",   
    "EXCLDATE",   
    "REINDATE",  
    "DOB",
    "WAIVERDATE",
    "WVRSTATE",
    "UPIN"
]

In [47]:
# creating a new DataFrame with only the selected columns
leie_df = full_leie_df.select(*keep_cols_leie)

In [48]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
leie_df.printSchema()
print("Sample:")
leie_df.limit(5).show(truncate=False)

Schema:
root
 |-- NPI: string (nullable = true)
 |-- LASTNAME: string (nullable = true)
 |-- FIRSTNAME: string (nullable = true)
 |-- MIDNAME: string (nullable = true)
 |-- BUSNAME: string (nullable = true)
 |-- ADDRESS: string (nullable = true)
 |-- CITY: string (nullable = true)
 |-- STATE: string (nullable = true)
 |-- ZIP: string (nullable = true)
 |-- SPECIALTY: string (nullable = true)
 |-- EXCLTYPE: string (nullable = true)
 |-- EXCLDATE: string (nullable = true)
 |-- REINDATE: string (nullable = true)
 |-- DOB: string (nullable = true)
 |-- WAIVERDATE: string (nullable = true)
 |-- WVRSTATE: string (nullable = true)
 |-- UPIN: string (nullable = true)

Sample:
+----------+--------+---------+-------+---------------------------+-----------------------------+----------+-----+-----+------------------+--------+--------+--------+----+----------+--------+----+
|NPI       |LASTNAME|FIRSTNAME|MIDNAME|BUSNAME                    |ADDRESS                      |CITY      |STATE|ZIP  |SPECIA

In [49]:
# getting the total number of rows in the DataFrame
total_rows_leie = leie_df.count()
print(f'Total rows: {total_rows_leie}')

Total rows: 81774


In [50]:
# getting the null counts for each column
leie_df.select(
    *[sum_(col(c).isNull().cast('int')).alias(c) for c in leie_df.columns]
).show()

+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+
|NPI|LASTNAME|FIRSTNAME|MIDNAME|BUSNAME|ADDRESS|CITY|STATE|ZIP|SPECIALTY|EXCLTYPE|EXCLDATE|REINDATE| DOB|WAIVERDATE|WVRSTATE| UPIN|
+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+
|  0|    3358|     3358|  24009|  78416|      4|   0|    0|  0|     4088|       0|       0|       0|4222|         0|   81762|75791|
+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+



In [51]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

# applying the normalized names to every column I plan to keep
normalized_names = []
for original_name in keep_cols_leie:
    new_name = normalize_col(original_name)
    normalized_names.append(new_name)

# counting how many times each cleaned name shows up
name_counts = Counter(normalized_names)

# pulling out any names that show up more than once
duplicates = [name for name, count in name_counts.items() if count > 1]

# bailing out early, if a coliision is detected
if duplicates:
    raise RuntimeError(f'Column name conflict after normalization: {duplicates}')

In [52]:
# building a new DataFrame with normalized column names

# building a list of expressions to rename the columns in the DataFrame
renamed_columns_leie = [
    col(old).alias(new)
    for old, new in zip(keep_cols_leie, normalized_names)
]

# applying 
normalized_leie = leie_df.select(*renamed_columns_leie)

In [53]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
normalized_leie.printSchema()
normalized_leie.show(5, truncate=False)

root
 |-- npi: string (nullable = true)
 |-- lastname: string (nullable = true)
 |-- firstname: string (nullable = true)
 |-- midname: string (nullable = true)
 |-- busname: string (nullable = true)
 |-- address: string (nullable = true)
 |-- city: string (nullable = true)
 |-- state: string (nullable = true)
 |-- zip: string (nullable = true)
 |-- specialty: string (nullable = true)
 |-- excltype: string (nullable = true)
 |-- excldate: string (nullable = true)
 |-- reindate: string (nullable = true)
 |-- dob: string (nullable = true)
 |-- waiverdate: string (nullable = true)
 |-- wvrstate: string (nullable = true)
 |-- upin: string (nullable = true)

+----------+--------+---------+-------+---------------------------+-----------------------------+----------+-----+-----+------------------+--------+--------+--------+----+----------+--------+----+
|npi       |lastname|firstname|midname|busname                    |address                      |city      |state|zip  |specialty         |exc

In [54]:
# moving onto the cleaning values step of the LEIE DataFrame

leie = normalized_leie


# trimming strings and converting empty strings to nulls
for c, t in leie.dtypes:
    if t == "string":
        leie = leie.withColumn(c, when(trim(col(c)) == "", None).otherwise(trim(col(c))))

In [55]:
# keeping NPI and ZIP as strings, but normalizing them
leie = (leie
    .withColumn("npi", regexp_replace(col("npi"), r"[^0-9]", ""))
    .withColumn("npi_valid", when(col("npi").rlike(r"^(?!0{10})\d{10}$"), 1).otherwise(0).cast(IntegerType()))
    .withColumn("zip5", substring(regexp_replace(col("zip"), r"[^0-9]", ""), 1, 5))
    .drop("zip")
)

In [56]:
# standardizing codes/geo (UPPER only where it matters)
leie = (leie
    .withColumn("state", when(col("state").rlike(r"^[A-Za-z]{2}$"), upper(col("state"))).otherwise(lit(None)))
    .withColumnRenamed("wvrstate", "waiverstate")
    .withColumn("waiverstate", when(col("waiverstate").rlike(r"^[A-Za-z]{2}$"), upper(col("waiverstate"))).otherwise(lit(None)))
    .withColumn("excltype", upper(col("excltype")))         # canonicalize code
    .withColumn("specialty", upper(col("specialty")))       # optional but helps grouping
)

In [57]:
# handling the dates
# treating 00000000 as null
def parse_date(cname):
    digits = regexp_replace(col(cname), r"[^0-9]", "")
    yyyymmdd = when(digits.rlike(r"^\d{8}$") & (digits != lit("00000000")), to_date(digits, "yyyyMMdd"))
    return coalesce(yyyymmdd, to_date(col(cname), "MM/dd/yyyy"),
                    to_date(col(cname), "M/d/yyyy"), to_date(col(cname), "yyyy-MM-dd"))

leie = (leie
    .withColumn("excldate_dt",   parse_date("excldate").cast(DateType()))
    .withColumn("reindate_dt",   parse_date("reindate").cast(DateType()))
    .withColumn("waiverdate_dt", parse_date("waiverdate").cast(DateType()))
    .withColumn("dob_dt",        parse_date("dob").cast(DateType()))
    .drop("excldate","reindate","waiverdate","dob")
)

In [58]:
# creating a flag for whether the record is an organization or not
leie = leie.withColumn("is_org", when(col("busname").isNotNull() & (col("busname") != ""), 1).otherwise(0).cast(IntegerType()))
as_of = to_date(lit(None))  # if you want "today", use current_date()
as_of = lit(None)           # or set to a fixed as-of date for reproducibility

as_of = current_date()

leie = (leie
    .withColumn("has_excl", when(col("excldate_dt").isNotNull(), 1).otherwise(0).cast(IntegerType()))
    .withColumn("is_currently_excluded",
        when(
            (col("excldate_dt").isNotNull()) &
            (col("waiverdate_dt").isNull()) &
            (col("reindate_dt").isNull() | (col("reindate_dt") > as_of)),
            1
        ).otherwise(0).cast(IntegerType())
    )
    .withColumn("exclusion_start", col("excldate_dt").cast(DateType()))
    .withColumn("exclusion_end_effective", coalesce(col("reindate_dt"), as_of).cast(DateType()))
    .withColumn("exclusion_duration_days",
        when(col("excldate_dt").isNotNull(), datediff(col("exclusion_end_effective"), col("exclusion_start"))).cast(IntegerType())
    )
)

In [59]:
# aggregating and making one row per NPI
leie_with_npi = leie.filter(col("npi_valid") == 1)

leie_by_npi = (leie_with_npi
    .groupBy("npi")
    .agg(
        countDistinct("excldate_dt").cast(IntegerType()).alias("exclusion_count"),
        min("excldate_dt").cast(DateType()).alias("first_excldate"),
        max("excldate_dt").cast(DateType()).alias("most_recent_excldate"),
        max("is_currently_excluded").cast(IntegerType()).alias("is_currently_excluded"),
        max("waiverdate_dt").cast(DateType()).alias("last_waiverdate"),
        max("reindate_dt").cast(DateType()).alias("last_reindate")
    )
)

### Working on joining the three datasets

In [None]:
wrangled_npi_df.write.mode("overwrite").parquet("curated/npi_nodes.parquet")
provider_agg.write.mode("overwrite").parquet("curated/puf_provider_agg.parquet")
leie_by_npi.write.mode("overwrite").parquet("curated/leie_by_npi.parquet")
