# Provider Fraud Detection Model - Data Wrangling
### In this notebook we will focus on collecting, organizing, defining, and cleaning the relevant datasets for the Provider Fraud Detection model
#### 1.) Ingesting & Inspecting Files
#### 2.) Standardizing Column Names & Types
#### 3.) Cleaning Values
#### 4.) Deduping (Removing Duplicates)
#### 5.) Merging/Joining Tables

In [3]:
# core imports
import re
import builtins as py  # for py.max / py.min in sampling

# essential spark imports
from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark import StorageLevel

# type imports
from pyspark.sql.types import IntegerType, DoubleType, DateType

# column operations imports
from pyspark.sql.functions import (
    col, lit, trim, upper, substring,
    to_date, datediff, coalesce, regexp_replace, length,
    when, row_number, min, max, countDistinct, avg,
    sum as sum_, pow, sqrt, greatest,
    broadcast, desc_nulls_last
)


In [4]:
# this function will help me normalize my column names to ensure a smooth workflow
# this function will...
    # lowercase, trim whitespace
    # drop parenthetical notes
    # replace non-alphanumeric chars with underscores
    # collapse runs of underscores
    # strip leading/trailing underscores

def normalize_col(column_name: str) -> str:
    text = column_name.strip().lower()
    text = re.sub(r'\(.*?\)', '', text)
    text = re.sub(r'[^0-9a-z]+', '_', text)
    text = re.sub(r'_+', '_', text)
    return text.strip('_')

In [5]:
# starting Spark session
spark = (
    SparkSession.builder
    .appName("ProviderFraudDetection-NN")
    .master("local[*]")
    .config("spark.driver.memory", "12g")
    .config("spark.sql.shuffle.partitions", "16")
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.memory.fraction", "0.6")
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/20 18:54:13 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [6]:
# time freeze to prevent leakage (aligns with my Medicare Physician and Practitioners dataset)
AS_OF_STR = "2023-12-31"
print("Training as-of date:", AS_OF_STR)

Training as-of date: 2023-12-31


## NPPES NPI Registry Dataset

In [7]:
# creating a csv path variable
npi_csv_path = "NPPES_Data_Dissemination_July_2025_V2/npidata_pfile_20050523-20250713.csv"


In [8]:
# reading the CSV file into a DataFrame
full_npi_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(npi_csv_path)
)

In [9]:
# inspecting the npi dataset column names
print("Available columns (first 50 shown):")
for c in full_npi_df.columns[:50]:
    print(repr(c))
    

Available columns (first 50 shown):
'NPI'
'Entity Type Code'
'Replacement NPI'
'Employer Identification Number (EIN)'
'Provider Organization Name (Legal Business Name)'
'Provider Last Name (Legal Name)'
'Provider First Name'
'Provider Middle Name'
'Provider Name Prefix Text'
'Provider Name Suffix Text'
'Provider Credential Text'
'Provider Other Organization Name'
'Provider Other Organization Name Type Code'
'Provider Other Last Name'
'Provider Other First Name'
'Provider Other Middle Name'
'Provider Other Name Prefix Text'
'Provider Other Name Suffix Text'
'Provider Other Credential Text'
'Provider Other Last Name Type Code'
'Provider First Line Business Mailing Address'
'Provider Second Line Business Mailing Address'
'Provider Business Mailing Address City Name'
'Provider Business Mailing Address State Name'
'Provider Business Mailing Address Postal Code'
'Provider Business Mailing Address Country Code (If outside U.S.)'
'Provider Business Mailing Address Telephone Number'
'Provider B

In [10]:
# selecting the relevant columns I want to keep
tax_cols = []

for i in range(1, 16):
    tax_cols += [f"Healthcare Provider Taxonomy Code_{i}",
                 f"Healthcare Provider Primary Taxonomy Switch_{i}"]

keep_cols_npi = [
    "NPI",
    "Entity Type Code",

    # person/org name columns 
    "Provider First Name",
    "Provider Last Name (Legal Name)",
    "Provider Middle Name",                    
    "Provider Organization Name (Legal Business Name)",

    "Provider Business Practice Location Address State Name",
    "Provider Business Practice Location Address Postal Code",
    "Is Organization Subpart",
    "Parent Organization TIN",
    "Parent Organization LBN",
    "Provider Enumeration Date",
    "Last Update Date",
    "NPI Deactivation Date",
    "NPI Reactivation Date",
    "Is Sole Proprietor",
] + tax_cols


In [11]:
# creating a new DataFrame with only the selected columns
npi_df = full_npi_df.select(*keep_cols_npi)


In [12]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
npi_df.printSchema()
print("Sample:")
npi_df.limit(5).show(truncate=False)

Schema:
root
 |-- NPI: string (nullable = true)
 |-- Entity Type Code: string (nullable = true)
 |-- Provider First Name: string (nullable = true)
 |-- Provider Last Name (Legal Name): string (nullable = true)
 |-- Provider Middle Name: string (nullable = true)
 |-- Provider Organization Name (Legal Business Name): string (nullable = true)
 |-- Provider Business Practice Location Address State Name: string (nullable = true)
 |-- Provider Business Practice Location Address Postal Code: string (nullable = true)
 |-- Is Organization Subpart: string (nullable = true)
 |-- Parent Organization TIN: string (nullable = true)
 |-- Parent Organization LBN: string (nullable = true)
 |-- Provider Enumeration Date: string (nullable = true)
 |-- Last Update Date: string (nullable = true)
 |-- NPI Deactivation Date: string (nullable = true)
 |-- NPI Reactivation Date: string (nullable = true)
 |-- Is Sole Proprietor: string (nullable = true)
 |-- Healthcare Provider Taxonomy Code_1: string (nullable 

25/08/20 18:56:20 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


+----------+----------------+-------------------+-------------------------------+--------------------+------------------------------------------------+------------------------------------------------------+-------------------------------------------------------+-----------------------+-----------------------+-----------------------+-------------------------+----------------+---------------------+---------------------+------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+--------------------------------

In [13]:
# getting the total number of rows in the DataFrame
total_rows = npi_df.count()
print(f'Total rows: {total_rows}')



Total rows: 9026996


                                                                                

In [14]:
# getting the null counts for each column
npi_df.select(
    *[
        sum_(col(c).isNull().cast('int')).alias(c) for c in npi_df.columns
    ]
).show()



+---+----------------+-------------------+-------------------------------+--------------------+------------------------------------------------+------------------------------------------------------+-------------------------------------------------------+-----------------------+-----------------------+-----------------------+-------------------------+----------------+---------------------+---------------------+------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------

                                                                                

In [15]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

new_names = [normalize_col(c) for c in keep_cols_npi]
if len(new_names) != len(set(new_names)):
    raise RuntimeError("Column name conflict after normalization in NPPES.")

normalized_npi_df = npi_df.select(*[col(o).alias(n) for o,n in zip(keep_cols_npi, new_names)])

In [16]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
normalized_npi_df.printSchema()
normalized_npi_df.show(5, truncate=False)

root
 |-- npi: string (nullable = true)
 |-- entity_type_code: string (nullable = true)
 |-- provider_first_name: string (nullable = true)
 |-- provider_last_name: string (nullable = true)
 |-- provider_middle_name: string (nullable = true)
 |-- provider_organization_name: string (nullable = true)
 |-- provider_business_practice_location_address_state_name: string (nullable = true)
 |-- provider_business_practice_location_address_postal_code: string (nullable = true)
 |-- is_organization_subpart: string (nullable = true)
 |-- parent_organization_tin: string (nullable = true)
 |-- parent_organization_lbn: string (nullable = true)
 |-- provider_enumeration_date: string (nullable = true)
 |-- last_update_date: string (nullable = true)
 |-- npi_deactivation_date: string (nullable = true)
 |-- npi_reactivation_date: string (nullable = true)
 |-- is_sole_proprietor: string (nullable = true)
 |-- healthcare_provider_taxonomy_code_1: string (nullable = true)
 |-- healthcare_provider_primary_ta

In [17]:
# casting and standardizing the data types
clean_npi = (
    normalized_npi_df
    .withColumn("npi", regexp_replace(trim(col("npi")), r"\D", ""))
    .withColumn("entity_type_code", col("entity_type_code").cast("int"))
    .withColumn(
        "is_organization_subpart",
        when(upper(trim(col("is_organization_subpart"))) == "Y", lit(1))
        .when(upper(trim(col("is_organization_subpart"))) == "N", lit(0))
        .otherwise(lit(None).cast("int"))
    )
    .withColumn("is_sole_proprietor_raw", upper(trim(col("is_sole_proprietor"))))
    .withColumn(
        "is_sole_proprietor",
        when(col("entity_type_code") == 1,
             when(col("is_sole_proprietor_raw") == "Y", 1)
             .when(col("is_sole_proprietor_raw") == "N", 0)
             .otherwise(None)
        ).otherwise(None)
    ).drop("is_sole_proprietor_raw")

    # parsing dates 
    .withColumn("provider_enumeration_date",
        when(col("provider_enumeration_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("provider_enumeration_date"), "M/d/yyyy"))
    )
    .withColumn("last_update_date",
        when(col("last_update_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("last_update_date"), "M/d/yyyy"))
    )
    .withColumn("npi_deactivation_date",
        when(col("npi_deactivation_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("npi_deactivation_date"), "M/d/yyyy"))
    )
    .withColumn("npi_reactivation_date",
        when(col("npi_reactivation_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("npi_reactivation_date"), "M/d/yyyy"))
    )

    # zip/state standardization
    .withColumn("zip5",
        substring(regexp_replace(col("provider_business_practice_location_address_postal_code"), r"[^0-9]", ""), 1, 5)
    )
    .withColumn("state_abbr",
        when(trim(col("provider_business_practice_location_address_state_name")).rlike(r"^[A-Za-z]{2}$"),
             upper(trim(col("provider_business_practice_location_address_state_name")))
        )
    )
    .withColumn("parent_org_tin_norm", regexp_replace(col("parent_organization_tin"), r"\D", ""))
)


In [18]:
# picking the "best" taxonomy for each NPI
# or if none is marked as primary, the first non-null one
cand_exprs = []

any_x = None
for i in range(1, 16):
    sw_u = upper(col(f"healthcare_provider_primary_taxonomy_switch_{i}"))
    cd   = col(f"healthcare_provider_taxonomy_code_{i}")
    cand_exprs.append(when((sw_u == "Y") & cd.isNotNull(), cd))
    any_x = (sw_u == "X") if any_x is None else (any_x | (sw_u == "X"))

explicit_expr = coalesce(*cand_exprs)  # first slot marked Y
first_code    = coalesce(*[col(f"healthcare_provider_taxonomy_code_{i}") for i in range(1, 16)])


clean_npi = (
    clean_npi
    .withColumn(
        "primary_taxonomy",
        when(explicit_expr.isNotNull(), explicit_expr)     # explicit Y
        .when(first_code.isNotNull(), first_code)          # fallback regardless of X
        .otherwise(lit(None))
    )
    
    # adding diagnostic flags for QA or feature slicing
    .withColumn("primary_taxonomy_explicit", when(explicit_expr.isNotNull(), lit(1)).otherwise(lit(0)))
    .withColumn("primary_taxonomy_inferred", when(explicit_expr.isNull() & first_code.isNotNull(), lit(1)).otherwise(lit(0)))
    .withColumn("primary_taxonomy_from_x",   when(explicit_expr.isNull() & any_x & first_code.isNotNull(), lit(1)).otherwise(lit(0)))
)

In [19]:
# add two provider lifecycle features (referenced later in npi_feat)
# npi_age_days: days since enumeration
# is_active: 1 if active, 0 if deactivated or reactivated after deactivation
clean_npi = (
    clean_npi
    .withColumn("npi_age_days", datediff(to_date(lit(AS_OF_STR)), col("provider_enumeration_date")))
    .withColumn(
        "is_active",
        when(
            (col("npi_deactivation_date").isNull()) |
            ((col("npi_reactivation_date").isNotNull()) & (col("npi_reactivation_date") >= col("npi_deactivation_date"))),
            1
        ).otherwise(0)
    )
)


In [20]:
# filtering out invalid NPIs (not 10 digits)
clean_npi = clean_npi.filter(col("npi").rlike(r"^\d{10}$"))

In [21]:
# deduplicate to one row per NPI (latest update, then most complete)

# scoring how complete each row is
comp_score = (
    when(coalesce(col("primary_taxonomy"), first_code).isNotNull(), 1).otherwise(0)
    + when(col("state_abbr").isNotNull(), 1).otherwise(0)
    + when(col("provider_enumeration_date").isNotNull(), 1).otherwise(0)
    + when(col("entity_type_code").isNotNull(), 1).otherwise(0)
).alias("completeness_score")

npi_scored = clean_npi.withColumn("completeness_score", comp_score)

from pyspark.sql.functions import desc_nulls_last

w = Window.partitionBy("npi").orderBy(
    desc_nulls_last("last_update_date"),
    col("completeness_score").desc()
)

# collapsing multiple NPPES rows per NPI down to a single row
wrangled_npi_df = (
    npi_scored
    .withColumn("rn", row_number().over(w))
    .filter(col("rn") == 1)
    .drop("rn")
)

print("NPPES wrangled rows:", wrangled_npi_df.count())



NPPES wrangled rows: 9026996


                                                                                

## Medicare Physician & Other Practitioners - by Provider and Service Dataset

In [22]:
# creating a csv path variable
phys_pract_csv_path = "Medicare Physician & Other Practitioners - by Provider and Service/MUP_PHY_R25_P05_V20_D23_Prov_Svc.csv"

In [23]:
# reading the CSV file into a DataFrame
full_physician_practitioner_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(phys_pract_csv_path)
)

In [24]:
# inspecting the column names
print("Available columns (first 50 shown):")
for c in full_physician_practitioner_df.columns[:50]:
    print(repr(c))

Available columns (first 50 shown):
'Rndrng_NPI'
'Rndrng_Prvdr_Last_Org_Name'
'Rndrng_Prvdr_First_Name'
'Rndrng_Prvdr_MI'
'Rndrng_Prvdr_Crdntls'
'Rndrng_Prvdr_Ent_Cd'
'Rndrng_Prvdr_St1'
'Rndrng_Prvdr_St2'
'Rndrng_Prvdr_City'
'Rndrng_Prvdr_State_Abrvtn'
'Rndrng_Prvdr_State_FIPS'
'Rndrng_Prvdr_Zip5'
'Rndrng_Prvdr_RUCA'
'Rndrng_Prvdr_RUCA_Desc'
'Rndrng_Prvdr_Cntry'
'Rndrng_Prvdr_Type'
'Rndrng_Prvdr_Mdcr_Prtcptg_Ind'
'HCPCS_Cd'
'HCPCS_Desc'
'HCPCS_Drug_Ind'
'Place_Of_Srvc'
'Tot_Benes'
'Tot_Srvcs'
'Tot_Bene_Day_Srvcs'
'Avg_Sbmtd_Chrg'
'Avg_Mdcr_Alowd_Amt'
'Avg_Mdcr_Pymt_Amt'
'Avg_Mdcr_Stdzd_Amt'


In [25]:
# selecting the relevant columns I want to keep
keep_cols_puf = [
    "Rndrng_NPI","Rndrng_Prvdr_Ent_Cd","Rndrng_Prvdr_State_Abrvtn","Rndrng_Prvdr_Zip5",
    "Rndrng_Prvdr_RUCA","Rndrng_Prvdr_RUCA_Desc","Rndrng_Prvdr_Mdcr_Prtcptg_Ind",
    "Rndrng_Prvdr_Type","Rndrng_Prvdr_Cntry",
    "HCPCS_Cd","HCPCS_Desc","HCPCS_Drug_Ind","Place_Of_Srvc",
    "Tot_Benes","Tot_Srvcs","Tot_Bene_Day_Srvcs",
    "Avg_Sbmtd_Chrg","Avg_Mdcr_Alowd_Amt","Avg_Mdcr_Pymt_Amt","Avg_Mdcr_Stdzd_Amt"
]

In [26]:
# creating a new DataFrame with only the selected columns
phys = full_physician_practitioner_df.select(*keep_cols_puf)

In [27]:
# getting the total number of rows in the DataFrame
total_rows_phys_pract = phys.count()
print(f'Total rows: {total_rows_phys_pract}')



Total rows: 9660647


                                                                                

In [28]:
# getting the null counts for each column
phys.select(
    *[
        sum_(col(c).isNull().cast('int')).alias(c) for c in phys.columns
    ]
).show()



+----------+-------------------+-------------------------+-----------------+-----------------+----------------------+-----------------------------+-----------------+------------------+--------+----------+--------------+-------------+---------+---------+------------------+--------------+------------------+-----------------+------------------+
|Rndrng_NPI|Rndrng_Prvdr_Ent_Cd|Rndrng_Prvdr_State_Abrvtn|Rndrng_Prvdr_Zip5|Rndrng_Prvdr_RUCA|Rndrng_Prvdr_RUCA_Desc|Rndrng_Prvdr_Mdcr_Prtcptg_Ind|Rndrng_Prvdr_Type|Rndrng_Prvdr_Cntry|HCPCS_Cd|HCPCS_Desc|HCPCS_Drug_Ind|Place_Of_Srvc|Tot_Benes|Tot_Srvcs|Tot_Bene_Day_Srvcs|Avg_Sbmtd_Chrg|Avg_Mdcr_Alowd_Amt|Avg_Mdcr_Pymt_Amt|Avg_Mdcr_Stdzd_Amt|
+----------+-------------------+-------------------------+-----------------+-----------------+----------------------+-----------------------------+-----------------+------------------+--------+----------+--------------+-------------+---------+---------+------------------+--------------+------------------+------

                                                                                

In [29]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

# applying the normalized names to every column I plan to keep
puf_names = [normalize_col(c) for c in keep_cols_puf]
if len(puf_names) != len(set(puf_names)):
    raise RuntimeError("Column name conflict after normalization in PUF.")

In [30]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
p = phys.select(*[col(o).alias(n) for o,n in zip(keep_cols_puf, puf_names)])
p.printSchema()
p.show(5, truncate=False)

root
 |-- rndrng_npi: string (nullable = true)
 |-- rndrng_prvdr_ent_cd: string (nullable = true)
 |-- rndrng_prvdr_state_abrvtn: string (nullable = true)
 |-- rndrng_prvdr_zip5: string (nullable = true)
 |-- rndrng_prvdr_ruca: string (nullable = true)
 |-- rndrng_prvdr_ruca_desc: string (nullable = true)
 |-- rndrng_prvdr_mdcr_prtcptg_ind: string (nullable = true)
 |-- rndrng_prvdr_type: string (nullable = true)
 |-- rndrng_prvdr_cntry: string (nullable = true)
 |-- hcpcs_cd: string (nullable = true)
 |-- hcpcs_desc: string (nullable = true)
 |-- hcpcs_drug_ind: string (nullable = true)
 |-- place_of_srvc: string (nullable = true)
 |-- tot_benes: string (nullable = true)
 |-- tot_srvcs: string (nullable = true)
 |-- tot_bene_day_srvcs: string (nullable = true)
 |-- avg_sbmtd_chrg: string (nullable = true)
 |-- avg_mdcr_alowd_amt: string (nullable = true)
 |-- avg_mdcr_pymt_amt: string (nullable = true)
 |-- avg_mdcr_stdzd_amt: string (nullable = true)

+----------+-------------------+

In [31]:
# cleaning and casting the data types

p = (
    p
    .withColumn("rndrng_npi", trim(col("rndrng_npi")))
    .withColumn("npi_valid", when(col("rndrng_npi").rlike(r"^\d{10}$"), 1).otherwise(0))
    .withColumn("is_individual", when(col("rndrng_prvdr_ent_cd") == "I", 1).otherwise(0))
    .withColumn("rndrng_prvdr_state_abrvtn", trim(col("rndrng_prvdr_state_abrvtn")))
    .withColumn("rndrng_prvdr_zip5", trim(col("rndrng_prvdr_zip5")))
    .withColumn("missing_state", when(col("rndrng_prvdr_state_abrvtn").isNull(), 1).otherwise(0))
    .withColumn("missing_zip", when(col("rndrng_prvdr_zip5").isNull(), 1).otherwise(0))
    .withColumn("rndrng_prvdr_ruca", col("rndrng_prvdr_ruca").cast(DoubleType()))
    .withColumn("medicare_participation", when(col("rndrng_prvdr_mdcr_prtcptg_ind") == "Y", 1).otherwise(0))
    .withColumn("hcpcs_cd", trim(col("hcpcs_cd")))
    .withColumn("place_of_srvc", trim(col("place_of_srvc")))
    .withColumn("avg_sbmtd_chrg", col("avg_sbmtd_chrg").cast(DoubleType()))
    .withColumn("avg_mdcr_alowd_amt", col("avg_mdcr_alowd_amt").cast(DoubleType()))
    .withColumn("avg_mdcr_pymt_amt", col("avg_mdcr_pymt_amt").cast(DoubleType()))
    .withColumn("avg_mdcr_stdzd_amt", col("avg_mdcr_stdzd_amt").cast(DoubleType()))
    .withColumn("is_drug", when(col("hcpcs_drug_ind") == "Y", 1).otherwise(0))
    .withColumn("rndrng_prvdr_cntry", upper(trim(col("rndrng_prvdr_cntry"))))
    .filter(col("rndrng_prvdr_cntry").isNull() | col("rndrng_prvdr_cntry").isin("US","USA"))
)


In [32]:
# aggregating to provider level (weighted avgs + diversity)

# keeping clean rows and casting weights
# only aggregating valid NPIs
p_agg_in = (
    p.filter(col("npi_valid") == 1)
     .withColumn("srvcs_d", col("tot_srvcs").cast("double"))
     .withColumn("benes_d", col("tot_benes").cast("double"))
     .withColumn("bene_days_d", col("tot_bene_day_srvcs").cast("double"))
)

# grouping to provider and building helper sums
agg = (
    p_agg_in.groupBy("rndrng_npi").agg(
        sum_("srvcs_d").alias("total_services"),
        sum_("benes_d").alias("total_beneficiaries"),
        sum_("bene_days_d").alias("total_bene_day_services"),

        # weighted sums for means
        sum_(col("avg_sbmtd_chrg") * col("srvcs_d")).alias("sum_submitted_w"),
        sum_(col("avg_mdcr_alowd_amt") * col("srvcs_d")).alias("sum_allowed_w"),
        sum_(col("avg_mdcr_pymt_amt") * col("srvcs_d")).alias("sum_payment_w"),

        # weighted second moment for variance
        sum_(pow(col("avg_sbmtd_chrg"), 2) * col("srvcs_d")).alias("sum_submitted_sq_w"),

        # service mix/ data quality signals
        countDistinct("hcpcs_cd").alias("num_unique_procedures"),
        avg(col("is_drug")).alias("frac_drug_services"),
        avg(col("missing_zip")).alias("frac_missing_zip")
    )
)


provider_agg = (
    agg
    # weighted means (per service)
    .withColumn("w_avg_submitted_charge",
        when(col("total_services") > 0, col("sum_submitted_w") / col("total_services")))
    .withColumn("w_avg_allowed",
        when(col("total_services") > 0, col("sum_allowed_w") / col("total_services")))
    .withColumn("w_avg_payment",
        when(col("total_services") > 0, col("sum_payment_w") / col("total_services")))

    # weighted variance/stddev of submitted charge
    .withColumn("w_var_submitted_charge",
        when(col("total_services") > 0,
             (col("sum_submitted_sq_w") / col("total_services")) -
             pow(col("sum_submitted_w") / col("total_services"), 2)))
    .withColumn("w_stddev_submitted_charge",
        when(col("w_var_submitted_charge").isNotNull(),
             sqrt(greatest(lit(0.0), col("w_var_submitted_charge")))))

    # pricing intensity vs. medicare benchmarks
    .withColumn("charge_allowed_ratio",
        when(col("sum_allowed_w") > 0, col("sum_submitted_w") / col("sum_allowed_w")))
    .withColumn("payment_allowed_ratio",
        when(col("sum_allowed_w") > 0, col("sum_payment_w") / col("sum_allowed_w")))

    # dropping helper sums now that we’ve derived features
    .drop("sum_submitted_sq_w", "sum_submitted_w", "sum_allowed_w", "sum_payment_w")

    # tidying up + existing derived rates
    .withColumnRenamed("rndrng_npi", "npi")
    .withColumn("services_per_bene",
        when(col("total_beneficiaries") > 0, col("total_services") / col("total_beneficiaries")))
    .withColumn("bene_days_per_bene",
        when(col("total_beneficiaries") > 0, col("total_bene_day_services") / col("total_beneficiaries")))
)




print("PUF aggregated providers:", provider_agg.count())



PUF aggregated providers: 1175213


                                                                                

## 'List of Excluded Individuals and Entities' (LEIE) - OIG Dataset

In [33]:
# creating a csv path variable
leie_csv_path = "Office of Inspector General - Excluded Individuals and Entities/20250710 LEIE.csv"

In [34]:
# reading the CSV file into a DataFrame
full_leie_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(leie_csv_path)
)

In [35]:
# inspecting the column names
print("Available columns:")
for c in full_leie_df.columns[:]:
    print(repr(c))

Available columns:
'LASTNAME'
'FIRSTNAME'
'MIDNAME'
'BUSNAME'
'GENERAL'
'SPECIALTY'
'UPIN'
'NPI'
'DOB'
'ADDRESS'
'CITY'
'STATE'
'ZIP'
'EXCLTYPE'
'EXCLDATE'
'REINDATE'
'WAIVERDATE'
'WVRSTATE'


In [36]:
# selecting the relevant columns I want to keep
keep_cols_leie = ["NPI","LASTNAME","FIRSTNAME","MIDNAME","BUSNAME","ADDRESS","CITY","STATE",
                  "ZIP","SPECIALTY","EXCLTYPE","EXCLDATE","REINDATE","DOB","WAIVERDATE","WVRSTATE","UPIN"]

In [37]:
# creating a new DataFrame with only the selected columns
leie_df = full_leie_df.select(*keep_cols_leie)

In [38]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
leie_df.printSchema()
print("Sample:")
leie_df.limit(5).show(truncate=False)

Schema:
root
 |-- NPI: string (nullable = true)
 |-- LASTNAME: string (nullable = true)
 |-- FIRSTNAME: string (nullable = true)
 |-- MIDNAME: string (nullable = true)
 |-- BUSNAME: string (nullable = true)
 |-- ADDRESS: string (nullable = true)
 |-- CITY: string (nullable = true)
 |-- STATE: string (nullable = true)
 |-- ZIP: string (nullable = true)
 |-- SPECIALTY: string (nullable = true)
 |-- EXCLTYPE: string (nullable = true)
 |-- EXCLDATE: string (nullable = true)
 |-- REINDATE: string (nullable = true)
 |-- DOB: string (nullable = true)
 |-- WAIVERDATE: string (nullable = true)
 |-- WVRSTATE: string (nullable = true)
 |-- UPIN: string (nullable = true)

Sample:
+----------+--------+---------+-------+---------------------------+-----------------------------+----------+-----+-----+------------------+--------+--------+--------+----+----------+--------+----+
|NPI       |LASTNAME|FIRSTNAME|MIDNAME|BUSNAME                    |ADDRESS                      |CITY      |STATE|ZIP  |SPECIA

In [39]:
# getting the total number of rows in the DataFrame
total_rows_leie = leie_df.count()
print(f'Total rows: {total_rows_leie}')

Total rows: 81774


In [40]:
# getting the null counts for each column
leie_df.select(
    *[sum_(col(c).isNull().cast('int')).alias(c) for c in leie_df.columns]
).show()

+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+
|NPI|LASTNAME|FIRSTNAME|MIDNAME|BUSNAME|ADDRESS|CITY|STATE|ZIP|SPECIALTY|EXCLTYPE|EXCLDATE|REINDATE| DOB|WAIVERDATE|WVRSTATE| UPIN|
+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+
|  0|    3358|     3358|  24009|  78416|      4|   0|    0|  0|     4088|       0|       0|       0|4222|         0|   81762|75791|
+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+



In [41]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

leie_cols = [normalize_col(c) for c in keep_cols_leie]
if len(leie_cols) != len(set(leie_cols)):
    raise RuntimeError("Column name conflict after normalization in LEIE.")
leie = leie_df.select(*[col(o).alias(n) for o,n in zip(keep_cols_leie, leie_cols)])

In [42]:
# cleaning and casting the data types
# trimming, normalizing keys, parsing dates (guarded)

# stripping whitespace and turning empty strings into nulls
for c_, t_ in leie.dtypes:
    if t_ == "string":
        leie = leie.withColumn(c_, when(trim(col(c_)) == "", None).otherwise(trim(col(c_))))

# normalizing keys and categorical values
leie = (
    leie
    .withColumn("npi", regexp_replace(col("npi"), r"[^0-9]", ""))
    .withColumn("npi_valid", when(col("npi").rlike(r"^(?!0{10})\d{10}$"), 1).otherwise(0).cast(IntegerType()))
    .withColumn("zip5", substring(regexp_replace(col("zip"), r"[^0-9]", ""), 1, 5))
    .drop("zip")
    .withColumn("state", when(col("state").rlike(r"^[A-Za-z]{2}$"), upper(col("state"))))
    .withColumnRenamed("wvrstate", "waiverstate")
    .withColumn("waiverstate", when(col("waiverstate").rlike(r"^[A-Za-z]{2}$"), upper(col("waiverstate"))))
    .withColumn("excltype", upper(col("excltype")))
    .withColumn("specialty", upper(col("specialty")))
)


def zeroish(colname):
    digits = regexp_replace(col(colname), r"[^0-9]", "")
    return when((digits == "00000000") | (length(digits) == 0), lit(None)).otherwise(col(colname))


def parse_date_guarded(colname: str):
    raw = zeroish(colname)
    digits = regexp_replace(raw, r"[^0-9]", "")
    ymd8 = when(digits.rlike(r"^\d{8}$"), to_date(digits, "yyyyMMdd"))
    mdyyyy = when(raw.rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"), to_date(raw, "M/d/yyyy"))
    ymd_dash = when(raw.rlike(r"^\d{4}-\d{2}-\d{2}$"), to_date(raw, "yyyy-MM-dd"))
    return coalesce(ymd8, mdyyyy, ymd_dash)

# producing typed date columns and dropping the original string columns
leie = (
    leie
    .withColumn("excldate_dt",   parse_date_guarded("excldate"))
    .withColumn("reindate_dt",   parse_date_guarded("reindate"))
    .withColumn("waiverdate_dt", parse_date_guarded("waiverdate"))
    .withColumn("dob_dt",        parse_date_guarded("dob"))
    .drop("excldate","reindate","waiverdate","dob")
)

In [43]:
# putting a time gate on LEIE so my labels reflect only what was knowable on the training cut-off date
# this date was picked because my Medicare Physician and Practitioners dataset was frozen at this date

AS_OF = to_date(lit(AS_OF_STR))
leie = leie.filter(col("excldate_dt").isNull() | (col("excldate_dt") <= AS_OF))

In [44]:
# keeping only valid NPIs
leie_with_npi = leie.filter(col("npi_valid") == 1)


# collapsing LEIE rows to one row per NPI
leie_status = (
    leie_with_npi
    .groupBy("npi")
    .agg(
        max(when(col("excldate_dt") <= AS_OF, col("excldate_dt"))).alias("last_excl_pre"),
        max(when(col("reindate_dt") <= AS_OF, col("reindate_dt"))).alias("last_rein_pre"),
        sum_(when(col("excldate_dt") <= AS_OF, lit(1)).otherwise(lit(0))).cast("int").alias("excl_events_pre")
    )


    # buidling two labels
    .withColumn("ever_excluded_asof", when(col("excl_events_pre") > 0, lit(1)).otherwise(lit(0)))
    .withColumn(
        "excluded_asof",
        when(
            col("last_excl_pre").isNotNull() &
            (col("last_rein_pre").isNull() | (col("last_rein_pre") < col("last_excl_pre"))),
            lit(1)
        ).otherwise(lit(0))
    )

    # keeping only the fields I need for the final join
    .select("npi","ever_excluded_asof","excluded_asof")
)


print("LEIE status NPIs:", leie_status.count())


[Stage 37:>                                                         (0 + 4) / 4]

LEIE status NPIs: 6711


                                                                                

## Joining the datasets together for one consolidated dataset
## I will also be downsampling in the code below so that my 'Fraud' consists of around 2-5% of the total data

In [46]:
# joining the datasets and using NPPES as the base


# time-safe NPPES features (no leakage past AS_OF)
AS_OF = to_date(lit(AS_OF_STR))
npi_feat = (
    wrangled_npi_df
    .filter( (col("provider_enumeration_date").isNull()) | (col("provider_enumeration_date") <= AS_OF) )
    .select(
        "npi","state_abbr","zip5","entity_type_code","primary_taxonomy",
        "npi_age_days","is_active","is_organization_subpart","is_sole_proprietor"
    )
    .distinct()
)







# building base row per NPI (NPPES base, optional PUF features)
base = (
    npi_feat.alias("n")
    .join(provider_agg.alias("p"), on="npi", how="left")
    .withColumn("has_puf", when(col("total_services").isNotNull(), lit(1)).otherwise(lit(0)))
)






# joining LEIE status (as-of labels)
merged = (
    base.join(broadcast(leie_status.alias("l")), on="npi", how="left")
        .withColumn("is_fraud",         coalesce(col("ever_excluded_asof"), lit(0)).cast("int"))  # training target
        .withColumn("is_excluded_asof", coalesce(col("excluded_asof"),     lit(0)).cast("int"))  # for reporting/slices
)






# filling numeric PUF (medicare physicians and practitioners data) columns to avoid nulls in the NN
num_cols = [
    "total_services","total_beneficiaries","total_bene_day_services",
    "w_avg_submitted_charge","w_avg_allowed","w_avg_payment",
    "charge_allowed_ratio","payment_allowed_ratio",
    "num_unique_procedures","w_stddev_submitted_charge",
    "frac_drug_services","frac_missing_zip",
    "services_per_bene","bene_days_per_bene"
]
for c in [c for c in num_cols if c in merged.columns]:
    merged = merged.withColumn(c, coalesce(col(c), lit(0.0)))


print("Merged row sample (NPPES base):")
merged.select("npi","has_puf","is_fraud").limit(10).show(truncate=False)
print("Class counts BEFORE sampling:")
merged.groupBy("is_fraud").count().show()





# downsampling negatives to target ~3–4% positives
TARGET_POS_RATE = 0.035   
SEED = 42

pos_ct = merged.filter(col("is_fraud") == 1).count()
neg_ct = merged.filter(col("is_fraud") == 0).count()
print(f"Before sampling — positives: {pos_ct:,}  negatives: {neg_ct:,}")

if pos_ct == 0:
    # no positives — keeping a modest slice just to avoid huge data (or skip sampling entirely)
    final_df = merged.sample(False, 0.1, seed=SEED)
else:
    # solving for how many negatives to keep so: pos / (pos + neg_kept) = TARGET_POS_RATE
    # => neg_kept = pos * (1 - r) / r
    neg_keep = int(pos_ct * (1.0 - TARGET_POS_RATE) / TARGET_POS_RATE)
    # converting to a sampling fraction over available negatives
    neg_frac = py.min(1.0, neg_keep / float(py.max(neg_ct, 1)))
    # keeping ALL positives; sample negatives to the fraction
    final_df = merged.sampleBy("is_fraud", fractions={1: 1.0, 0: neg_frac}, seed=SEED)

print("Class counts AFTER sampling:")
final_df.groupBy("is_fraud").count().show()







# quick prevalence check
total_after = final_df.count()
pos_after = final_df.filter(col("is_fraud") == 1).count()
if total_after > 0:
    print(f"Post-sample prevalence ≈ {100.0 * pos_after/total_after:.2f}%  "
          f"({pos_after:,}/{total_after:,})")


Merged row sample (NPPES base):


                                                                                

+----------+-------+--------+
|npi       |has_puf|is_fraud|
+----------+-------+--------+
|1003000134|1      |0       |
|1003000183|0      |0       |
|1003000191|0      |0       |
|1003000233|0      |0       |
|1003000274|0      |0       |
|1003000308|0      |0       |
|1003000431|0      |0       |
|1003000548|0      |0       |
|1003000662|0      |0       |
|1003000670|0      |0       |
+----------+-------+--------+

Class counts BEFORE sampling:


                                                                                

+--------+-------+
|is_fraud|  count|
+--------+-------+
|       1|   6711|
|       0|8119224|
+--------+-------+



                                                                                

Before sampling — positives: 6,711  negatives: 8,119,224
Class counts AFTER sampling:


                                                                                

+--------+------+
|is_fraud| count|
+--------+------+
|       1|  6711|
|       0|185442|
+--------+------+





Post-sample prevalence ≈ 3.49%  (6,711/192,153)


                                                                                

In [47]:
# checking that 'has_puf' distribution isn't wildly different post-sample
merged.groupBy("has_puf").count().show()
final_df.groupBy("has_puf").count().show()

# checking prevalence by entity type didn't skew
final_df.groupBy("is_fraud", "entity_type_code").count().show()


                                                                                

+-------+-------+
|has_puf|  count|
+-------+-------+
|      1|1175210|
|      0|6950725|
+-------+-------+



                                                                                

+-------+------+
|has_puf| count|
+-------+------+
|      1| 26944|
|      0|165209|
+-------+------+





+--------+----------------+------+
|is_fraud|entity_type_code| count|
+--------+----------------+------+
|       0|               1|140027|
|       1|            NULL|   772|
|       1|               2|   443|
|       1|               1|  5496|
|       0|               2| 38359|
|       0|            NULL|  7056|
+--------+----------------+------+



                                                                                

In [48]:
# quick sanity check (no need to persist/repartition at this size)
final_df.groupBy("is_fraud").count().show()

# writing a small number of Parquet files (snappy is default-friendly)
outpath = f"curated/training/providers_nn_asof_{AS_OF_STR}.parquet"
(final_df.coalesce(4)              # 1–4 files is fine
         .write
         .mode("overwrite")
         .option("compression", "snappy")
         .parquet(outpath))

print("Saved training table to:", outpath)


                                                                                

+--------+------+
|is_fraud| count|
+--------+------+
|       1|  6711|
|       0|185442|
+--------+------+



[Stage 189:>                                                        (0 + 4) / 4]

Saved training table to: curated/training/providers_nn_asof_2023-12-31.parquet


                                                                                

In [49]:
# stopping the Spark session
spark.stop()
print("Spark session stopped.")

Spark session stopped.
