# Provider Fraud Detection Model - Data Wrangling
### In this notebook we will focus on collecting, organizing, defining, and cleaning the relevant datasets for the Provider Fraud Detection model
#### 1.) Ingesting & Inspecting Files
#### 2.) Standardizing Column Names & Types
#### 3.) Cleaning Values
#### 4.) Deduping (Removing Duplicates)
#### 5.) Merging/Joining Tables

In [46]:
# loading needed modules
import re
from collections import Counter
from functools import reduce

from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    DoubleType,
    DateType
)

from pyspark.sql.functions import (
    col,
    expr,
    lit,
    trim,
    to_date,
    current_date,
    datediff,
    sum as sum_,
    avg,
    stddev,
    countDistinct,
    when,
    try_divide,
    row_number,
    min,
    max,
    mean,
    isnan,
    isnull,
    regexp_replace,
    substring,
    upper,
    coalesce,
    sha2,
    length,
    pow,
    sqrt,
    greatest
)

import builtins as py  # add once near the top


In [2]:
# this function will help me normalize my column names to ensure a smooth workflow
# this function will...
    # lowercase, trim whitespace
    # drop parenthetical notes
    # replace non-alphanumeric chars with underscores
    # collapse runs of underscores
    # strip leading/trailing underscores

def normalize_col(column_name: str) -> str:
    text = column_name.strip().lower()
    text = re.sub(r'\(.*?\)', '', text)
    text = re.sub(r'[^0-9a-z]+', '_', text)
    text = re.sub(r'_+', '_', text)
    return text.strip('_')

In [3]:
# starting Spark session
spark = (
    SparkSession.builder
    .appName("ProviderFraudDetection-NN")
    .master("local[*]")
    .config("spark.driver.memory", "12g")
    .config("spark.sql.shuffle.partitions", "16")
    .config("spark.sql.adaptive.enabled", "true")
    .config("spark.memory.fraction", "0.6")
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/18 20:47:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [4]:
# time freeze to prevent leakage (aligns with my Medicare Physician and Practitioners dataset)
AS_OF_STR = "2023-12-31"
print("Training as-of date:", AS_OF_STR)

Training as-of date: 2023-12-31


## NPPES NPI Registry Dataset

In [5]:
# creating a csv path variable
npi_csv_path = "NPPES_Data_Dissemination_July_2025_V2/npidata_pfile_20050523-20250713.csv"


In [6]:
# reading the CSV file into a DataFrame
full_npi_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(npi_csv_path)
)

In [7]:
# inspecting the npi dataset column names
print("Available columns (first 50 shown):")
for c in full_npi_df.columns[:50]:
    print(repr(c))
    

Available columns (first 50 shown):
'NPI'
'Entity Type Code'
'Replacement NPI'
'Employer Identification Number (EIN)'
'Provider Organization Name (Legal Business Name)'
'Provider Last Name (Legal Name)'
'Provider First Name'
'Provider Middle Name'
'Provider Name Prefix Text'
'Provider Name Suffix Text'
'Provider Credential Text'
'Provider Other Organization Name'
'Provider Other Organization Name Type Code'
'Provider Other Last Name'
'Provider Other First Name'
'Provider Other Middle Name'
'Provider Other Name Prefix Text'
'Provider Other Name Suffix Text'
'Provider Other Credential Text'
'Provider Other Last Name Type Code'
'Provider First Line Business Mailing Address'
'Provider Second Line Business Mailing Address'
'Provider Business Mailing Address City Name'
'Provider Business Mailing Address State Name'
'Provider Business Mailing Address Postal Code'
'Provider Business Mailing Address Country Code (If outside U.S.)'
'Provider Business Mailing Address Telephone Number'
'Provider B

In [8]:
# selecting the relevant columns I want to keep
tax_cols = []
for i in range(1, 16):
    tax_cols += [f"Healthcare Provider Taxonomy Code_{i}",
                 f"Healthcare Provider Primary Taxonomy Switch_{i}"]

keep_cols_npi = [
    "NPI",
    "Entity Type Code",

    # person/org name columns (add these so wrangled_npi_df has names)
    "Provider First Name",
    "Provider Last Name (Legal Name)",
    "Provider Middle Name",                     # optional
    "Provider Organization Name (Legal Business Name)",

    "Provider Business Practice Location Address State Name",
    "Provider Business Practice Location Address Postal Code",
    "Is Organization Subpart",
    "Parent Organization TIN",
    "Parent Organization LBN",
    "Provider Enumeration Date",
    "Last Update Date",
    "NPI Deactivation Date",
    "NPI Reactivation Date",
    "Is Sole Proprietor",
] + tax_cols


In [9]:
# creating a new DataFrame with only the selected columns
npi_df = full_npi_df.select(*keep_cols_npi)


In [10]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
npi_df.printSchema()
print("Sample:")
npi_df.limit(5).show(truncate=False)

Schema:
root
 |-- NPI: string (nullable = true)
 |-- Entity Type Code: string (nullable = true)
 |-- Provider First Name: string (nullable = true)
 |-- Provider Last Name (Legal Name): string (nullable = true)
 |-- Provider Middle Name: string (nullable = true)
 |-- Provider Organization Name (Legal Business Name): string (nullable = true)
 |-- Provider Business Practice Location Address State Name: string (nullable = true)
 |-- Provider Business Practice Location Address Postal Code: string (nullable = true)
 |-- Is Organization Subpart: string (nullable = true)
 |-- Parent Organization TIN: string (nullable = true)
 |-- Parent Organization LBN: string (nullable = true)
 |-- Provider Enumeration Date: string (nullable = true)
 |-- Last Update Date: string (nullable = true)
 |-- NPI Deactivation Date: string (nullable = true)
 |-- NPI Reactivation Date: string (nullable = true)
 |-- Is Sole Proprietor: string (nullable = true)
 |-- Healthcare Provider Taxonomy Code_1: string (nullable 

25/08/18 20:47:30 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


In [11]:
# getting the total number of rows in the DataFrame
total_rows = npi_df.count()
print(f'Total rows: {total_rows}')



Total rows: 9026996


                                                                                

In [12]:
# getting the null counts for each column
npi_df.select(
    *[
        sum_(col(c).isNull().cast('int')).alias(c) for c in npi_df.columns
    ]
).show()



+---+----------------+-------------------+-------------------------------+--------------------+------------------------------------------------+------------------------------------------------------+-------------------------------------------------------+-----------------------+-----------------------+-----------------------+-------------------------+----------------+---------------------+---------------------+------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------

                                                                                

In [13]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

new_names = [normalize_col(c) for c in keep_cols_npi]
if len(new_names) != len(set(new_names)):
    raise RuntimeError("Column name conflict after normalization in NPPES.")
normalized_npi_df = npi_df.select(*[col(o).alias(n) for o,n in zip(keep_cols_npi, new_names)])

In [14]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
normalized_npi_df.printSchema()
normalized_npi_df.show(5, truncate=False)

root
 |-- npi: string (nullable = true)
 |-- entity_type_code: string (nullable = true)
 |-- provider_first_name: string (nullable = true)
 |-- provider_last_name: string (nullable = true)
 |-- provider_middle_name: string (nullable = true)
 |-- provider_organization_name: string (nullable = true)
 |-- provider_business_practice_location_address_state_name: string (nullable = true)
 |-- provider_business_practice_location_address_postal_code: string (nullable = true)
 |-- is_organization_subpart: string (nullable = true)
 |-- parent_organization_tin: string (nullable = true)
 |-- parent_organization_lbn: string (nullable = true)
 |-- provider_enumeration_date: string (nullable = true)
 |-- last_update_date: string (nullable = true)
 |-- npi_deactivation_date: string (nullable = true)
 |-- npi_reactivation_date: string (nullable = true)
 |-- is_sole_proprietor: string (nullable = true)
 |-- healthcare_provider_taxonomy_code_1: string (nullable = true)
 |-- healthcare_provider_primary_ta

In [None]:
# casting and standardizing the data types
clean_npi = (
    normalized_npi_df
    .withColumn("npi", regexp_replace(trim(col("npi")), r"\D", ""))
    .withColumn("entity_type_code", col("entity_type_code").cast("int"))
    .withColumn(
        "is_organization_subpart",
        when(upper(trim(col("is_organization_subpart"))) == "Y", lit(1))
        .when(upper(trim(col("is_organization_subpart"))) == "N", lit(0))
        .otherwise(lit(None).cast("int"))
    )
    .withColumn("is_sole_proprietor_raw", upper(trim(col("is_sole_proprietor"))))
    .withColumn(
        "is_sole_proprietor",
        when(col("entity_type_code") == 1,
             when(col("is_sole_proprietor_raw") == "Y", 1)
             .when(col("is_sole_proprietor_raw") == "N", 0)
             .otherwise(None)
        ).otherwise(None)
    ).drop("is_sole_proprietor_raw")
    .withColumn("provider_enumeration_date",
        when(col("provider_enumeration_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("provider_enumeration_date"), "M/d/yyyy"))
    )
    .withColumn("last_update_date",
        when(col("last_update_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("last_update_date"), "M/d/yyyy"))
    )
    .withColumn("npi_deactivation_date",
        when(col("npi_deactivation_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("npi_deactivation_date"), "M/d/yyyy"))
    )
    .withColumn("npi_reactivation_date",
        when(col("npi_reactivation_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("npi_reactivation_date"), "M/d/yyyy"))
    )
    # zip/state standardization
    .withColumn("zip5",
        substring(regexp_replace(col("provider_business_practice_location_address_postal_code"), r"[^0-9]", ""), 1, 5)
    )
    .withColumn("state_abbr",
        when(trim(col("provider_business_practice_location_address_state_name")).rlike(r"^[A-Za-z]{2}$"),
             upper(trim(col("provider_business_practice_location_address_state_name")))
        )
    )
    .withColumn("parent_org_tin_norm", regexp_replace(col("parent_organization_tin"), r"\D", ""))
)


In [None]:
# --- Primary taxonomy: Y wins; else first non-null; track how it was chosen ---
cand_exprs = []
any_x = None
for i in range(1, 16):
    sw_u = upper(col(f"healthcare_provider_primary_taxonomy_switch_{i}"))
    cd   = col(f"healthcare_provider_taxonomy_code_{i}")
    cand_exprs.append(when((sw_u == "Y") & cd.isNotNull(), cd))
    any_x = (sw_u == "X") if any_x is None else (any_x | (sw_u == "X"))

explicit_expr = coalesce(*cand_exprs)  # first slot marked Y
first_code    = coalesce(*[col(f"healthcare_provider_taxonomy_code_{i}") for i in range(1, 16)])

clean_npi = (
    clean_npi
    .withColumn(
        "primary_taxonomy",
        when(explicit_expr.isNotNull(), explicit_expr)     # explicit Y
        .when(first_code.isNotNull(), first_code)          # fallback regardless of X
        .otherwise(lit(None))
    )
    # Diagnostics for QA/EDA
    .withColumn("primary_taxonomy_explicit", when(explicit_expr.isNotNull(), lit(1)).otherwise(lit(0)))
    .withColumn("primary_taxonomy_inferred", when(explicit_expr.isNull() & first_code.isNotNull(), lit(1)).otherwise(lit(0)))
    .withColumn("primary_taxonomy_from_x",   when(explicit_expr.isNull() & any_x & first_code.isNotNull(), lit(1)).otherwise(lit(0)))
)

In [17]:
# add two provider lifecycle features (referenced later in npi_feat)
clean_npi = (
    clean_npi
    .withColumn("npi_age_days", datediff(to_date(lit(AS_OF_STR)), col("provider_enumeration_date")))
    .withColumn(
        "is_active",
        when(
            (col("npi_deactivation_date").isNull()) |
            ((col("npi_reactivation_date").isNotNull()) & (col("npi_reactivation_date") >= col("npi_deactivation_date"))),
            1
        ).otherwise(0)
    )
)


In [18]:
# filtering out invalid NPIs (not 10 digits)
clean_npi = clean_npi.filter(col("npi").rlike(r"^\d{10}$"))

In [19]:
# Deduplicate to one row per NPI (latest update, then most complete)
comp_score = (
    when(coalesce(col("primary_taxonomy"), first_code).isNotNull(), 1).otherwise(0)
    + when(col("state_abbr").isNotNull(), 1).otherwise(0)
    + when(col("provider_enumeration_date").isNotNull(), 1).otherwise(0)
    + when(col("entity_type_code").isNotNull(), 1).otherwise(0)
).alias("completeness_score")

npi_scored = clean_npi.withColumn("completeness_score", comp_score)

from pyspark.sql.functions import desc_nulls_last

w = Window.partitionBy("npi").orderBy(
    desc_nulls_last("last_update_date"),
    col("completeness_score").desc()
)


wrangled_npi_df = (
    npi_scored
    .withColumn("rn", row_number().over(w))
    .filter(col("rn") == 1)
    .drop("rn")
)

print("NPPES wrangled rows:", wrangled_npi_df.count())



NPPES wrangled rows: 9026996


                                                                                

#### The NPPES NPI dataset has almost been fully wrangled as "wrangled_npi_df"
#### Stopping here to wrangle others...

## Medicare Physician & Other Practitioners - by Provider and Service Dataset

In [20]:
# creating a csv path variable
phys_pract_csv_path = "Medicare Physician & Other Practitioners - by Provider and Service/MUP_PHY_R25_P05_V20_D23_Prov_Svc.csv"

In [21]:
# reading the CSV file into a DataFrame
full_physician_practitioner_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(phys_pract_csv_path)
)

In [22]:
# inspecting the column names
print("Available columns (first 50 shown):")
for c in full_physician_practitioner_df.columns[:50]:
    print(repr(c))

Available columns (first 50 shown):
'Rndrng_NPI'
'Rndrng_Prvdr_Last_Org_Name'
'Rndrng_Prvdr_First_Name'
'Rndrng_Prvdr_MI'
'Rndrng_Prvdr_Crdntls'
'Rndrng_Prvdr_Ent_Cd'
'Rndrng_Prvdr_St1'
'Rndrng_Prvdr_St2'
'Rndrng_Prvdr_City'
'Rndrng_Prvdr_State_Abrvtn'
'Rndrng_Prvdr_State_FIPS'
'Rndrng_Prvdr_Zip5'
'Rndrng_Prvdr_RUCA'
'Rndrng_Prvdr_RUCA_Desc'
'Rndrng_Prvdr_Cntry'
'Rndrng_Prvdr_Type'
'Rndrng_Prvdr_Mdcr_Prtcptg_Ind'
'HCPCS_Cd'
'HCPCS_Desc'
'HCPCS_Drug_Ind'
'Place_Of_Srvc'
'Tot_Benes'
'Tot_Srvcs'
'Tot_Bene_Day_Srvcs'
'Avg_Sbmtd_Chrg'
'Avg_Mdcr_Alowd_Amt'
'Avg_Mdcr_Pymt_Amt'
'Avg_Mdcr_Stdzd_Amt'


In [23]:
# selecting the relevant columns I want to keep
keep_cols_puf = [
    "Rndrng_NPI","Rndrng_Prvdr_Ent_Cd","Rndrng_Prvdr_State_Abrvtn","Rndrng_Prvdr_Zip5",
    "Rndrng_Prvdr_RUCA","Rndrng_Prvdr_RUCA_Desc","Rndrng_Prvdr_Mdcr_Prtcptg_Ind",
    "Rndrng_Prvdr_Type","Rndrng_Prvdr_Cntry",
    "HCPCS_Cd","HCPCS_Desc","HCPCS_Drug_Ind","Place_Of_Srvc",
    "Tot_Benes","Tot_Srvcs","Tot_Bene_Day_Srvcs",
    "Avg_Sbmtd_Chrg","Avg_Mdcr_Alowd_Amt","Avg_Mdcr_Pymt_Amt","Avg_Mdcr_Stdzd_Amt"
]

In [24]:
# creating a new DataFrame with only the selected columns
phys = full_physician_practitioner_df.select(*keep_cols_puf)

In [25]:
# getting the total number of rows in the DataFrame
total_rows_phys_pract = phys.count()
print(f'Total rows: {total_rows_phys_pract}')



Total rows: 9660647


                                                                                

In [26]:
# getting the null counts for each column
phys.select(
    *[
        sum_(col(c).isNull().cast('int')).alias(c) for c in phys.columns
    ]
).show()



+----------+-------------------+-------------------------+-----------------+-----------------+----------------------+-----------------------------+-----------------+------------------+--------+----------+--------------+-------------+---------+---------+------------------+--------------+------------------+-----------------+------------------+
|Rndrng_NPI|Rndrng_Prvdr_Ent_Cd|Rndrng_Prvdr_State_Abrvtn|Rndrng_Prvdr_Zip5|Rndrng_Prvdr_RUCA|Rndrng_Prvdr_RUCA_Desc|Rndrng_Prvdr_Mdcr_Prtcptg_Ind|Rndrng_Prvdr_Type|Rndrng_Prvdr_Cntry|HCPCS_Cd|HCPCS_Desc|HCPCS_Drug_Ind|Place_Of_Srvc|Tot_Benes|Tot_Srvcs|Tot_Bene_Day_Srvcs|Avg_Sbmtd_Chrg|Avg_Mdcr_Alowd_Amt|Avg_Mdcr_Pymt_Amt|Avg_Mdcr_Stdzd_Amt|
+----------+-------------------+-------------------------+-----------------+-----------------+----------------------+-----------------------------+-----------------+------------------+--------+----------+--------------+-------------+---------+---------+------------------+--------------+------------------+------

                                                                                

In [27]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

# applying the normalized names to every column I plan to keep
puf_names = [normalize_col(c) for c in keep_cols_puf]
if len(puf_names) != len(set(puf_names)):
    raise RuntimeError("Column name conflict after normalization in PUF.")

In [28]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
p = phys.select(*[col(o).alias(n) for o,n in zip(keep_cols_puf, puf_names)])
p.printSchema()
p.show(5, truncate=False)

root
 |-- rndrng_npi: string (nullable = true)
 |-- rndrng_prvdr_ent_cd: string (nullable = true)
 |-- rndrng_prvdr_state_abrvtn: string (nullable = true)
 |-- rndrng_prvdr_zip5: string (nullable = true)
 |-- rndrng_prvdr_ruca: string (nullable = true)
 |-- rndrng_prvdr_ruca_desc: string (nullable = true)
 |-- rndrng_prvdr_mdcr_prtcptg_ind: string (nullable = true)
 |-- rndrng_prvdr_type: string (nullable = true)
 |-- rndrng_prvdr_cntry: string (nullable = true)
 |-- hcpcs_cd: string (nullable = true)
 |-- hcpcs_desc: string (nullable = true)
 |-- hcpcs_drug_ind: string (nullable = true)
 |-- place_of_srvc: string (nullable = true)
 |-- tot_benes: string (nullable = true)
 |-- tot_srvcs: string (nullable = true)
 |-- tot_bene_day_srvcs: string (nullable = true)
 |-- avg_sbmtd_chrg: string (nullable = true)
 |-- avg_mdcr_alowd_amt: string (nullable = true)
 |-- avg_mdcr_pymt_amt: string (nullable = true)
 |-- avg_mdcr_stdzd_amt: string (nullable = true)

+----------+-------------------+

In [29]:
# cleaning and casting the data types

p = (
    p
    .withColumn("rndrng_npi", trim(col("rndrng_npi")))
    .withColumn("npi_valid", when(col("rndrng_npi").rlike(r"^\d{10}$"), 1).otherwise(0))
    .withColumn("is_individual", when(col("rndrng_prvdr_ent_cd") == "I", 1).otherwise(0))
    .withColumn("rndrng_prvdr_state_abrvtn", trim(col("rndrng_prvdr_state_abrvtn")))
    .withColumn("rndrng_prvdr_zip5", trim(col("rndrng_prvdr_zip5")))
    .withColumn("missing_state", when(col("rndrng_prvdr_state_abrvtn").isNull(), 1).otherwise(0))
    .withColumn("missing_zip", when(col("rndrng_prvdr_zip5").isNull(), 1).otherwise(0))
    .withColumn("rndrng_prvdr_ruca", col("rndrng_prvdr_ruca").cast(DoubleType()))
    .withColumn("medicare_participation", when(col("rndrng_prvdr_mdcr_prtcptg_ind") == "Y", 1).otherwise(0))
    .withColumn("hcpcs_cd", trim(col("hcpcs_cd")))
    .withColumn("place_of_srvc", trim(col("place_of_srvc")))
    .withColumn("avg_sbmtd_chrg", col("avg_sbmtd_chrg").cast(DoubleType()))
    .withColumn("avg_mdcr_alowd_amt", col("avg_mdcr_alowd_amt").cast(DoubleType()))
    .withColumn("avg_mdcr_pymt_amt", col("avg_mdcr_pymt_amt").cast(DoubleType()))
    .withColumn("avg_mdcr_stdzd_amt", col("avg_mdcr_stdzd_amt").cast(DoubleType()))
    .withColumn("is_drug", when(col("hcpcs_drug_ind") == "Y", 1).otherwise(0))
    .withColumn("rndrng_prvdr_cntry", upper(trim(col("rndrng_prvdr_cntry"))))
    .filter(col("rndrng_prvdr_cntry").isNull() | col("rndrng_prvdr_cntry").isin("US","USA"))
)


##### Because of the nature of this dataset, I will not be doing a traditional deduping process
##### This dataset is SUPPOSED to have multiple lines per provider
##### I will be checking for true duplicates

In [None]:
# Aggregate to provider level (weighted avgs + diversity)
p_agg_in = (
    p.filter(col("npi_valid") == 1)
     .withColumn("srvcs_d", col("tot_srvcs").cast("double"))
     .withColumn("benes_d", col("tot_benes").cast("double"))
     .withColumn("bene_days_d", col("tot_bene_day_srvcs").cast("double"))
)

agg = (
    p_agg_in.groupBy("rndrng_npi").agg(
        sum_("srvcs_d").alias("total_services"),
        sum_("benes_d").alias("total_beneficiaries"),
        sum_("bene_days_d").alias("total_bene_day_services"),
        sum_(col("avg_sbmtd_chrg") * col("srvcs_d")).alias("sum_submitted_w"),
        sum_(col("avg_mdcr_alowd_amt") * col("srvcs_d")).alias("sum_allowed_w"),
        sum_(col("avg_mdcr_pymt_amt") * col("srvcs_d")).alias("sum_payment_w"),
        sum_(pow(col("avg_sbmtd_chrg"), 2) * col("srvcs_d")).alias("sum_submitted_sq_w"),
        countDistinct("hcpcs_cd").alias("num_unique_procedures"),
        avg(col("is_drug")).alias("frac_drug_services"),
        avg(col("missing_zip")).alias("frac_missing_zip")
    )
)

provider_agg = (
    agg
    # weighted means (already in your code)
    .withColumn("w_avg_submitted_charge",
        when(col("total_services") > 0, col("sum_submitted_w") / col("total_services")))
    .withColumn("w_avg_allowed",
        when(col("total_services") > 0, col("sum_allowed_w") / col("total_services")))
    .withColumn("w_avg_payment",
        when(col("total_services") > 0, col("sum_payment_w") / col("total_services")))

    # NEW: weighted variance/stddev of submitted charge
    .withColumn("w_var_submitted_charge",
        when(col("total_services") > 0,
             (col("sum_submitted_sq_w") / col("total_services")) -
             pow(col("sum_submitted_w") / col("total_services"), 2)))
    .withColumn("w_stddev_submitted_charge",
        when(col("w_var_submitted_charge").isNotNull(),
             sqrt(greatest(lit(0.0), col("w_var_submitted_charge")))))

    # ratios (unchanged)
    .withColumn("charge_allowed_ratio",
        when(col("sum_allowed_w") > 0, col("sum_submitted_w") / col("sum_allowed_w")))
    .withColumn("payment_allowed_ratio",
        when(col("sum_allowed_w") > 0, col("sum_payment_w") / col("sum_allowed_w")))

    # drop helper sums now that we’ve derived features
    .drop("sum_submitted_sq_w", "sum_submitted_w", "sum_allowed_w", "sum_payment_w")

    # tidy up + your existing derived rates
    .withColumnRenamed("rndrng_npi", "npi")
    .withColumn("services_per_bene",
        when(col("total_beneficiaries") > 0, col("total_services") / col("total_beneficiaries")))
    .withColumn("bene_days_per_bene",
        when(col("total_beneficiaries") > 0, col("total_bene_day_services") / col("total_beneficiaries")))
)


print("PUF aggregated providers:", provider_agg.count())



PUF aggregated providers: 1175213


                                                                                

## 'List of Excluded Individuals and Entities' (LEIE) - OIG Dataset

In [31]:
# creating a csv path variable
leie_csv_path = "Office of Inspector General - Excluded Individuals and Entities/20250710 LEIE.csv"

In [32]:
# reading the CSV file into a DataFrame
full_leie_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(leie_csv_path)
)

In [33]:
# inspecting the column names
print("Available columns:")
for c in full_leie_df.columns[:]:
    print(repr(c))

Available columns:
'LASTNAME'
'FIRSTNAME'
'MIDNAME'
'BUSNAME'
'GENERAL'
'SPECIALTY'
'UPIN'
'NPI'
'DOB'
'ADDRESS'
'CITY'
'STATE'
'ZIP'
'EXCLTYPE'
'EXCLDATE'
'REINDATE'
'WAIVERDATE'
'WVRSTATE'


In [34]:
# selecting the relevant columns I want to keep
keep_cols_leie = ["NPI","LASTNAME","FIRSTNAME","MIDNAME","BUSNAME","ADDRESS","CITY","STATE",
                  "ZIP","SPECIALTY","EXCLTYPE","EXCLDATE","REINDATE","DOB","WAIVERDATE","WVRSTATE","UPIN"]

In [35]:
# creating a new DataFrame with only the selected columns
leie_df = full_leie_df.select(*keep_cols_leie)

In [36]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
leie_df.printSchema()
print("Sample:")
leie_df.limit(5).show(truncate=False)

Schema:
root
 |-- NPI: string (nullable = true)
 |-- LASTNAME: string (nullable = true)
 |-- FIRSTNAME: string (nullable = true)
 |-- MIDNAME: string (nullable = true)
 |-- BUSNAME: string (nullable = true)
 |-- ADDRESS: string (nullable = true)
 |-- CITY: string (nullable = true)
 |-- STATE: string (nullable = true)
 |-- ZIP: string (nullable = true)
 |-- SPECIALTY: string (nullable = true)
 |-- EXCLTYPE: string (nullable = true)
 |-- EXCLDATE: string (nullable = true)
 |-- REINDATE: string (nullable = true)
 |-- DOB: string (nullable = true)
 |-- WAIVERDATE: string (nullable = true)
 |-- WVRSTATE: string (nullable = true)
 |-- UPIN: string (nullable = true)

Sample:
+----------+--------+---------+-------+---------------------------+-----------------------------+----------+-----+-----+------------------+--------+--------+--------+----+----------+--------+----+
|NPI       |LASTNAME|FIRSTNAME|MIDNAME|BUSNAME                    |ADDRESS                      |CITY      |STATE|ZIP  |SPECIA

In [37]:
# getting the total number of rows in the DataFrame
total_rows_leie = leie_df.count()
print(f'Total rows: {total_rows_leie}')

Total rows: 81774


In [38]:
# getting the null counts for each column
leie_df.select(
    *[sum_(col(c).isNull().cast('int')).alias(c) for c in leie_df.columns]
).show()

+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+
|NPI|LASTNAME|FIRSTNAME|MIDNAME|BUSNAME|ADDRESS|CITY|STATE|ZIP|SPECIALTY|EXCLTYPE|EXCLDATE|REINDATE| DOB|WAIVERDATE|WVRSTATE| UPIN|
+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+
|  0|    3358|     3358|  24009|  78416|      4|   0|    0|  0|     4088|       0|       0|       0|4222|         0|   81762|75791|
+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+



In [39]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

leie_cols = [normalize_col(c) for c in keep_cols_leie]
if len(leie_cols) != len(set(leie_cols)):
    raise RuntimeError("Column name conflict after normalization in LEIE.")
leie = leie_df.select(*[col(o).alias(n) for o,n in zip(keep_cols_leie, leie_cols)])

In [40]:
# Trim, normalize keys, parse dates (guarded)
for c_, t_ in leie.dtypes:
    if t_ == "string":
        leie = leie.withColumn(c_, when(trim(col(c_)) == "", None).otherwise(trim(col(c_))))

leie = (
    leie
    .withColumn("npi", regexp_replace(col("npi"), r"[^0-9]", ""))
    .withColumn("npi_valid", when(col("npi").rlike(r"^(?!0{10})\d{10}$"), 1).otherwise(0).cast(IntegerType()))
    .withColumn("zip5", substring(regexp_replace(col("zip"), r"[^0-9]", ""), 1, 5))
    .drop("zip")
    .withColumn("state", when(col("state").rlike(r"^[A-Za-z]{2}$"), upper(col("state"))))
    .withColumnRenamed("wvrstate", "waiverstate")
    .withColumn("waiverstate", when(col("waiverstate").rlike(r"^[A-Za-z]{2}$"), upper(col("waiverstate"))))
    .withColumn("excltype", upper(col("excltype")))
    .withColumn("specialty", upper(col("specialty")))
)

def zeroish(colname):
    digits = regexp_replace(col(colname), r"[^0-9]", "")
    return when((digits == "00000000") | (length(digits) == 0), lit(None)).otherwise(col(colname))

def parse_date_guarded(colname: str):
    raw = zeroish(colname)
    digits = regexp_replace(raw, r"[^0-9]", "")
    ymd8 = when(digits.rlike(r"^\d{8}$"), to_date(digits, "yyyyMMdd"))
    mdyyyy = when(raw.rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"), to_date(raw, "M/d/yyyy"))
    ymd_dash = when(raw.rlike(r"^\d{4}-\d{2}-\d{2}$"), to_date(raw, "yyyy-MM-dd"))
    return coalesce(ymd8, mdyyyy, ymd_dash)

leie = (
    leie
    .withColumn("excldate_dt",   parse_date_guarded("excldate"))
    .withColumn("reindate_dt",   parse_date_guarded("reindate"))
    .withColumn("waiverdate_dt", parse_date_guarded("waiverdate"))
    .withColumn("dob_dt",        parse_date_guarded("dob"))
    .drop("excldate","reindate","waiverdate","dob")
)

In [41]:
# As-of filter: only use exclusions effective <= AS_OF_STR to avoid leakage
AS_OF = to_date(lit(AS_OF_STR))
leie = leie.filter(col("excldate_dt").isNull() | (col("excldate_dt") <= AS_OF))

In [None]:
# --- LEIE: episode-aware status as-of AS_OF (build-only; no join here) ---

leie_with_npi = leie.filter(col("npi_valid") == 1)

leie_status = (
    leie_with_npi
    .groupBy("npi")
    .agg(
        max(when(col("excldate_dt") <= AS_OF, col("excldate_dt"))).alias("last_excl_pre"),
        max(when(col("reindate_dt") <= AS_OF, col("reindate_dt"))).alias("last_rein_pre"),
        sum_(when(col("excldate_dt") <= AS_OF, lit(1)).otherwise(lit(0))).cast("int").alias("excl_events_pre")
    )
    .withColumn("ever_excluded_asof", when(col("excl_events_pre") > 0, lit(1)).otherwise(lit(0)))
    .withColumn(
        "excluded_asof",
        when(
            col("last_excl_pre").isNotNull() &
            (col("last_rein_pre").isNull() | (col("last_rein_pre") < col("last_excl_pre"))),
            lit(1)
        ).otherwise(lit(0))
    )
    .select("npi","ever_excluded_asof","excluded_asof")
)

print("LEIE status NPIs:", leie_status.count())


LEIE labeled NPIs: 6711


In [None]:
# -------------------------------
# JOINING (Option B: NPPES base)
# -------------------------------
from pyspark.sql.functions import broadcast

# Ensure we only keep providers enumerated by AS_OF (no future enumeration leakage)
AS_OF = to_date(lit(AS_OF_STR))
npi_feat = (
    wrangled_npi_df
    .filter( (col("provider_enumeration_date").isNull()) | (col("provider_enumeration_date") <= AS_OF) )
    .select(
        "npi","state_abbr","zip5","entity_type_code","primary_taxonomy",
        "npi_age_days","is_active","is_organization_subpart","is_sole_proprietor"
    )
    .distinct()
)

# NPPES as base; PUF features are optional
base = (
    npi_feat.alias("n")
    .join(provider_agg.alias("p"), on="npi", how="left")
    .withColumn("has_puf", when(col("total_services").isNotNull(), lit(1)).otherwise(lit(0)))
)

# Label join (explicit LEIE only; consider adding inferred later if you want more positives)
merged = (
    base.join(broadcast(leie_status.alias("l")), on="npi", how="left")
        .withColumn("is_fraud",         coalesce(col("ever_excluded_asof"), lit(0)).cast("int"))  # training target
        .withColumn("is_excluded_asof", coalesce(col("excluded_asof"),     lit(0)).cast("int"))  # for reporting/slices
)

# Fill PUF numeric columns to avoid nulls in the NN
num_cols = [
    "total_services","total_beneficiaries","total_bene_day_services",
    "w_avg_submitted_charge","w_avg_allowed","w_avg_payment",
    "charge_allowed_ratio","payment_allowed_ratio",
    "num_unique_procedures","w_stddev_submitted_charge",   # <- changed here
    "frac_drug_services","frac_missing_zip",
    "services_per_bene","bene_days_per_bene"
]


for c in [c for c in num_cols if c in merged.columns]:
    merged = merged.withColumn(c, coalesce(col(c), lit(0.0)))

print("Merged row sample (NPPES base):")
merged.select("npi","has_puf","is_fraud").limit(10).show(truncate=False)
print("Positives (distinct NPIs):", merged.filter(col("is_fraud")==1).select("npi").distinct().count())
print("NPIs with PUF features:", merged.filter(col("has_puf")==1).select("npi").distinct().count())

# --------------------------------
# SAMPLING (robust + builtins.max)
# --------------------------------
TARGET = 1_000_000  # adjust for your machine
pos_df = merged.filter(col("is_fraud") == 1).select("npi").distinct()
pos_ct = pos_df.count()
print("Positives (distinct NPIs):", pos_ct)

SAMPLE_F = 0.001
sample_est_val = merged.sample(False, SAMPLE_F, seed=42).count()
est_total = int(sample_est_val / SAMPLE_F)
est_neg = py.max(est_total - pos_ct, 1)
keep_neg = py.max(TARGET - pos_ct, 0)
frac_neg = py.min(1.0, float(keep_neg) / float(est_neg)) if est_neg > 0 else 0.0
print("Estimated total:", est_total, "Negatives est:", est_neg, "Neg keep:", keep_neg, "Neg fraction:", frac_neg)

if frac_neg <= 0:
    # If positives exceed target, keep a thin negative slice for calibration
    neg_slice = merged.filter(col("is_fraud")==0).sample(False, 0.02, seed=42)
    final_df = merged.filter(col("is_fraud")==1).unionByName(neg_slice.select(merged.columns))
else:
    neg_sample = merged.filter(col("is_fraud")==0).sample(False, frac_neg, seed=42)
    pos_full = pos_df.join(merged, on="npi", how="inner")
    final_df = pos_full.unionByName(neg_sample.select(merged.columns))


Merged row sample (NPPES base):


                                                                                

+----------+-------+--------+
|npi       |has_puf|is_fraud|
+----------+-------+--------+
|1003000134|1      |0       |
|1003000183|0      |0       |
|1003000191|0      |0       |
|1003000431|0      |0       |
|1003000548|0      |0       |
|1003000662|0      |0       |
|1003000738|1      |0       |
|1003000951|0      |0       |
|1003000977|0      |0       |
|1003001017|0      |0       |
+----------+-------+--------+



                                                                                

Positives (distinct NPIs): 6711


                                                                                

NPIs with PUF features: 1175210


                                                                                

Positives (distinct NPIs): 6711


                                                                                

Estimated total: 8172000 Negatives est: 8165289 Neg keep: 993289 Neg fraction: 0.12164774571971672


In [44]:
from pyspark import StorageLevel

final_df = final_df.repartition(8, "npi").persist(StorageLevel.MEMORY_AND_DISK)

print("Final class balance (preview):")
final_df.groupBy("is_fraud").count().show()

WRITE_OPTS = {
    "compression": "snappy",
    "maxRecordsPerFile": "750000",
    "parquet.block.size": str(64 * 1024 * 1024),
    "parquet.page.size":  str(1 * 1024 * 1024),
}
outpath = f"curated/training/providers_nn_asof_{AS_OF_STR}.parquet"
(final_df.write.mode("overwrite").options(**WRITE_OPTS).parquet(outpath))
print("Saved training table to:", outpath)


Final class balance (preview):


                                                                                

+--------+------+
|is_fraud| count|
+--------+------+
|       1|  6711|
|       0|988572|
+--------+------+



[Stage 134:>                                                        (0 + 8) / 8]

Saved training table to: curated/training/providers_nn_asof_2023-12-31.parquet


                                                                                

In [45]:
# 7) Cleanup big caches (be kind to RAM)
# -----------------------------
try:
    final_df.unpersist()
except Exception:
    pass

25/08/19 04:44:03 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 977797 ms exceeds timeout 120000 ms
25/08/19 04:44:03 WARN SparkContext: Killing executors is not supported by current scheduler.
25/08/19 04:59:19 WARN Executor: Issue communicating with driver in heartbeater
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:53)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:342)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:101)
	at org.apache.spark.rpc.RpcEndpointRef.askSync(RpcEndpointRef.scala:85)
	at org.apache.spark.storage.BlockManagerMaster.registerBlockManager(BlockManagerMaster.scala:81)
	at org.apache.spark.storage.BlockManager.reregister(BlockManager.scala:669)
	at org.apache.spark.executor.Executor.reportHeartBeat(Executor.scala:1296)
	at o