# Provider Fraud Detection Model - Data Wrangling
### In this notebook we will focus on collecting, organizing, defining, and cleaning the relevant datasets for the Provider Fraud Detection model
#### 1.) Ingesting & Inspecting Files
#### 2.) Standardizing Column Names & Types
#### 3.) Cleaning Values
#### 4.) Deduping (Removing Duplicates)
#### 5.) Merging/Joining Tables

In [1]:
# loading needed modules
import re
from collections import Counter
from functools import reduce

from pyspark.sql import SparkSession
from pyspark.sql.window import Window
from pyspark.sql.types import (
    StructType,
    StructField,
    StringType,
    IntegerType,
    DoubleType,
    DateType
)

from pyspark.sql.functions import (
    col,
    expr,
    lit,
    trim,
    to_date,
    current_date,
    datediff,
    sum as sum_,
    avg,
    stddev,
    countDistinct,
    when,
    try_divide,
    row_number,
    min,
    max,
    mean,
    isnan,
    isnull,
    regexp_replace,
    substring,
    upper,
    coalesce,
    sha2
)

In [2]:
# this function will help me normalize my column names to ensure a smooth workflow
# this function will...
    # lowercase, trim whitespace
    # drop parenthetical notes
    # replace non-alphanumeric chars with underscores
    # collapse runs of underscores
    # strip leading/trailing underscores

def normalize_col(column_name: str) -> str:
    text = column_name.strip().lower()
    text = re.sub(r'\(.*?\)', '', text)
    text = re.sub(r'[^0-9a-z]+', '_', text)
    text = re.sub(r'_+', '_', text)
    return text.strip('_')

In [3]:
# starting Spark session
spark = (
    SparkSession.builder
    .appName("ProviderFraudDetection")
    .master("local[*]")
    .config("spark.driver.memory", "12g")              # bump heap (use 8g–14g based on your RAM)
    .config("spark.sql.shuffle.partitions", "16")      # fewer, fatter tasks = less overhead
    .config("spark.sql.adaptive.enabled", "true")      # AQE helps with skew
    .config("spark.memory.fraction", "0.6")            # leave GC headroom
    .getOrCreate()
)

Using Spark's default log4j profile: org/apache/spark/log4j2-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
25/08/15 17:43:53 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


----------------------------------------
Exception occurred during processing of request from ('127.0.0.1', 52184)
Traceback (most recent call last):
  File "/opt/anaconda3/envs/ds-env/lib/python3.11/socketserver.py", line 317, in _handle_request_noblock
    self.process_request(request, client_address)
  File "/opt/anaconda3/envs/ds-env/lib/python3.11/socketserver.py", line 348, in process_request
    self.finish_request(request, client_address)
  File "/opt/anaconda3/envs/ds-env/lib/python3.11/socketserver.py", line 361, in finish_request
    self.RequestHandlerClass(request, client_address, self)
  File "/opt/anaconda3/envs/ds-env/lib/python3.11/socketserver.py", line 755, in __init__
    self.handle()
  File "/opt/anaconda3/envs/ds-env/lib/python3.11/site-packages/pyspark/accumulators.py", line 299, in handle
    poll(accum_updates)
  File "/opt/anaconda3/envs/ds-env/lib/python3.11/site-packages/pyspark/accumulators.py", line 271, in poll
    if self.rfile in r and func():
        

## NPPES NPI Registry Dataset

In [4]:
# creating a csv path variable
npi_csv_path = "NPPES_Data_Dissemination_July_2025_V2/npidata_pfile_20050523-20250713.csv"


In [5]:
# reading the CSV file into a DataFrame
full_npi_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(npi_csv_path)
)

In [6]:
# inspecting the npi dataset column names
print("Available columns (first 50 shown):")
for c in full_npi_df.columns[:50]:
    print(repr(c))
    

Available columns (first 50 shown):
'NPI'
'Entity Type Code'
'Replacement NPI'
'Employer Identification Number (EIN)'
'Provider Organization Name (Legal Business Name)'
'Provider Last Name (Legal Name)'
'Provider First Name'
'Provider Middle Name'
'Provider Name Prefix Text'
'Provider Name Suffix Text'
'Provider Credential Text'
'Provider Other Organization Name'
'Provider Other Organization Name Type Code'
'Provider Other Last Name'
'Provider Other First Name'
'Provider Other Middle Name'
'Provider Other Name Prefix Text'
'Provider Other Name Suffix Text'
'Provider Other Credential Text'
'Provider Other Last Name Type Code'
'Provider First Line Business Mailing Address'
'Provider Second Line Business Mailing Address'
'Provider Business Mailing Address City Name'
'Provider Business Mailing Address State Name'
'Provider Business Mailing Address Postal Code'
'Provider Business Mailing Address Country Code (If outside U.S.)'
'Provider Business Mailing Address Telephone Number'
'Provider B

In [None]:
# selecting the relevant columns I want to keep
tax_cols = []
for i in range(1, 16):
    tax_cols += [f"Healthcare Provider Taxonomy Code_{i}",
                 f"Healthcare Provider Primary Taxonomy Switch_{i}"]

keep_cols_npi = [
    "NPI",
    "Entity Type Code",

    # person/org name columns (add these so wrangled_npi_df has names)
    "Provider First Name",
    "Provider Last Name (Legal Name)",
    "Provider Middle Name",                     # optional
    "Provider Organization Name (Legal Business Name)",

    "Provider Business Practice Location Address State Name",
    "Provider Business Practice Location Address Postal Code",
    "Is Organization Subpart",
    "Parent Organization TIN",
    "Parent Organization LBN",
    "Provider Enumeration Date",
    "Last Update Date",
    "NPI Deactivation Date",
    "NPI Reactivation Date",
    "Is Sole Proprietor",
] + tax_cols


In [8]:
# creating a new DataFrame with only the selected columns
npi_df = full_npi_df.select(*keep_cols_npi)


In [9]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
npi_df.printSchema()
print("Sample:")
npi_df.limit(5).show(truncate=False)

Schema:
root
 |-- NPI: string (nullable = true)
 |-- Entity Type Code: string (nullable = true)
 |-- Provider Business Practice Location Address State Name: string (nullable = true)
 |-- Provider Business Practice Location Address Postal Code: string (nullable = true)
 |-- Is Organization Subpart: string (nullable = true)
 |-- Parent Organization TIN: string (nullable = true)
 |-- Parent Organization LBN: string (nullable = true)
 |-- Provider Enumeration Date: string (nullable = true)
 |-- Last Update Date: string (nullable = true)
 |-- NPI Deactivation Date: string (nullable = true)
 |-- NPI Reactivation Date: string (nullable = true)
 |-- Is Sole Proprietor: string (nullable = true)
 |-- Healthcare Provider Taxonomy Code_1: string (nullable = true)
 |-- Healthcare Provider Primary Taxonomy Switch_1: string (nullable = true)
 |-- Healthcare Provider Taxonomy Code_2: string (nullable = true)
 |-- Healthcare Provider Primary Taxonomy Switch_2: string (nullable = true)
 |-- Healthcare P

25/08/15 17:43:56 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


In [10]:
# getting the total number of rows in the DataFrame
total_rows = npi_df.count()
print(f'Total rows: {total_rows}')



Total rows: 9026996


                                                                                

In [11]:
# getting the null counts for each column
npi_df.select(
    *[
        sum_(col(c).isNull().cast('int')).alias(c) for c in npi_df.columns
    ]
).show()



+---+----------------+------------------------------------------------------+-------------------------------------------------------+-----------------------+-----------------------+-----------------------+-------------------------+----------------+---------------------+---------------------+------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+---------------------------------

                                                                                

In [12]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

# applying the normalized names to every column I plan to keep
cleaned_names = []
for original_name in keep_cols_npi:
    new_name = normalize_col(original_name)
    cleaned_names.append(new_name)

# counting how many times each cleaned name shows up
name_counts = Counter(cleaned_names)

# pulling out any names that show up more than once
duplicates = [name for name, count in name_counts.items() if count > 1]

# bailing out early, if a coliision is detected
if duplicates:
    raise RuntimeError(f'Column name conflict after normalization: {duplicates}')

In [13]:
# building a new DataFrame with normalized column names

# building a list of expressions to rename the columns in the DataFrame
renamed_columns = [
    col(old).alias(new)
    for old, new in zip(keep_cols_npi, cleaned_names)
]

# applying 
normalized_npi_df = npi_df.select(*renamed_columns)

In [14]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
normalized_npi_df.printSchema()
normalized_npi_df.show(5, truncate=False)

root
 |-- npi: string (nullable = true)
 |-- entity_type_code: string (nullable = true)
 |-- provider_business_practice_location_address_state_name: string (nullable = true)
 |-- provider_business_practice_location_address_postal_code: string (nullable = true)
 |-- is_organization_subpart: string (nullable = true)
 |-- parent_organization_tin: string (nullable = true)
 |-- parent_organization_lbn: string (nullable = true)
 |-- provider_enumeration_date: string (nullable = true)
 |-- last_update_date: string (nullable = true)
 |-- npi_deactivation_date: string (nullable = true)
 |-- npi_reactivation_date: string (nullable = true)
 |-- is_sole_proprietor: string (nullable = true)
 |-- healthcare_provider_taxonomy_code_1: string (nullable = true)
 |-- healthcare_provider_primary_taxonomy_switch_1: string (nullable = true)
 |-- healthcare_provider_taxonomy_code_2: string (nullable = true)
 |-- healthcare_provider_primary_taxonomy_switch_2: string (nullable = true)
 |-- healthcare_provider_

In [15]:
# casting and standardizing the data types
clean = (normalized_npi_df
         
    # trimming the stray spaces so joins/comparisons work correctly
    .withColumn("npi", trim(col("npi")))
    .withColumn("npi", regexp_replace(col("npi"), r"\D", ""))
    .withColumn("entity_type_code", col("entity_type_code").cast("int"))

    # mapping yes/no to 1/0
    .withColumn("is_organization_subpart", when(col("is_organization_subpart") == "Y", 1).otherwise(0))

    # is_sole_proprietor: only meaningful for individuals (entity_type_code==1)
    # Y->1, N->0, anything else (incl. X) -> null; orgs -> null
    .withColumn("is_sole_proprietor_raw", upper(trim(col("is_sole_proprietor"))))
    .withColumn(
        "is_sole_proprietor",
        when(col("entity_type_code") == 1,
             when(col("is_sole_proprietor_raw") == "Y", 1)
            .when(col("is_sole_proprietor_raw") == "N", 0)
            .otherwise(None)
        ).otherwise(None)
    ).drop("is_sole_proprietor_raw")

    # parse dates strings into real Date columns
    .withColumn("provider_enumeration_date",
        when(col("provider_enumeration_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("provider_enumeration_date"), "M/d/yyyy"))
    )
    .withColumn("last_update_date",
        when(col("last_update_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("last_update_date"), "M/d/yyyy"))
    )
    .withColumn("npi_deactivation_date",
        when(col("npi_deactivation_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("npi_deactivation_date"), "M/d/yyyy"))
    )
    .withColumn("npi_reactivation_date",
        when(col("npi_reactivation_date").rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"),
             to_date(col("npi_reactivation_date"), "M/d/yyyy"))
    )
)

In [16]:
# normalizing ZIP & TIN

# creating a zip5 with just the first 5 digits of the postal code column
clean = clean.withColumn(
    "zip5",
    substring(regexp_replace(col("provider_business_practice_location_address_postal_code"), r"[^0-9]", ""), 1, 5)
)

# normalizing state names to 2-letter codes
clean = clean.withColumn(
    "state_abbr",
    when(col("provider_business_practice_location_address_state_name").rlike(r"^[A-Za-z]{2}$"),
         upper(col("provider_business_practice_location_address_state_name"))
    )
    .otherwise(None)   # optional: keep None if it's a full name or messy
)

# stripping non-digits from TIN for consistent grouping
clean = (clean
    .withColumn("parent_org_tin_norm", regexp_replace(col("parent_organization_tin"), r"\D", ""))
)


In [17]:
# primary taxonomy extraction with nuance

# NPPES has 15 slots so building an expression that picks the code where the corresponding switch is "Y"
explicit_expr = None
for i in range(1, 16):
    sw = f"healthcare_provider_primary_taxonomy_switch_{i}"
    cd = f"healthcare_provider_taxonomy_code_{i}"
    condY = (col(sw) == "Y") & col(cd).isNotNull()
    explicit_expr = when(condY, col(cd)) if explicit_expr is None else explicit_expr.when(condY, col(cd))

# first non-null taxonomy code across all slots (for fallback)
code_cols = [col(f"healthcare_provider_taxonomy_code_{i}") for i in range(1, 16)]
first_code = coalesce(*code_cols)

# any 'X' present across switches? (for your "unknown → fallback" behavior)
any_x = None
for i in range(1, 16):
    sw = f"healthcare_provider_primary_taxonomy_switch_{i}"
    condX = (col(sw) == "X")
    any_x = condX if any_x is None else (any_x | condX)

clean = (clean
    .withColumn(
        "primary_taxonomy",
        when(explicit_expr.isNotNull(), explicit_expr)
        .when(any_x & first_code.isNotNull(), first_code)   # your fallback rule
        .otherwise(lit(None))
    )
    .withColumn(
        "primary_taxonomy_explicit",
        when(explicit_expr.isNotNull(), lit(1)).otherwise(lit(0))
    )
    .withColumn(
        "primary_taxonomy_unknown",
        when(explicit_expr.isNull() & any_x, lit(1)).otherwise(lit(0))
    )
    .withColumn(
        "primary_taxonomy_source",
        when(explicit_expr.isNotNull(), lit("explicit"))
        .when(any_x & first_code.isNotNull(), lit("fallback"))
        .otherwise(lit(None))
    )
)


In [18]:
# lifecycle / derived features
clean = (clean
    # this will show how many days since the Provider first got their NPI
    .withColumn(
        "npi_age_days", datediff(current_date(), col("provider_enumeration_date")))

    # this will show if the NPI is 'active' today...
    # if there's no deactivation date -> it's active
    # or if the Provider was deactivated and then reactivated after the deactivation date -> it's active
    .withColumn(
        "is_active",
        when(
            (col("npi_deactivation_date").isNull()) |
            ((col("npi_reactivation_date").isNotNull()) & (col("npi_reactivation_date") >= col("npi_deactivation_date"))),
            lit(1)
        ).otherwise(lit(0))
    )

    # this will show if a Provider was deactivated and then reactivated at some point
    .withColumn("was_reactivated", when(col("npi_reactivation_date").isNotNull(), 1).otherwise(0))
    .withColumn(
        "deactivated_then_reactivated",
        when((col("npi_deactivation_date").isNotNull()) & (col("npi_reactivation_date").isNotNull()), 1).otherwise(0)
    )

    # this will show if there are any location info (ex. if state field is not null)
    .withColumn(
        "has_location",
        when(col("state_abbr").isNotNull(), 1).otherwise(0)
)

    # this will show if the entity type code is missing
    .withColumn(
        "missing_entity_type",
        when(col("entity_type_code").isNull(), 1).otherwise(0)
    )
)





In [19]:
# building a fallback 'any taxonomy code' for completeness scoring
code_any = None
for i in range(1, 16):
    cd = f"healthcare_provider_taxonomy_code_{i}"
    if cd in clean.columns:
        code_any = col(cd) if code_any is None else coalesce(code_any, col(cd))
if code_any is None:
    code_any = lit(None)

# computing completeness_score
clean = clean.withColumn(
    "completeness_score",
    (when(coalesce(col("primary_taxonomy"), code_any).isNotNull(), 1).otherwise(0)
     + when(col("state_abbr").isNotNull(), 1).otherwise(0)   
     + when(col("provider_enumeration_date").isNotNull(), 1).otherwise(0)
     + when(col("entity_type_code").isNotNull(), 1).otherwise(0))
)

In [20]:
# validating the NPIs, they should be 10 digits long
npi_regex = r'^\d{10}$'

# counting how many NPIs are malformed
invalid_npi_count = (
    clean.filter(~col("npi").rlike(npi_regex))
    .count()
)

print(f"Invalid NPIs found: {invalid_npi_count}")



Invalid NPIs found: 0


                                                                                

In [21]:
# deduplicating based off of latest "last_update_date" and then highest completeness score
w = Window.partitionBy("npi").orderBy(col("last_update_date").desc_nulls_last(), col("completeness_score").desc())
wrangled_npi_df = (
    clean
    .withColumn("rn", row_number().over(w))
    .filter(col("rn") == 1)
    .drop("rn")
)

In [22]:
# comparing to make sure I have all my changes

print(sorted(clean.columns))
print(sorted(wrangled_npi_df.columns))

['completeness_score', 'deactivated_then_reactivated', 'entity_type_code', 'has_location', 'healthcare_provider_primary_taxonomy_switch_1', 'healthcare_provider_primary_taxonomy_switch_10', 'healthcare_provider_primary_taxonomy_switch_11', 'healthcare_provider_primary_taxonomy_switch_12', 'healthcare_provider_primary_taxonomy_switch_13', 'healthcare_provider_primary_taxonomy_switch_14', 'healthcare_provider_primary_taxonomy_switch_15', 'healthcare_provider_primary_taxonomy_switch_2', 'healthcare_provider_primary_taxonomy_switch_3', 'healthcare_provider_primary_taxonomy_switch_4', 'healthcare_provider_primary_taxonomy_switch_5', 'healthcare_provider_primary_taxonomy_switch_6', 'healthcare_provider_primary_taxonomy_switch_7', 'healthcare_provider_primary_taxonomy_switch_8', 'healthcare_provider_primary_taxonomy_switch_9', 'healthcare_provider_taxonomy_code_1', 'healthcare_provider_taxonomy_code_10', 'healthcare_provider_taxonomy_code_11', 'healthcare_provider_taxonomy_code_12', 'healthca

In [23]:
# doing sanity checks after cleaning my NPI dataframe values
print("Distinct is_sole_proprietor values:")
wrangled_npi_df.select("is_sole_proprietor").distinct().show()


print("Primary taxonomy source distinct values:")
wrangled_npi_df.select("primary_taxonomy_source").distinct().show()


print("Null counts after transformation:")
nulls = wrangled_npi_df.select(*[
    sum_(when(col(c).isNull(), 1).otherwise(0)).alias(c)
    for c in wrangled_npi_df.columns
])
nulls.show(truncate=False)

Distinct is_sole_proprietor values:


                                                                                

+------------------+
|is_sole_proprietor|
+------------------+
|                 1|
|                 0|
|              NULL|
+------------------+

Primary taxonomy source distinct values:


                                                                                

+-----------------------+
|primary_taxonomy_source|
+-----------------------+
|               explicit|
|               fallback|
|                   NULL|
+-----------------------+

Null counts after transformation:




+---+----------------+------------------------------------------------------+-------------------------------------------------------+-----------------------+-----------------------+-----------------------+-------------------------+----------------+---------------------+---------------------+------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+-----------------------------------+---------------------------------------------+---------------------------------

                                                                                

#### The NPPES NPI dataset has almost been fully wrangled as "wrangled_npi_df"
#### Stopping here to wrangle others...

## Medicare Physician & Other Practitioners - by Provider and Service Dataset

In [24]:
# creating a csv path variable
phys_pract_csv_path = "Medicare Physician & Other Practitioners - by Provider and Service/MUP_PHY_R25_P05_V20_D23_Prov_Svc.csv"

In [25]:
# reading the CSV file into a DataFrame
full_physician_practitioner_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(phys_pract_csv_path)
)

In [26]:
# inspecting the column names
print("Available columns (first 50 shown):")
for c in full_physician_practitioner_df.columns[:50]:
    print(repr(c))

Available columns (first 50 shown):
'Rndrng_NPI'
'Rndrng_Prvdr_Last_Org_Name'
'Rndrng_Prvdr_First_Name'
'Rndrng_Prvdr_MI'
'Rndrng_Prvdr_Crdntls'
'Rndrng_Prvdr_Ent_Cd'
'Rndrng_Prvdr_St1'
'Rndrng_Prvdr_St2'
'Rndrng_Prvdr_City'
'Rndrng_Prvdr_State_Abrvtn'
'Rndrng_Prvdr_State_FIPS'
'Rndrng_Prvdr_Zip5'
'Rndrng_Prvdr_RUCA'
'Rndrng_Prvdr_RUCA_Desc'
'Rndrng_Prvdr_Cntry'
'Rndrng_Prvdr_Type'
'Rndrng_Prvdr_Mdcr_Prtcptg_Ind'
'HCPCS_Cd'
'HCPCS_Desc'
'HCPCS_Drug_Ind'
'Place_Of_Srvc'
'Tot_Benes'
'Tot_Srvcs'
'Tot_Bene_Day_Srvcs'
'Avg_Sbmtd_Chrg'
'Avg_Mdcr_Alowd_Amt'
'Avg_Mdcr_Pymt_Amt'
'Avg_Mdcr_Stdzd_Amt'


In [27]:
# selecting the relevant columns I want to keep
keep_cols_physician_practitioners = [
    "Rndrng_NPI",                 
    "Rndrng_Prvdr_Ent_Cd",        
    "Rndrng_Prvdr_State_Abrvtn",  
    "Rndrng_Prvdr_Zip5",          
    "Rndrng_Prvdr_RUCA",          
    "Rndrng_Prvdr_RUCA_Desc",     
    "Rndrng_Prvdr_Mdcr_Prtcptg_Ind", 
    "Rndrng_Prvdr_Type",
    "Rndrng_Prvdr_Cntry",

    "HCPCS_Cd",                   
    "HCPCS_Desc",                 
    "HCPCS_Drug_Ind",             
    "Place_Of_Srvc",              
    "Tot_Benes",                  
    "Tot_Srvcs",                  
    "Tot_Bene_Day_Srvcs",         

    "Avg_Sbmtd_Chrg",             
    "Avg_Mdcr_Alowd_Amt",        
    "Avg_Mdcr_Pymt_Amt",          
    "Avg_Mdcr_Stdzd_Amt"        
]

In [28]:
# creating a new DataFrame with only the selected columns
phys_pract_df = full_physician_practitioner_df.select(*keep_cols_physician_practitioners)

In [29]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
phys_pract_df.printSchema()
print("Sample:")
phys_pract_df.limit(5).show(truncate=False)

Schema:
root
 |-- Rndrng_NPI: string (nullable = true)
 |-- Rndrng_Prvdr_Ent_Cd: string (nullable = true)
 |-- Rndrng_Prvdr_State_Abrvtn: string (nullable = true)
 |-- Rndrng_Prvdr_Zip5: string (nullable = true)
 |-- Rndrng_Prvdr_RUCA: string (nullable = true)
 |-- Rndrng_Prvdr_RUCA_Desc: string (nullable = true)
 |-- Rndrng_Prvdr_Mdcr_Prtcptg_Ind: string (nullable = true)
 |-- Rndrng_Prvdr_Type: string (nullable = true)
 |-- Rndrng_Prvdr_Cntry: string (nullable = true)
 |-- HCPCS_Cd: string (nullable = true)
 |-- HCPCS_Desc: string (nullable = true)
 |-- HCPCS_Drug_Ind: string (nullable = true)
 |-- Place_Of_Srvc: string (nullable = true)
 |-- Tot_Benes: string (nullable = true)
 |-- Tot_Srvcs: string (nullable = true)
 |-- Tot_Bene_Day_Srvcs: string (nullable = true)
 |-- Avg_Sbmtd_Chrg: string (nullable = true)
 |-- Avg_Mdcr_Alowd_Amt: string (nullable = true)
 |-- Avg_Mdcr_Pymt_Amt: string (nullable = true)
 |-- Avg_Mdcr_Stdzd_Amt: string (nullable = true)

Sample:
+----------+----

In [30]:
# getting the total number of rows in the DataFrame
total_rows_phys_pract = phys_pract_df.count()
print(f'Total rows: {total_rows_phys_pract}')



Total rows: 9660647


                                                                                

In [31]:
# getting the null counts for each column
phys_pract_df.select(
    *[
        sum_(col(c).isNull().cast('int')).alias(c) for c in phys_pract_df.columns
    ]
).show()



+----------+-------------------+-------------------------+-----------------+-----------------+----------------------+-----------------------------+-----------------+------------------+--------+----------+--------------+-------------+---------+---------+------------------+--------------+------------------+-----------------+------------------+
|Rndrng_NPI|Rndrng_Prvdr_Ent_Cd|Rndrng_Prvdr_State_Abrvtn|Rndrng_Prvdr_Zip5|Rndrng_Prvdr_RUCA|Rndrng_Prvdr_RUCA_Desc|Rndrng_Prvdr_Mdcr_Prtcptg_Ind|Rndrng_Prvdr_Type|Rndrng_Prvdr_Cntry|HCPCS_Cd|HCPCS_Desc|HCPCS_Drug_Ind|Place_Of_Srvc|Tot_Benes|Tot_Srvcs|Tot_Bene_Day_Srvcs|Avg_Sbmtd_Chrg|Avg_Mdcr_Alowd_Amt|Avg_Mdcr_Pymt_Amt|Avg_Mdcr_Stdzd_Amt|
+----------+-------------------+-------------------------+-----------------+-----------------+----------------------+-----------------------------+-----------------+------------------+--------+----------+--------------+-------------+---------+---------+------------------+--------------+------------------+------

                                                                                

In [32]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

# applying the normalized names to every column I plan to keep
normalized_names = []
for original_name in keep_cols_physician_practitioners:
    new_name = normalize_col(original_name)
    normalized_names.append(new_name)

# counting how many times each cleaned name shows up
name_counts = Counter(normalized_names)

# pulling out any names that show up more than once
duplicates = [name for name, count in name_counts.items() if count > 1]

# bailing out early, if a coliision is detected
if duplicates:
    raise RuntimeError(f'Column name conflict after normalization: {duplicates}')

In [33]:
# building a new DataFrame with normalized column names

# building a list of expressions to rename the columns in the DataFrame
renamed_columns_phys_pract = [
    col(old).alias(new)
    for old, new in zip(keep_cols_physician_practitioners, normalized_names)
]

# applying 
normalized_phys_pract_df = phys_pract_df.select(*renamed_columns_phys_pract)

In [34]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
normalized_phys_pract_df.printSchema()
normalized_phys_pract_df.show(5, truncate=False)

root
 |-- rndrng_npi: string (nullable = true)
 |-- rndrng_prvdr_ent_cd: string (nullable = true)
 |-- rndrng_prvdr_state_abrvtn: string (nullable = true)
 |-- rndrng_prvdr_zip5: string (nullable = true)
 |-- rndrng_prvdr_ruca: string (nullable = true)
 |-- rndrng_prvdr_ruca_desc: string (nullable = true)
 |-- rndrng_prvdr_mdcr_prtcptg_ind: string (nullable = true)
 |-- rndrng_prvdr_type: string (nullable = true)
 |-- rndrng_prvdr_cntry: string (nullable = true)
 |-- hcpcs_cd: string (nullable = true)
 |-- hcpcs_desc: string (nullable = true)
 |-- hcpcs_drug_ind: string (nullable = true)
 |-- place_of_srvc: string (nullable = true)
 |-- tot_benes: string (nullable = true)
 |-- tot_srvcs: string (nullable = true)
 |-- tot_bene_day_srvcs: string (nullable = true)
 |-- avg_sbmtd_chrg: string (nullable = true)
 |-- avg_mdcr_alowd_amt: string (nullable = true)
 |-- avg_mdcr_pymt_amt: string (nullable = true)
 |-- avg_mdcr_stdzd_amt: string (nullable = true)

+----------+-------------------+

In [35]:
# cleaning the values of the physician practitioner DataFrame

clean_phys_pract = (
    normalized_phys_pract_df

    # trim stray spaces and making sure the NPIs are exactly 10 digits
    .withColumn("rndrng_npi", trim(col("rndrng_npi")))
    .withColumn("npi_valid",
        when(col("rndrng_npi").rlike(r"^\d{10}$"), lit(1))
        .otherwise(lit(0))
    )

    # mapping the provider type: 
        # “I” → individual (1), everything else → org (0)
    .withColumn("is_individual",
        when(col("rndrng_prvdr_ent_cd") == "I", lit(1))
        .otherwise(lit(0))
    )

    # tidying up geographic strings
    .withColumn("rndrng_prvdr_state_abrvtn", trim(col("rndrng_prvdr_state_abrvtn")))
    .withColumn("rndrng_prvdr_zip5",          trim(col("rndrng_prvdr_zip5")))

    # flagging if the state is missing
    .withColumn("missing_state",
        when(col("rndrng_prvdr_state_abrvtn").isNull(), lit(1)).otherwise(lit(0))
    )

    # flagging if the zip code is missing
    .withColumn("missing_zip",
        when(col("rndrng_prvdr_zip5").isNull(), lit(1)).otherwise(lit(0))
    )

    # casting RUCA (rural‐urban code) to double type
    .withColumn("rndrng_prvdr_ruca", col("rndrng_prvdr_ruca").cast(DoubleType())
    )

    # medicare‐participation Y/N → 1/0
    .withColumn("medicare_participation",
        when(col("rndrng_prvdr_mdcr_prtcptg_ind") == "Y", lit(1))
        .otherwise(lit(0))
    )

    # clean HCPCS and place‐of‐service keys
    .withColumn("hcpcs_cd",     trim(col("hcpcs_cd")))
    .withColumn("place_of_srvc", trim(col("place_of_srvc")))


    # cast dollar amounts so ratios and averages behave
    .withColumn("avg_sbmtd_chrg",    col("avg_sbmtd_chrg").cast(DoubleType()))
    .withColumn("avg_mdcr_alowd_amt",col("avg_mdcr_alowd_amt").cast(DoubleType()))
    .withColumn("avg_mdcr_pymt_amt", col("avg_mdcr_pymt_amt").cast(DoubleType()))
    .withColumn("avg_mdcr_stdzd_amt",col("avg_mdcr_stdzd_amt").cast(DoubleType()))

    # drug flag: “Y” → 1, else → 0
    .withColumn("is_drug",
        when(col("hcpcs_drug_ind") == "Y", lit(1)).otherwise(lit(0))
    )
)

In [36]:
# keeping only United States rows
clean_phys_pract = (clean_phys_pract
    .withColumn("rndrng_prvdr_cntry", upper(trim(col("rndrng_prvdr_cntry"))))
    .filter(col("rndrng_prvdr_cntry").isNull() | col("rndrng_prvdr_cntry").isin("US","USA"))
)

In [37]:
# checking to make sure my cleaning worked
clean_phys_pract.select("hcpcs_drug_ind", "is_drug").distinct().show()



+--------------+-------+
|hcpcs_drug_ind|is_drug|
+--------------+-------+
|             Y|      1|
|             N|      0|
+--------------+-------+



                                                                                

In [38]:
# few other santity checks

# no “bad” NPIs should pass
clean_phys_pract.filter(col("npi_valid") == 0).count()

# ensuring dollar‐amount columns have no unexpected negatives or huge outliers
clean_phys_pract.select(
  min("avg_mdcr_alowd_amt"),
  max("avg_mdcr_alowd_amt"),
  mean("avg_mdcr_alowd_amt")
).show()

# checking that my “missing” flags line up with nulls
clean_phys_pract.filter(col("missing_zip") == 1).select("rndrng_prvdr_zip5").show()


                                                                                

+-----------------------+-----------------------+-----------------------+
|min(avg_mdcr_alowd_amt)|max(avg_mdcr_alowd_amt)|avg(avg_mdcr_alowd_amt)|
+-----------------------+-----------------------+-----------------------+
|                    0.0|           52794.040667|       106.107469676975|
+-----------------------+-----------------------+-----------------------+





+-----------------+
|rndrng_prvdr_zip5|
+-----------------+
+-----------------+



                                                                                

##### Because of the nature of this dataset, I will not be doing a traditional deduping process
##### This dataset is SUPPOSED to have multiple lines per provider
##### I will be checking for true duplicates

In [39]:
# safe-casting to make sure my volume columns are really integers
safe_phys = (
    clean_phys_pract
    .withColumn("tot_benes", expr("try_cast(tot_benes AS INT)")) 
    .withColumn("tot_srvcs", expr("try_cast(tot_srvcs AS INT)")) 
    .withColumn("tot_bene_day_srvcs", expr("try_cast(tot_bene_day_srvcs AS INT)"))
)

# checking for exact duplicates across the key combinations
    # (provider + procedure + place + drug flag)
total_rows = safe_phys.count()

unique_rows = (
    safe_phys 
    .dropDuplicates([
        "rndrng_npi",
        "hcpcs_cd",
        "place_of_srvc",
        "is_drug"
    ]) 
    .count()
)

duplicate_rows = total_rows - unique_rows


print(f"Total rows:{total_rows:,}")
print(f"Unique prov×proc×place×drug:{unique_rows:,}")
print(f"Exact duplicate rows:{duplicate_rows:,}")




Total rows:9,660,252
Unique prov×proc×place×drug:9,660,252
Exact duplicate rows:0


                                                                                

In [40]:
# aggregating by provider for modeling purposes
# more reasoning -> the GNN model will be using the provider as a node, so I want to aggregate all the features by provider

# making sure my per-row volumes are truly DOUBLE
clean_for_agg = (
    clean_phys_pract
    .filter(col("npi_valid") == 1)   # <<< ADD THIS
    .withColumn("srvcs_d", col("tot_srvcs").cast("double"))
    .withColumn("benes_d", col("tot_benes").cast("double"))
    .withColumn("bene_days_d", col("tot_bene_day_srvcs").cast("double"))
)

clean_for_agg.count()

# grouping by NPI and rolling up all the per-procedure metrics into provider totals/averages

# aggregate weighted sums
agg_w = (
    clean_for_agg
    .groupBy("rndrng_npi")
    .agg(
        # volumes
        sum_("srvcs_d").alias("total_services"),
        sum_("benes_d").alias("total_beneficiaries"),
        sum_("bene_days_d").alias("total_bene_day_services"),

        # service-weighted sums for charges/allowed/payments
        sum_(col("avg_sbmtd_chrg") * col("srvcs_d")).alias("sum_submitted_w"),
        sum_(col("avg_mdcr_alowd_amt") * col("srvcs_d")).alias("sum_allowed_w"),
        sum_(col("avg_mdcr_pymt_amt") * col("srvcs_d")).alias("sum_payment_w"),

        # diversity & quality signals
        countDistinct("hcpcs_cd").alias("num_unique_procedures"),
        stddev(col("avg_sbmtd_chrg")).alias("stddev_submitted_charge"),
        avg(col("is_drug")).alias("frac_drug_services"),
        avg(col("missing_zip")).alias("frac_missing_zip")
    )
)

# deriving weighted averages & ratios of totals
provider_agg = (
    agg_w
    .withColumn("w_avg_submitted_charge",
                when(col("total_services") > 0, col("sum_submitted_w") / col("total_services")))
    .withColumn("w_avg_allowed",
                when(col("total_services") > 0, col("sum_allowed_w") / col("total_services")))
    .withColumn("w_avg_payment",
                when(col("total_services") > 0, col("sum_payment_w") / col("total_services")))
    .withColumn("charge_allowed_ratio",
                when(col("sum_allowed_w") > 0, col("sum_submitted_w") / col("sum_allowed_w")))
    .withColumn("payment_allowed_ratio",
                when(col("sum_allowed_w") > 0, col("sum_payment_w") / col("sum_allowed_w")))
    .drop("sum_submitted_w", "sum_allowed_w", "sum_payment_w")
    .withColumnRenamed("rndrng_npi", "npi")  # unify key name for joins later
)


                                                                                

In [41]:
# adding some derived features to provider_agg
provider_agg = (provider_agg
    .withColumn("services_per_bene",
        when(col("total_beneficiaries") > 0,
             col("total_services") / col("total_beneficiaries")))
    .withColumn("bene_days_per_bene",
        when(col("total_beneficiaries") > 0,
             col("total_bene_day_services") / col("total_beneficiaries")))
)

In [42]:
# doing a sanity check on the aggregation
provider_agg.printSchema()
provider_agg.show(5, truncate=False)

root
 |-- npi: string (nullable = true)
 |-- total_services: double (nullable = true)
 |-- total_beneficiaries: double (nullable = true)
 |-- total_bene_day_services: double (nullable = true)
 |-- num_unique_procedures: long (nullable = false)
 |-- stddev_submitted_charge: double (nullable = true)
 |-- frac_drug_services: double (nullable = true)
 |-- frac_missing_zip: double (nullable = true)
 |-- w_avg_submitted_charge: double (nullable = true)
 |-- w_avg_allowed: double (nullable = true)
 |-- w_avg_payment: double (nullable = true)
 |-- charge_allowed_ratio: double (nullable = true)
 |-- payment_allowed_ratio: double (nullable = true)
 |-- services_per_bene: double (nullable = true)
 |-- bene_days_per_bene: double (nullable = true)





+----------+--------------+-------------------+-----------------------+---------------------+-----------------------+-------------------+----------------+----------------------+------------------+------------------+--------------------+---------------------+------------------+------------------+
|npi       |total_services|total_beneficiaries|total_bene_day_services|num_unique_procedures|stddev_submitted_charge|frac_drug_services |frac_missing_zip|w_avg_submitted_charge|w_avg_allowed     |w_avg_payment     |charge_allowed_ratio|payment_allowed_ratio|services_per_bene |bene_days_per_bene|
+----------+--------------+-------------------+-----------------------+---------------------+-----------------------+-------------------+----------------+----------------------+------------------+------------------+--------------------+---------------------+------------------+------------------+
|1003002254|2433.0        |788.0              |880.0                  |8                    |84.4099751660698

                                                                                

In [43]:
# checking the total after the aggregation
total_providers = provider_agg.count()
print(f"Total providers after aggregation: {total_providers:,}")



Total providers after aggregation: 1,175,213


                                                                                

In [44]:
# doing some more sanity checks


# checking the ranges
print("=== Range summary ===")
provider_agg.select(
    min("total_services").alias("min_services"),
    max("total_services").alias("max_services"),
    min("total_beneficiaries").alias("min_beneficiaries"),
    max("total_beneficiaries").alias("max_beneficiaries"),
    min("total_bene_day_services").alias("min_bene_day_services"),
    max("total_bene_day_services").alias("max_bene_day_services"),
    min("charge_allowed_ratio").alias("min_charge_ratio"),
    max("charge_allowed_ratio").alias("max_charge_ratio"),
    min("payment_allowed_ratio").alias("min_payment_ratio"),
    max("payment_allowed_ratio").alias("max_payment_ratio")
).show(truncate=False)

# checking for any negatives or absurd outliers
provider_agg.filter(
    (col("total_services") < 0) |
    (col("total_beneficiaries") < 0) |
    (col("total_bene_day_services") < 0) |
    (col("charge_allowed_ratio") < 0) |
    (col("payment_allowed_ratio") < 0) |
    (col("charge_allowed_ratio") > 100) |
    (col("payment_allowed_ratio") > 100)
).select(
    "npi",
    "total_services",
    "total_beneficiaries",
    "total_bene_day_services",
    "charge_allowed_ratio",
    "payment_allowed_ratio"
).show(10, truncate=False)


# checking frac_drug_services
print("=== Distinct frac_drug_services values ===")
provider_agg.select("frac_drug_services").distinct().orderBy("frac_drug_services").show()

# checking how many providers have zero drug‐share
zero_drug_count = provider_agg.filter(col("frac_drug_services") == 0.0).count()
print(f"Providers with frac_drug_services = 0.0: {zero_drug_count:,}")


=== Range summary ===


                                                                                

+------------+------------+-----------------+-----------------+---------------------+---------------------+------------------+----------------+-----------------+-----------------+
|min_services|max_services|min_beneficiaries|max_beneficiaries|min_bene_day_services|max_bene_day_services|min_charge_ratio  |max_charge_ratio|min_payment_ratio|max_payment_ratio|
+------------+------------+-----------------+-----------------+---------------------+---------------------+------------------+----------------+-----------------+-----------------+
|11.0        |1.6019747E7 |11.0             |1.0403747E7      |11.0                 |1.5177802E7          |0.8987661160749534|10638.461538    |0.0              |1.0              |
+------------+------------+-----------------+-----------------+---------------------+---------------------+------------------+----------------+-----------------+-----------------+



                                                                                

+----------+--------------+-------------------+-----------------------+--------------------+---------------------+
|npi       |total_services|total_beneficiaries|total_bene_day_services|charge_allowed_ratio|payment_allowed_ratio|
+----------+--------------+-------------------+-----------------------+--------------------+---------------------+
|1063551851|169.0         |137.0              |137.0                  |143.4770992393912   |0.7967357200377378   |
|1083682405|14.0          |14.0               |14.0                   |101.85851588546967  |0.8004288306050098   |
|1154375772|1540.0        |1140.0             |1157.0                 |141.13808279686486  |0.7842014118514078   |
|1215262720|1248.0        |1240.0             |1247.0                 |105.33816398391602  |0.7897113758618032   |
|1235447517|1582.0        |1209.0             |1232.0                 |142.30543513101554  |0.7853682283714345   |
|1225281827|25.0          |25.0               |25.0                   |127.46577

                                                                                

+--------------------+
|  frac_drug_services|
+--------------------+
|                 0.0|
|0.007246376811594203|
|0.007462686567164179|
|0.007692307692307693|
|0.007751937984496124|
|0.007874015748031496|
|0.007936507936507936|
|               0.008|
|0.008064516129032258|
| 0.00819672131147541|
|0.008264462809917356|
| 0.00847457627118644|
|0.008620689655172414|
|0.008695652173913044|
|0.008771929824561403|
|0.008849557522123894|
|0.008928571428571428|
|0.009009009009009009|
| 0.00909090909090909|
|0.009259259259259259|
+--------------------+
only showing top 20 rows




Providers with frac_drug_services = 0.0: 953,848


                                                                                

## 'List of Excluded Individuals and Entities' (LEIE) - OIG Dataset

In [45]:
# creating a csv path variable
leie_csv_path = "Office of Inspector General - Excluded Individuals and Entities/20250710 LEIE.csv"

In [46]:
# reading the CSV file into a DataFrame
full_leie_df = (
    spark.read
    .option("header", "true")
    .option("inferSchema", "false")
    .csv(leie_csv_path)
)

In [47]:
# inspecting the column names
print("Available columns:")
for c in full_leie_df.columns[:]:
    print(repr(c))

Available columns:
'LASTNAME'
'FIRSTNAME'
'MIDNAME'
'BUSNAME'
'GENERAL'
'SPECIALTY'
'UPIN'
'NPI'
'DOB'
'ADDRESS'
'CITY'
'STATE'
'ZIP'
'EXCLTYPE'
'EXCLDATE'
'REINDATE'
'WAIVERDATE'
'WVRSTATE'


In [48]:
# selecting the relevant columns I want to keep
keep_cols_leie = [
    "NPI",        
    "LASTNAME",   
    "FIRSTNAME",  
    "MIDNAME",
    "BUSNAME",    
    "ADDRESS",    
    "CITY",       
    "STATE",      
    "ZIP",        
    "SPECIALTY",  
    "EXCLTYPE",   
    "EXCLDATE",   
    "REINDATE",  
    "DOB",
    "WAIVERDATE",
    "WVRSTATE",
    "UPIN"
]

In [49]:
# creating a new DataFrame with only the selected columns
leie_df = full_leie_df.select(*keep_cols_leie)

In [50]:
# having a sanity check to ensure I loaded the right columns, the correct types, and the data looks okay
print("Schema:")
leie_df.printSchema()
print("Sample:")
leie_df.limit(5).show(truncate=False)

Schema:
root
 |-- NPI: string (nullable = true)
 |-- LASTNAME: string (nullable = true)
 |-- FIRSTNAME: string (nullable = true)
 |-- MIDNAME: string (nullable = true)
 |-- BUSNAME: string (nullable = true)
 |-- ADDRESS: string (nullable = true)
 |-- CITY: string (nullable = true)
 |-- STATE: string (nullable = true)
 |-- ZIP: string (nullable = true)
 |-- SPECIALTY: string (nullable = true)
 |-- EXCLTYPE: string (nullable = true)
 |-- EXCLDATE: string (nullable = true)
 |-- REINDATE: string (nullable = true)
 |-- DOB: string (nullable = true)
 |-- WAIVERDATE: string (nullable = true)
 |-- WVRSTATE: string (nullable = true)
 |-- UPIN: string (nullable = true)

Sample:
+----------+--------+---------+-------+---------------------------+-----------------------------+----------+-----+-----+------------------+--------+--------+--------+----+----------+--------+----+
|NPI       |LASTNAME|FIRSTNAME|MIDNAME|BUSNAME                    |ADDRESS                      |CITY      |STATE|ZIP  |SPECIA

In [51]:
# getting the total number of rows in the DataFrame
total_rows_leie = leie_df.count()
print(f'Total rows: {total_rows_leie}')

Total rows: 81774


In [52]:
# getting the null counts for each column
leie_df.select(
    *[sum_(col(c).isNull().cast('int')).alias(c) for c in leie_df.columns]
).show()

+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+
|NPI|LASTNAME|FIRSTNAME|MIDNAME|BUSNAME|ADDRESS|CITY|STATE|ZIP|SPECIALTY|EXCLTYPE|EXCLDATE|REINDATE| DOB|WAIVERDATE|WVRSTATE| UPIN|
+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+
|  0|    3358|     3358|  24009|  78416|      4|   0|    0|  0|     4088|       0|       0|       0|4222|         0|   81762|75791|
+---+--------+---------+-------+-------+-------+----+-----+---+---------+--------+--------+--------+----+----------+--------+-----+



In [53]:
# comoputing normalized names and checking for collisions
# checking for collisions as I don't want to accidently end up with two originals mapping to the same cleaned name

# applying the normalized names to every column I plan to keep
normalized_names = []
for original_name in keep_cols_leie:
    new_name = normalize_col(original_name)
    normalized_names.append(new_name)

# counting how many times each cleaned name shows up
name_counts = Counter(normalized_names)

# pulling out any names that show up more than once
duplicates = [name for name, count in name_counts.items() if count > 1]

# bailing out early, if a coliision is detected
if duplicates:
    raise RuntimeError(f'Column name conflict after normalization: {duplicates}')

In [54]:
# building a new DataFrame with normalized column names

# building a list of expressions to rename the columns in the DataFrame
renamed_columns_leie = [
    col(old).alias(new)
    for old, new in zip(keep_cols_leie, normalized_names)
]

# applying 
normalized_leie = leie_df.select(*renamed_columns_leie)

In [55]:
# doing a quick sanity check to ensure the new DataFrame has the correct columns
normalized_leie.printSchema()
normalized_leie.show(5, truncate=False)

root
 |-- npi: string (nullable = true)
 |-- lastname: string (nullable = true)
 |-- firstname: string (nullable = true)
 |-- midname: string (nullable = true)
 |-- busname: string (nullable = true)
 |-- address: string (nullable = true)
 |-- city: string (nullable = true)
 |-- state: string (nullable = true)
 |-- zip: string (nullable = true)
 |-- specialty: string (nullable = true)
 |-- excltype: string (nullable = true)
 |-- excldate: string (nullable = true)
 |-- reindate: string (nullable = true)
 |-- dob: string (nullable = true)
 |-- waiverdate: string (nullable = true)
 |-- wvrstate: string (nullable = true)
 |-- upin: string (nullable = true)

+----------+--------+---------+-------+---------------------------+-----------------------------+----------+-----+-----+------------------+--------+--------+--------+----+----------+--------+----+
|npi       |lastname|firstname|midname|busname                    |address                      |city      |state|zip  |specialty         |exc

In [56]:
# moving onto the cleaning values step of the LEIE DataFrame

leie = normalized_leie


# trimming strings and converting empty strings to nulls
for c, t in leie.dtypes:
    if t == "string":
        leie = leie.withColumn(c, when(trim(col(c)) == "", None).otherwise(trim(col(c))))

In [57]:
# keeping NPI and ZIP as strings, but normalizing them
leie = (leie
    .withColumn("npi", regexp_replace(col("npi"), r"[^0-9]", ""))
    .withColumn("npi_valid", when(col("npi").rlike(r"^(?!0{10})\d{10}$"), 1).otherwise(0).cast(IntegerType()))
    .withColumn("zip5", substring(regexp_replace(col("zip"), r"[^0-9]", ""), 1, 5))
    .drop("zip")
)

In [58]:
# standardizing codes/geo (UPPER only where it matters)
leie = (leie
    .withColumn("state", when(col("state").rlike(r"^[A-Za-z]{2}$"), upper(col("state"))).otherwise(lit(None)))
    .withColumnRenamed("wvrstate", "waiverstate")
    .withColumn("waiverstate", when(col("waiverstate").rlike(r"^[A-Za-z]{2}$"), upper(col("waiverstate"))).otherwise(lit(None)))
    .withColumn("excltype", upper(col("excltype")))         # canonicalize code
    .withColumn("specialty", upper(col("specialty")))       # optional but helps grouping
)

In [59]:
# --- CLEAN & PARSE DATES SAFELY (no try_to_date; tolerant to junk like 00000000) ---
from pyspark.sql.functions import (
    col, lit, when, coalesce, to_date, regexp_replace, length
)

def zeroish(colname):
    # keep only digits; null if it's 00000000 or empty
    digits = regexp_replace(col(colname), r"[^0-9]", "")
    return when((digits == "00000000") | (length(digits) == 0), lit(None)).otherwise(col(colname))

def parse_date_guarded(colname: str):
    raw = zeroish(colname)
    digits = regexp_replace(raw, r"[^0-9]", "")

    # 1) yyyyMMdd when exactly 8 digits
    ymd_compact = when(digits.rlike(r"^\d{8}$"), to_date(digits, "yyyyMMdd"))

    # 2) M/d/yyyy (handles 1-2 digit month/day)
    mdyyyy = when(raw.rlike(r"^\d{1,2}/\d{1,2}/\d{4}$"), to_date(raw, "M/d/yyyy"))

    # 3) yyyy-MM-dd
    ymd_dash = when(raw.rlike(r"^\d{4}-\d{2}-\d{2}$"), to_date(raw, "yyyy-MM-dd"))

    # first successful parse wins; else NULL
    return coalesce(ymd_compact, mdyyyy, ymd_dash)

leie = (leie
    .withColumn("excldate_dt",   parse_date_guarded("excldate"))
    .withColumn("reindate_dt",   parse_date_guarded("reindate"))
    .withColumn("waiverdate_dt", parse_date_guarded("waiverdate"))
    .withColumn("dob_dt",        parse_date_guarded("dob"))
    .drop("excldate","reindate","waiverdate","dob")
)



# --- ORG FLAG + EXCLUSION FEATURES (single, consolidated block) ---
leie = leie.withColumn(
    "is_org",
    when(trim(col("busname")) != "", 1).otherwise(0).cast(IntegerType())
)

as_of_feat = current_date()  # for feature engineering (your label uses AS_OF later)

leie = (leie
    .withColumn("has_excl", (col("excldate_dt").isNotNull()).cast(IntegerType()))
    .withColumn("exclusion_start", col("excldate_dt"))
    .withColumn("exclusion_end_effective", coalesce(col("reindate_dt"), as_of_feat))

    # safe duration (null if start missing or negative)
    .withColumn("exclusion_duration_days_raw",
        datediff(col("exclusion_end_effective"), col("exclusion_start"))
    )
    .withColumn("exclusion_duration_days",
        when(col("excldate_dt").isNotNull() & (col("exclusion_duration_days_raw") >= 0),
             col("exclusion_duration_days_raw")
        ).otherwise(lit(None)).cast(IntegerType())
    ).drop("exclusion_duration_days_raw")

    .withColumn(
        "is_currently_excluded",
        when(
            col("excldate_dt").isNotNull() &
            col("waiverdate_dt").isNull() &
            (col("reindate_dt").isNull() | (col("reindate_dt") > as_of_feat)),
            1
        ).otherwise(0).cast(IntegerType())
    )
)


In [60]:
# Confirm no zero-ish values slipped through (sanity check)
for raw in ["excldate_dt","reindate_dt","waiverdate_dt","dob_dt"]:
    cnt = leie.filter(col(raw).isNull()).count()  # null is expected for bad inputs; just ensuring parse ran
    print(raw, "null_count_after_parse:", cnt)


excldate_dt null_count_after_parse: 7
reindate_dt null_count_after_parse: 81767
waiverdate_dt null_count_after_parse: 81770
dob_dt null_count_after_parse: 4222


In [61]:
# aggregating and making one row per NPI
leie_with_npi = leie.filter(col("npi_valid") == 1)

leie_by_npi = (leie_with_npi
    .groupBy("npi")
    .agg(
        countDistinct("excldate_dt").cast(IntegerType()).alias("exclusion_count"),
        min("excldate_dt").cast(DateType()).alias("first_excldate"),
        max("excldate_dt").cast(DateType()).alias("most_recent_excldate"),
        max("is_currently_excluded").cast(IntegerType()).alias("is_currently_excluded"),
        max("waiverdate_dt").cast(DateType()).alias("last_waiverdate"),
        max("reindate_dt").cast(DateType()).alias("last_reindate")
    )
)

In [None]:
# --- Build npi_names from full_npi_df using available name/org columns (drop this BEFORE "Cell A") ---
from pyspark.sql import functions as F
from pyspark.sql.functions import col, trim, regexp_replace, upper, split, array_remove, when, substring

# ordered fallbacks based on your inspection
last_name_fallbacks = [
    "Provider Last Name (Legal Name)", "Provider Other Last Name", "Provider Other Last Name Type Code", "provider_last_name", "last_name"
]
first_name_fallbacks = [
    "Provider First Name", "Provider Other First Name", "provider_first_name", "first_name"
]
org_name_fallbacks = [
    "Provider Organization Name (Legal Business Name)",
    "Provider Other Organization Name",
    "Provider Business Practice Location Name",
    "parent_organization_lbn",
    "provider_organization_name",
    "organization_name"
]
state_fallbacks = [
    "Provider Business Practice Location Address State Name",
    "Provider Business Mailing Address State Name",
    "Provider Business Mailing Address State Name"
]
zip_fallbacks = [
    "Provider Business Practice Location Address Postal Code",
    "Provider Business Mailing Address Postal Code"
]

def pick_first_existing(col_list, df_cols):
    for c in col_list:
        if c in df_cols:
            return c
    return None

last_col = pick_first_existing(last_name_fallbacks, full_npi_df.columns)
first_col = pick_first_existing(first_name_fallbacks, full_npi_df.columns)
org_col = pick_first_existing(org_name_fallbacks, full_npi_df.columns)
state_col = pick_first_existing(state_fallbacks, full_npi_df.columns)
zip_col = pick_first_existing(zip_fallbacks, full_npi_df.columns)

print("Selected columns for npi_names -> last:", last_col, "first:", first_col, "org:", org_col, "state:", state_col, "zip:", zip_col)

# helper normalizer for person names
def norm_person_expr(colname):
    # uppercase, remove non-alnum/space, collapse multi-space, trim
    return upper(regexp_replace(F.coalesce(col(colname), F.lit("")), r'[^A-Z0-9\s]', ''))

# helper normalizer for organization names (strip common suffixes to improve token overlap)
org_suffix_rx = r'\b(LLC|L\.L\.C\.|L L C|INC|CORP|CORPORATION|PC|P\.C\.|P C|P\.A\.|P A|PA|PLLC|LLP|LTD|ASSOCIATES|ASSOC|MEDICAL|HOSPITAL|INSTITUTE|GROUP|SERVICES|SERVICE|CENTER|CENTRE)\b'
def norm_org_expr(colname):
    base = upper(regexp_replace(F.coalesce(col(colname), F.lit("")), r'[^A-Z0-9\s]', ''))
    no_suffix = regexp_replace(base, org_suffix_rx, '')         # remove suffix words
    no_extra_spaces = regexp_replace(no_suffix, r'\s+', ' ')    # collapse spaces
    return F.trim(no_extra_spaces)

# Build the base npi_names DataFrame, using coalesce fallbacks for safety
npi_names = (
    full_npi_df
    .withColumn("npi", regexp_replace(trim(col("NPI")), r"\D", ""))
    .withColumn("npi", when(col("npi") == "", None).otherwise(col("npi")))
    .withColumn("npi_last_raw",
                F.coalesce(col(last_col) if last_col is not None else F.lit(""),
                           col("Provider Other Last Name") if "Provider Other Last Name" in full_npi_df.columns else F.lit(""),
                           F.lit("")))
    .withColumn("npi_first_raw",
                F.coalesce(col(first_col) if first_col is not None else F.lit(""),
                           col("Provider Other First Name") if "Provider Other First Name" in full_npi_df.columns else F.lit(""),
                           F.lit("")))
    .withColumn("npi_org_raw",
                F.coalesce(col(org_col) if org_col is not None else F.lit(""),
                           col("Provider Other Organization Name") if "Provider Other Organization Name" in full_npi_df.columns else F.lit(""),
                           F.lit("")))
    # optional geo fields if present
    .withColumn("npi_state",
                when(col(state_col).rlike(r"^[A-Za-z]{2}$"), upper(col(state_col))).otherwise(None) if state_col is not None else F.lit(None))
    .withColumn("npi_zip5",
                substring(regexp_replace(col(zip_col), r"[^0-9]", ""), 1, 5) if zip_col is not None else F.lit(None))
    # normalized forms
    .withColumn("npi_last_norm", norm_person_expr("npi_last_raw"))
    .withColumn("npi_first_norm", norm_person_expr("npi_first_raw"))
    .withColumn("npi_org_norm", norm_org_expr("npi_org_raw"))
    # tokens
    .withColumn("npi_last_tokens", array_remove(split(col("npi_last_norm"), r"\s+"), ""))
    .withColumn("npi_first_tokens", array_remove(split(col("npi_first_norm"), r"\s+"), ""))
    .withColumn("npi_org_tokens", array_remove(split(col("npi_org_norm"), r"\s+"), ""))
    .select("npi", "npi_state", "npi_zip5", "npi_last_norm", "npi_first_norm", "npi_org_norm",
            "npi_last_tokens", "npi_first_tokens", "npi_org_tokens")
    .cache()
)

print("Built npi_names rows:", npi_names.count())
npi_names.show(8, truncate=120)


In [None]:
# Cell A: build normalized name/location slices for matching
from pyspark.sql import functions as F
from pyspark.sql.functions import col, trim, upper, regexp_replace, substring, split, array_remove

# sanity checks
if "leie" not in globals():
    raise RuntimeError("leie DF not found — run the LEIE cleaning cell first.")
if "wrangled_npi_df" not in globals():
    raise RuntimeError("wrangled_npi_df not found — run the NPPES wrangle cells first.")

# helper normalizers (keep letters/numbers and collapse spaces)
def norm_text_expr(c):
    return F.upper(F.regexp_replace(F.coalesce(c, F.lit("")), r'[^A-Z0-9\s]', ''))

# LEIE name tokens
leie_names = (
    leie
    .select(
        col("npi").alias("leie_npi"),
        norm_text_expr(col("lastname")).alias("leie_last_norm"),
        norm_text_expr(col("firstname")).alias("leie_first_norm"),
        norm_text_expr(col("busname")).alias("leie_org_norm"),
        col("state").alias("leie_state"),
        col("zip5").alias("leie_zip5")
    )
    .withColumn("leie_last_tokens", array_remove(split(F.col("leie_last_norm"), r"\s+"), ""))
    .withColumn("leie_first_tokens", array_remove(split(F.col("leie_first_norm"), r"\s+"), ""))
    .withColumn("leie_org_tokens", array_remove(split(F.col("leie_org_norm"), r"\s+"), ""))
    .cache()
)


print("NPI name slices:", npi_names.count(), "rows (cached).")


In [None]:
# Cell B: candidate generation using cheap blocking (state/zip/org/lastname)
from pyspark.sql import functions as F
from pyspark.sql.functions import col, array_intersect, size

# Parameters: you can tune these
MAX_PER_BLOCK_SAMPLE = None   # if you want to sample each block for speed, set integer like 20000
# Blocking strategies:
# 1) exact npi join already handled in your main flow (explicit).
# 2) state + zip5 join (tight geo)
c1 = leie_names.alias("L").join(
    npi_names.alias("N"),
    ( (F.coalesce(col("L.leie_state"), F.lit("")) == F.coalesce(col("N.npi_state"), F.lit(""))) &
      (F.coalesce(col("L.leie_zip5"), F.lit("")) == F.coalesce(col("N.npi_zip5"), F.lit("")))
    ),
    how="inner"
).select(
    col("L.leie_npi"), col("L.leie_last_norm"), col("L.leie_first_norm"), col("L.leie_org_norm"),
    col("L.leie_last_tokens"), col("L.leie_first_tokens"), col("L.leie_org_tokens"),
    col("N.npi"), col("N.npi_state"), col("N.npi_zip5"),
    col("N.npi_last_norm"), col("N.npi_first_norm"), col("N.npi_org_norm"),
    col("N.npi_last_tokens"), col("N.npi_first_tokens"), col("N.npi_org_tokens")
)

# 3) state + org-token overlap (contracts/org matches)
c2 = leie_names.alias("L").join(
    npi_names.alias("N"),
    ( (F.coalesce(col("L.leie_state"), F.lit("")) == F.coalesce(col("N.npi_state"), F.lit(""))) &
      (size(array_intersect(col("L.leie_org_tokens"), col("N.npi_org_tokens"))) > 0)
    ),
    how="inner"
).select(
    col("L.leie_npi"), col("L.leie_last_norm"), col("L.leie_first_norm"), col("L.leie_org_norm"),
    col("L.leie_last_tokens"), col("L.leie_first_tokens"), col("L.leie_org_tokens"),
    col("N.npi"), col("N.npi_state"), col("N.npi_zip5"),
    col("N.npi_last_norm"), col("N.npi_first_norm"), col("N.npi_org_norm"),
    col("N.npi_last_tokens"), col("N.npi_first_tokens"), col("N.npi_org_tokens")
)

# 4) last name exact + state block (last name + first initial)
c3 = leie_names.alias("L").join(
    npi_names.alias("N"),
    ( (F.coalesce(col("L.leie_state"), F.lit("")) == F.coalesce(col("N.npi_state"), F.lit(""))) &
      (col("L.leie_last_norm") == col("N.npi_last_norm"))
    ),
    how="inner"
).select(
    col("L.leie_npi"), col("L.leie_last_norm"), col("L.leie_first_norm"), col("L.leie_org_norm"),
    col("L.leie_last_tokens"), col("L.leie_first_tokens"), col("L.leie_org_tokens"),
    col("N.npi"), col("N.npi_state"), col("N.npi_zip5"),
    col("N.npi_last_norm"), col("N.npi_first_norm"), col("N.npi_org_norm"),
    col("N.npi_last_tokens"), col("N.npi_first_tokens"), col("N.npi_org_tokens")
)

# union candidate sets and dedupe
candidates = c1.unionByName(c2).unionByName(c3).dropDuplicates(["leie_npi", "npi"])

# optionally sample for speed (uncomment to use)
if MAX_PER_BLOCK_SAMPLE is not None:
    candidates = candidates.sample(withReplacement=False, fraction=0.05, seed=42)

candidates = candidates.cache()
print("Candidates generated:", candidates.count())
candidates.show(10, truncate=120)


In [None]:
# Cell C: score candidates with simple heuristics and flag conservative/loose inferred matches
from pyspark.sql import functions as F
from pyspark.sql.functions import col, size, array_intersect

# similarity components:
# - org_token_overlap_frac = intersection / union of org tokens
# - name_last_exact (1/0)
# - first_initial_match (1/0)
# - last_token_overlap_frac

def safe_frac(intersect_col, tokens_a_col, tokens_b_col):
    # (size(intersect) / (size(a)+size(b)-size(intersect))) safely
    return (size(intersect_col) / 
            ( (size(tokens_a_col) + size(tokens_b_col) - size(intersect_col)).cast("double") ).cast("double") )

cand = candidates.withColumn(
    "org_intersect", array_intersect(col("leie_org_tokens"), col("npi_org_tokens"))
).withColumn(
    "org_union_size",
    (size(col("leie_org_tokens")) + size(col("npi_org_tokens")) - size(col("org_intersect")))
).withColumn(
    "org_token_overlap_frac",
    F.when(col("org_union_size") > 0, size(col("org_intersect")) / col("org_union_size")).otherwise(0.0)
).withColumn(
    "last_token_intersect", array_intersect(col("leie_last_tokens"), col("npi_last_tokens"))
).withColumn(
    "last_token_overlap_frac",
    F.when( (size(col("leie_last_tokens")) + size(col("npi_last_tokens")) - size(col("last_token_intersect"))) > 0,
           size(col("last_token_intersect")) / (size(col("leie_last_tokens")) + size(col("npi_last_tokens")) - size(col("last_token_intersect")) )
         ).otherwise(0.0)
).withColumn(
    "last_exact",
    F.when(col("leie_last_norm") == col("npi_last_norm"), 1).otherwise(0)
).withColumn(
    "first_initial_match",
    F.when( (F.length(F.coalesce(col("leie_first_norm"), F.lit(""))) > 0) &
            (F.substring(col("leie_first_norm"),1,1) == F.substring(col("npi_first_norm"),1,1)),
            1).otherwise(0)
)

# composite score (weights chosen conservatively)
cand = cand.withColumn(
    "score",
    0.45 * col("org_token_overlap_frac") + 
    0.30 * col("last_token_overlap_frac") + 
    0.15 * col("last_exact") + 
    0.10 * col("first_initial_match")
)

# thresholds (tune these)
CONSERVE_THRESH = 0.92   # very high precision: likely same provider
LOOSE_THRESH = 0.80      # more recall, may need SME review

cand = cand.withColumn(
    "match_conservative", F.when(col("score") >= CONSERVE_THRESH, 1).otherwise(0)
).withColumn(
    "match_loose", F.when(col("score") >= LOOSE_THRESH, 1).otherwise(0)
)

# Select columns of interest
candidates_scored = cand.select(
    "leie_npi", "npi", "score", "org_token_overlap_frac", "last_token_overlap_frac",
    "last_exact", "first_initial_match", "match_conservative", "match_loose"
).cache()

print("Scored candidate count:", candidates_scored.count())
candidates_scored.orderBy(col("score").desc()).show(20, truncate=120)


In [None]:
# Cell D: aggregate candidate matches to build augmented NPI-level labels
from pyspark.sql import functions as F
from pyspark.sql.functions import col, max as spark_max

# explicit gold (from leie_by_npi)
explicit = leie_by_npi.select(col("npi").alias("npi"), (F.when(col("exclusion_count") > 0, 1).otherwise(0)).alias("is_explicit"))

# inferred matches: if any candidate for NPI had match_conservative or match_loose
inferred = (
    candidates_scored
    .groupBy("npi")
    .agg(
        spark_max(col("match_conservative")).alias("any_conservative_match"),
        spark_max(col("match_loose")).alias("any_loose_match"),
        F.max(col("score")).alias("max_score")
    )
)

# merge explicit + inferred into a single table keyed by npi
# NOTE: Some NPIs may be in LEIE explicit-only; some inferred-only (no explicit). We'll keep provenance.
# first ensure explicit table has all explicit NPIs (npi column already string in your workflow)
leie_label_augmented = (
    explicit.alias("E")
    .join(inferred.alias("I"), on="npi", how="full_outer")
    .na.fill({"is_explicit": 0, "any_conservative_match": 0, "any_loose_match": 0, "max_score": 0.0})
    .withColumn("is_fraud", F.when(col("is_explicit") == 1, 1).otherwise( F.when(col("any_conservative_match") == 1, 1).otherwise(0) ))
    .withColumn("label_is_inferred_conservative", col("any_conservative_match").cast("int"))
    .withColumn("label_is_inferred_loose", col("any_loose_match").cast("int"))
    .select("npi", "is_explicit", "is_fraud", "label_is_inferred_conservative", "label_is_inferred_loose", "max_score")
)

# persist and save to parquet for later use
leie_label_augmented = leie_label_augmented.cache()
print("Augmented label rows:", leie_label_augmented.count())

# write to disk (so downstream code can read it if session restarts)
aug_path = "curated/leie_label_augmented.parquet"
leie_label_augmented.coalesce(1).write.mode("overwrite").parquet(aug_path)
print("Wrote augmented labels to:", aug_path)


In [None]:
# Cell E: export the top N inferred-only candidates for quick SME review
from pyspark.sql import functions as F

# examples: inferred-only (not explicit in LEIE), sorted by score desc
inferred_only = (
    leie_label_augmented.filter((col("is_explicit") == 0) & (col("max_score") > 0))
    .join(candidates_scored, on="npi", how="left")   # bring candidate score back (may duplicate but fine for review)
    .filter(col("match_conservative") == 1)
    .orderBy(col("max_score").desc())
)

sample_for_review = inferred_only.limit(500)  # keep small for export/review
print("Inferred-only (conservative) sample count:", sample_for_review.count())
sample_for_review.show(20, truncate=200)

# Save sample for SME as parquet & csv
rev_path = "curated/leie_inferred_candidates_sample.parquet"
sample_for_review.coalesce(1).write.mode("overwrite").parquet(rev_path)
print("Saved SME sample to:", rev_path)


### Working on joining the three datasets

In [62]:
# NPPES slice
wrangled_npi_tosave = wrangled_npi_df.select(
    "npi","state_abbr","zip5","entity_type_code","primary_taxonomy",
    "npi_age_days","is_active","is_organization_subpart","is_sole_proprietor"
)


# PUF slice
provider_agg_tosave = provider_agg.select(
    "npi","total_services","total_beneficiaries","total_bene_day_services",
    "w_avg_submitted_charge","w_avg_allowed","w_avg_payment",
    "charge_allowed_ratio","payment_allowed_ratio",
    "num_unique_procedures","stddev_submitted_charge",
    "frac_drug_services","frac_missing_zip",
    "services_per_bene","bene_days_per_bene"
)

# LEIE slice
leie_by_npi_tosave = leie_by_npi.select(
    "npi","exclusion_count","first_excldate","most_recent_excldate",
    "is_currently_excluded","last_waiverdate","last_reindate"
)


In [63]:
WRITE_OPTS = {
    "compression": "snappy",
    "maxRecordsPerFile": "750000",                 # split giant outputs into manageable files
    "parquet.block.size": str(64 * 1024 * 1024),   # 64MB row groups (lighter than default)
    "parquet.page.size":  str(1 * 1024 * 1024),
}

def write_small(df, path, parts):
    (df.coalesce(parts)                               # 1–3 writers is plenty on a laptop
       .write.mode("overwrite")
       .options(**WRITE_OPTS)
       .parquet(path))

write_small(wrangled_npi_tosave, "curated/npi_nodes.parquet",        parts=2)
write_small(provider_agg_tosave, "curated/puf_provider_agg.parquet", parts=2)
write_small(leie_by_npi_tosave,  "curated/leie_by_npi.parquet",      parts=1)


                                                                                

In [64]:
# quick sanity check on the saved files
print("NPI files:", len(dbutils.fs.ls("curated/npi_nodes.parquet")) if 'dbutils' in globals() else "written")
print("PUF files:", "written")
print("LEIE files:", "written")


NPI files: written
PUF files: written
LEIE files: written


In [None]:
# ==== FIXED BLOCK (with AS_OF_STR fallback) ====
import builtins as py
from pyspark.sql import functions as F
from pyspark.sql.functions import col, lit
from pyspark import StorageLevel

# ensure AS_OF_STR exists (use existing if kernel still has it, else fallback)
AS_OF_STR = globals().get("AS_OF_STR", "2023-12-31")
print("Using AS_OF_STR =", AS_OF_STR)

# tune local shuffle partitions for dev
spark.conf.set("spark.sql.shuffle.partitions", "8")

# --- 0) Build leie_label (prefer augmented/inferred labels if available) ---
if "leie_by_npi" not in globals():
    raise RuntimeError("leie_by_npi not found — please run the LEIE aggregation cell before this block.")

# Try in-memory augmented labels first, then on-disk, otherwise fall back to explicit-only
leie_label = None

if "leie_label_augmented" in globals():
    la = leie_label_augmented
    leie_label = (
        la
        .withColumn("is_fraud", col("is_fraud").cast("int"))
        # prefer conservative inferred flag if present, else loose, else 0
        .withColumn("label_is_inferred",
                    coalesce(col("label_is_inferred_conservative"), col("label_is_inferred_loose"), lit(0)).cast("int"))
        .select("npi", "is_fraud", "label_is_inferred")
        .filter(col("npi").isNotNull())
        .distinct()
    )
    print("Using in-memory leie_label_augmented")

else:
    # try to load augmented parquet written earlier (if you executed Cell D / E)
    try:
        la = spark.read.parquet("curated/leie_label_augmented.parquet")
        leie_label = (
            la
            .withColumn("is_fraud", col("is_fraud").cast("int"))
            .withColumn("label_is_inferred",
                        coalesce(col("label_is_inferred_conservative"), col("label_is_inferred_loose"), lit(0)).cast("int"))
            .select("npi", "is_fraud", "label_is_inferred")
            .filter(col("npi").isNotNull())
            .distinct()
        )
        print("Loaded leie_label_augmented.parquet from disk and using augmented labels")
    except Exception:
        # final fallback: explicit-only gold from leie_by_npi
        leie_label = (
            leie_by_npi
            .select(
                col("npi"),
                F.when(col("exclusion_count") > 0, 1).otherwise(0).alias("is_fraud")
            )
            .distinct()
            .withColumn("label_is_inferred", lit(0))
            .filter(col("npi").isNotNull())
        )
        print("Augmented labels not found; using explicit-only leie_by_npi label")


# cheap approx size check for leie_label (to decide whether to broadcast)
try:
    approx_leie = int(list(leie_label.rdd.countApprox(2000).values())[0])
except Exception:
    approx_leie = None
print("Approx LEIE rows:", approx_leie)

# --- 1) Compact NPPES slice (project early) ---
cols_npi_feat = [
    "npi","state_abbr","zip5","entity_type_code","primary_taxonomy",
    "npi_age_days","is_active","is_organization_subpart","is_sole_proprietor"
]
if "wrangled_npi_df" not in globals():
    raise RuntimeError("wrangled_npi_df not found — re-run NPPES wrangle cells.")
npi_feat = wrangled_npi_df.select(*[c for c in cols_npi_feat if c in wrangled_npi_df.columns]).distinct()

# --- 2) Project provider_agg columns you actually created earlier ---
needed_cols = [
    "npi",
    "total_services","total_beneficiaries","total_bene_day_services",
    "w_avg_submitted_charge","w_avg_allowed","w_avg_payment",
    "charge_allowed_ratio","payment_allowed_ratio",
    "num_unique_procedures","stddev_submitted_charge",
    "frac_drug_services","frac_missing_zip",
    "services_per_bene","bene_days_per_bene"
]
if "provider_agg" not in globals():
    raise RuntimeError("provider_agg not found — re-run the aggregation cell.")
needed_cols = [c for c in needed_cols if c in provider_agg.columns]  # keep only existing ones
provider_agg_small = provider_agg.select(*needed_cols)

# --- 3) Decide broadcast strategy (safe: broadcast if LEIE small) ---
BROADCAST_THRESHOLD = 2_000_000  # safe rule-of-thumb for local dev
do_broadcast_leie = (approx_leie is not None) and (approx_leie <= BROADCAST_THRESHOLD)
print("Will broadcast LEIE lookup:", do_broadcast_leie)

# Build base merged DF:
base = provider_agg_small.alias("p").join(
    npi_feat.select("npi", "primary_taxonomy", "state_abbr").alias("n"),
    on="npi",
    how="left"
)

if do_broadcast_leie:
    merged = base.join(F.broadcast(leie_label.alias("l")), on="npi", how="left")
else:
    # write leie_label to parquet and read it back to avoid an attempted broadcast of a large DF
    tmp_lp = "/tmp/leie_label_for_join.parquet"
    leie_label.write.mode("overwrite").parquet(tmp_lp)
    leie_label_on_disk = spark.read.parquet(tmp_lp).select("npi","is_fraud","label_is_inferred").distinct()
    merged = base.join(leie_label_on_disk.alias("l"), on="npi", how="left")

# fill missing fraud label with 0
merged = merged.withColumn("is_fraud", F.coalesce(F.col("is_fraud"), F.lit(0)).cast("int"))

# --- 4) Persist safely (spill to disk if memory is tight) ---
merged = merged.repartition(8, "npi")
merged.persist(StorageLevel.MEMORY_AND_DISK)

# LIGHT SANITY CHECKS (cheap)
print("Merged preview:")
merged.select("npi","is_fraud").limit(10).show(truncate=False)
print("Preview positives (limit):")
merged.filter(col("is_fraud") == 1).select("npi").limit(20).show(truncate=False)

# --- 5) Downsample negatives up to TARGET while keeping all positives ---
TARGET = 1_000_000

# exact positives count (cheap because LEIE is small)
pos_df = merged.filter(col("is_fraud") == 1).select("npi").distinct()
pos_ct = pos_df.count()
print("Positive (distinct NPIs) count:", pos_ct)

# estimate total providers via a very small sample instead of full count
SAMPLE_F = 0.001  # 0.1% sample for local dev; increase if estimates are noisy
print("Using SAMPLE_F:", SAMPLE_F)
sample_rdd = provider_agg_small.sample(fraction=SAMPLE_F, seed=42).rdd

# Use countApprox (timeout in ms) to keep this cheap
try:
    approx_res = sample_rdd.countApprox(2000)
    sample_est_val = int(list(approx_res.values())[0])
except Exception:
    # fallback to safer (but heavier) exact count on the tiny sample RDD
    sample_est_val = sample_rdd.count()

est_total = int(sample_est_val / SAMPLE_F) if SAMPLE_F > 0 else provider_agg_small.count()
print("Estimated total providers (approx):", est_total)

# compute estimated negatives safely using Python builtins to avoid PySpark shadowing
est_neg = py.max(est_total - pos_ct, 1)
keep_neg = py.max(TARGET - pos_ct, 0)
frac_neg = py.min(1.0, float(keep_neg) / float(est_neg)) if est_neg > 0 else 0.0
print("est_neg:", est_neg, "keep_neg:", keep_neg, "frac_neg:", frac_neg)

if frac_neg <= 0:
    final_df = merged.filter(col("is_fraud") == 1)   # keep only positives if target already reached
else:
    neg_sample = merged.filter(col("is_fraud") == 0).sample(withReplacement=False, fraction=frac_neg, seed=42)
    # bring positives back (join ensures we preserve full rows)
    pos_full = pos_df.join(merged, on="npi", how="inner")
    final_df = pos_full.unionByName(neg_sample.select(merged.columns))

# slight materialization to warm cache (cheap)
_ = final_df.limit(20).toPandas()

# --- 6) Drop leakage columns if any (excl*) ---
leak_cols = [c for c in final_df.columns if c.startswith("excl") or c.endswith("_excldate")]
if leak_cols:
    final_df = final_df.drop(*leak_cols)

# --- 7) Write out parquet (local-dev friendly); partition by state_abbr if available ---
WRITE_OPTS = {
    "compression": "snappy",
    "maxRecordsPerFile": "750000",
    "parquet.block.size": str(64 * 1024 * 1024),
    "parquet.page.size":  str(1  * 1024 * 1024),
}

outpath = f"curated/training/providers_merged_asof_{AS_OF_STR}_ever.parquet"
try:
    if "state_abbr" in final_df.columns:
        final_df.repartition(8).write.mode("overwrite").partitionBy("state_abbr").options(**WRITE_OPTS).parquet(outpath)
    else:
        final_df.repartition(8).write.mode("overwrite").options(**WRITE_OPTS).parquet(outpath)
    print("Saved merged dataset to:", outpath)
except Exception as e:
    print("Failed to write parquet:", str(e))

# final sanity
print("Positives count in final (preview):")
final_df.select("is_fraud").groupBy("is_fraud").count().show()
# ==== END FIXED BLOCK ====


Using AS_OF_STR = 2023-12-31


[Stage 137:>                                                        (0 + 1) / 1]

Approx LEIE rows: None
Will broadcast LEIE lookup: False


                                                                                

Merged preview:


                                                                                

+----------+--------+
|npi       |is_fraud|
+----------+--------+
|1003000936|0       |
|1003002254|0       |
|1003002742|0       |
|1003002890|0       |
|1003003856|0       |
|1003006107|0       |
|1003006198|0       |
|1003006396|0       |
|1003006602|0       |
|1003006768|0       |
+----------+--------+

Preview positives (limit):
+----------+
|npi       |
+----------+
|1003242314|
|1134216633|
|1235576398|
|1255665261|
|1326046616|
|1437198074|
|1568690014|
|1649201153|
|1700426467|
|1891854873|
|1912999087|
|1922059831|
|1942217526|
|1942375837|
|1962745356|
|1003312018|
|1073977302|
|1083673180|
|1093856973|
|1275580748|
+----------+

Positive (distinct NPIs) count: 120
Using SAMPLE_F: 0.001


                                                                                

Estimated total providers (approx): 1237000
est_neg: 1236880 keep_neg: 999880 frac_neg: 0.8083888493629131


                                                                                

Saved merged dataset to: curated/training/providers_merged_asof_2023-12-31_ever.parquet
Positives count in final (preview):


                                                                                

+--------+------+
|is_fraud| count|
+--------+------+
|       1|   120|
|       0|950227|
+--------+------+



25/08/15 19:30:46 WARN HeartbeatReceiver: Removing executor driver with no recent heartbeats: 897373 ms exceeds timeout 120000 ms
25/08/15 19:30:46 WARN SparkContext: Killing executors is not supported by current scheduler.
25/08/15 19:30:47 ERROR Inbox: Ignoring error
org.apache.spark.SparkException: Exception thrown in awaitResult: 
	at org.apache.spark.util.SparkThreadUtils$.awaitResult(SparkThreadUtils.scala:53)
	at org.apache.spark.util.ThreadUtils$.awaitResult(ThreadUtils.scala:342)
	at org.apache.spark.rpc.RpcTimeout.awaitResult(RpcTimeout.scala:75)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRefByURI(RpcEnv.scala:102)
	at org.apache.spark.rpc.RpcEnv.setupEndpointRef(RpcEnv.scala:110)
	at org.apache.spark.util.RpcUtils$.makeDriverRef(RpcUtils.scala:36)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.driverEndpoint$lzycompute(BlockManagerMasterEndpoint.scala:132)
	at org.apache.spark.storage.BlockManagerMasterEndpoint.org$apache$spark$storage$BlockManagerMasterEndpoint$$

In [66]:
# Cell: find probable name-like columns in the full NPI file
import re
possible = [c for c in full_npi_df.columns if re.search(r'first|last|name|lbn|organization|org|legal|business', c, re.I)]
print("Possible NPI name columns (sample):")
for c in possible:
    print(" -", c)
print("\nTotal candidate name-like columns:", len(possible))


Possible NPI name columns (sample):
 - Provider Organization Name (Legal Business Name)
 - Provider Last Name (Legal Name)
 - Provider First Name
 - Provider Middle Name
 - Provider Name Prefix Text
 - Provider Name Suffix Text
 - Provider Other Organization Name
 - Provider Other Organization Name Type Code
 - Provider Other Last Name
 - Provider Other First Name
 - Provider Other Middle Name
 - Provider Other Name Prefix Text
 - Provider Other Name Suffix Text
 - Provider Other Last Name Type Code
 - Provider First Line Business Mailing Address
 - Provider Second Line Business Mailing Address
 - Provider Business Mailing Address City Name
 - Provider Business Mailing Address State Name
 - Provider Business Mailing Address Postal Code
 - Provider Business Mailing Address Country Code (If outside U.S.)
 - Provider Business Mailing Address Telephone Number
 - Provider Business Mailing Address Fax Number
 - Provider First Line Business Practice Location Address
 - Provider Second Line Bu