In [35]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path

In [36]:
base = Path().resolve().parent
file_path = f'{base}' + '/data/medicare_data_cleaned.parquet'

df = pd.read_parquet(file_path)

In [37]:
df_2023 = df[df['year']==2023]

df_2023 = df_2023.drop(columns=['Rndrng_Prvdr_St', 'Rndrng_Prvdr_State_FIPS',
                       'Tot_Dschrgs', 'Avg_Submtd_Cvrd_Chrg', 'Avg_Tot_Pymt_Amt', 'year'], axis=1)

In [38]:
df_2023.head()

Unnamed: 0,Rndrng_Prvdr_CCN,Rndrng_Prvdr_Org_Name,Rndrng_Prvdr_City,Rndrng_Prvdr_Zip5,Rndrng_Prvdr_State_Abrvtn,DRG_Cd,Avg_Mdcr_Pymt_Amt,RUCA_category
2000,10001,Southeast Health Medical Center,Dothan,36301.0,AL,3,115544.14286,metro_high_commute
2001,10001,Southeast Health Medical Center,Dothan,36301.0,AL,23,35261.807692,metro_high_commute
2002,10001,Southeast Health Medical Center,Dothan,36301.0,AL,24,25048.916667,metro_high_commute
2003,10001,Southeast Health Medical Center,Dothan,36301.0,AL,25,32438.625,metro_high_commute
2004,10001,Southeast Health Medical Center,Dothan,36301.0,AL,38,9579.363636,metro_high_commute


In [39]:
common_drgs = df_2023['DRG_Cd'].value_counts(
)[df_2023['DRG_Cd'].value_counts() > 25].index

df_2023['drg_grouped'] = df_2023['DRG_Cd'].apply(
    lambda x: x if x in common_drgs else 'Other')

In [41]:
cat_cols = ['Rndrng_Prvdr_CCN', 'DRG_Cd', 'Rndrng_Prvdr_Zip5', 'Rndrng_Prvdr_Org_Name',
            'Rndrng_Prvdr_City', 'Rndrng_Prvdr_State_Abrvtn', 'RUCA_category', 'drg_grouped']

# Encode each column
for col in cat_cols:
    df_2023[col] = df_2023[col].astype('category')

In [42]:
df_2023.info()

<class 'pandas.core.frame.DataFrame'>
Index: 146427 entries, 2000 to 439168
Data columns (total 9 columns):
 #   Column                     Non-Null Count   Dtype   
---  ------                     --------------   -----   
 0   Rndrng_Prvdr_CCN           146427 non-null  category
 1   Rndrng_Prvdr_Org_Name      146427 non-null  category
 2   Rndrng_Prvdr_City          146427 non-null  category
 3   Rndrng_Prvdr_Zip5          146427 non-null  category
 4   Rndrng_Prvdr_State_Abrvtn  146427 non-null  category
 5   DRG_Cd                     146427 non-null  category
 6   Avg_Mdcr_Pymt_Amt          146427 non-null  float64 
 7   RUCA_category              146427 non-null  category
 8   drg_grouped                146427 non-null  category
dtypes: category(8), float64(1)
memory usage: 4.6 MB


In [43]:
target = 'Avg_Mdcr_Pymt_Amt'
cols = [target] + [col for col in df_2023.columns if col != target]
df_2023 = df_2023[cols]
df_2023.head()

Unnamed: 0,Avg_Mdcr_Pymt_Amt,Rndrng_Prvdr_CCN,Rndrng_Prvdr_Org_Name,Rndrng_Prvdr_City,Rndrng_Prvdr_Zip5,Rndrng_Prvdr_State_Abrvtn,DRG_Cd,RUCA_category,drg_grouped
2000,115544.14286,10001,Southeast Health Medical Center,Dothan,36301.0,AL,3,metro_high_commute,3
2001,35261.807692,10001,Southeast Health Medical Center,Dothan,36301.0,AL,23,metro_high_commute,23
2002,25048.916667,10001,Southeast Health Medical Center,Dothan,36301.0,AL,24,metro_high_commute,24
2003,32438.625,10001,Southeast Health Medical Center,Dothan,36301.0,AL,25,metro_high_commute,25
2004,9579.363636,10001,Southeast Health Medical Center,Dothan,36301.0,AL,38,metro_high_commute,38


In [None]:
# 5% random sample of the full dataset
df_sampled = df_2023.sample(frac=0.05, random_state=42)
df_sampled

Unnamed: 0,Avg_Mdcr_Pymt_Amt,Rndrng_Prvdr_CCN,Rndrng_Prvdr_Org_Name,Rndrng_Prvdr_City,Rndrng_Prvdr_Zip5,Rndrng_Prvdr_State_Abrvtn,DRG_Cd,RUCA_category,drg_grouped
2980,8826.280000,010039,Huntsville Hospital,Huntsville,35801.0,AL,602,metro_core,602
110234,3652.666667,100315,Viera Hospital,Melbourne,32940.0,FL,948,metro_core,948
110588,21010.800000,110001,Hamilton Medical Center,Dalton,30720.0,GA,521,metro_core,521
185626,26276.583333,210043,Umd Baltimore Washington Medical Center,Glen Burnie,21061.0,MD,035,metro_core,035
257278,8260.636364,310050,Saint Clare's Hospital,Denville,7834.0,NJ,074,metro_core,074
...,...,...,...,...,...,...,...,...,...
368937,5235.333333,440059,Cookeville Regional Medical Center,Cookeville,38501.0,TN,065,micro_core,065
218913,5654.418605,240057,Abbott Northwestern Hospital,Minneapolis,55407.0,MN,683,metro_core,683
368799,5661.757576,440050,Greeneville Community Hospital,Greeneville,37745.0,TN,872,micro_core,872
347420,5743.490196,390195,Lankenau Medical Center,Wynnewood,19096.0,PA,690,metro_core,690


In [45]:
from sklearn.model_selection import train_test_split

# Split into train/val/test
train_val, test = train_test_split(df_sampled, test_size=0.15, random_state=42)
train, val = train_test_split(train_val, test_size=0.15, random_state=42)

# Save to CSV for SageMaker (no headers, no index)
train.to_csv("train.csv", index=False, header=False)
val.to_csv("val.csv", index=False, header=False)
test.to_csv("test.csv", index=False, header=False)

In [47]:
train.head()

Unnamed: 0,Avg_Mdcr_Pymt_Amt,Rndrng_Prvdr_CCN,Rndrng_Prvdr_Org_Name,Rndrng_Prvdr_City,Rndrng_Prvdr_Zip5,Rndrng_Prvdr_State_Abrvtn,DRG_Cd,RUCA_category,drg_grouped
251154,8803.666667,310011,"Cape Regional Medical Center, Inc",Cape May Court House,8210.0,NJ,193,metro_core,193
347384,13428.642857,390195,Lankenau Medical Center,Wynnewood,19096.0,PA,331,metro_core,331
287764,12055.657895,330332,St Joseph Hospital,Bethpage,11714.0,NY,698,metro_core,698
395783,21941.461538,450686,University Medical Center,Lubbock,79415.0,TX,246,metro_core,246
227332,18942.066667,250138,Merit Health River Oaks,Flowood,39232.0,MS,246,metro_core,246
