### Decision tree Estimator on merged_emp data 

- This Python notebook demonstrates creating an ML Pipeline to preprocess a dataset, train a Machine Learning model, and make predictions.

- Data: The dataset contains    merged_emp data  

- Goal: We want to learn to predict  deaths  from features such as patient data (comprising of Adult_smoking_percentage,	Adult_obesity_percentage and Excessive_drinking_percentage), various healthcare data like hospitals, physician, nurses, surgean specialities, etc., socio econonmic factors like life rank, quality of life, social-ecomonica factors rank, physical environment rank, etc. and other features like popluation, etc.  

- Approach: We will use Spark ML Pipelines, which help users piece together parts of a workflow such as feature processing and model training. We will also demonstrate model selection (a.k.a. hyperparameter tuning) using Cross Validation in order to fine-tune and improve our ML model.

In [2]:
%sh 
pip install mleap

In [3]:
from pyspark.sql.functions import log

In [4]:
# read data from the output of exploratory data analysis which is in parquet file format
df = spark.read.parquet('/databricks/driver/spark_merged_emp_modeling1.parquet')

In [5]:
display(df)

county,cases,deaths,POP_ESTIMATE_2018,All_Specialties_AAMC,Total_nurse_practitioners_2019,Total_physician_assistants_2019,Total_Hospitals_2019,Total_Primary_Care_Physicians_2019,Surgery_specialists_2019,Emergency_Medicine_specialists_2019,Total_Specialist_Physicians_2019,ICU_Beds,pop_fraction,Length_of_Life_rank,Quality_of_Life_rank,Health_Behaviors_rank,Clinical_Care_rank,Social_Economic_Factors_rank,Physical_Environment_rank,Adult_smoking_percentage,Adult_obesity_percentage,Excessive_drinking_percentage,Population_per_sq_mile,House_per_sq_mile,Unemployed_2019,Unemployment_rate_2019,Median_Household_Income_2018,Med_HH_Income_Percent_of_State_Total_2018
Lee,7.117022517792556,0.5833625014584063,17142.0,2.2945925621281065,0.4171807667133357,0.2063284414304048,0.0135716515575778,1.257836399486641,0.1581392439621981,0.1431907581962431,1.269047763971532,0.0,0.001685835,41.0,46.0,44.0,41.0,42.0,41.0,21.0,40.0,13.0,46.9,19.0,266.0,4.0,33199.0,63.29767965
Lexington,1.667615716261287,0.0847365709482361,295032.0,2.2945925621627485,0.4171807668320725,0.2063284414571978,0.013571651532715,1.2578363994414166,0.158139243980314,0.1431907582228368,1.2690477636324196,0.1864204560861194,0.029015011,9.0,10.0,11.0,6.0,3.0,7.0,16.0,32.0,19.0,375.4,163.0,3510.0,2.3,60627.0,115.5922897
Marion,1.5464415735043011,0.1288701311253584,31039.0,2.294592562260382,0.4171807667772801,0.2063284414446341,0.0135716515351654,1.2578363993685362,0.1581392439511582,0.1431907582074164,1.2690477637810496,0.322175327813396,0.00305254,40.0,45.0,41.0,44.0,43.0,19.0,20.0,44.0,14.0,67.6,30.6,564.0,4.3,34365.0,65.52079163
Marlboro,2.272899462080461,0.0,26398.0,2.294592562315327,0.4171807667247518,0.2063284414349571,0.0135716515266308,1.2578363993484354,0.1581392439578756,0.1431907582013789,1.2690477638457458,0.0,0.002596119,43.0,39.0,40.0,39.0,45.0,44.0,21.0,39.0,15.0,60.3,25.2,444.0,4.8,33534.0,63.93639536
McCormick,0.7438894792773645,0.106269925611052,9410.0,2.294592562167906,0.4171807667375132,0.206328441445271,0.0135716515409139,1.2578363995749204,0.1581392439957492,0.1431907582359192,1.2690477640807651,0.0,0.000925429,46.0,25.0,9.0,4.0,34.0,13.0,17.0,34.0,14.0,28.5,15.2,104.0,3.1,44683.0,85.19323533
Newberry,0.8047767393561787,0.0259605399792315,38520.0,2.294592562305296,0.4171807666147455,0.2063284414330218,0.0135716515316718,1.2578363995327102,0.1581392439771547,0.1431907582294911,1.269047763759086,0.181723779854621,0.003788261,10.0,21.0,23.0,21.0,15.0,16.0,18.0,36.0,16.0,59.5,28.4,514.0,2.7,48029.0,91.57276592
Oconee,0.4976140046443973,0.0,78374.0,2.2945925625845303,0.4171807667083471,0.2063284414474188,0.013571651529844,1.257836399443693,0.1581392440094929,0.1431907581595937,1.2690477637992192,0.1531120014290453,0.007707715,16.0,4.0,14.0,12.0,12.0,15.0,17.0,32.0,16.0,118.6,61.9,971.0,2.8,50529.0,96.33930104
Orangeburg,1.2193158027929234,0.0230059585432627,86934.0,2.294592562173603,0.4171807666735684,0.2063284414613384,0.0135716515287459,1.2578363988773091,0.1581392439091724,0.1431907581613637,1.2690477638208295,0.230059585432627,0.00854955,34.0,38.0,27.0,26.0,38.0,37.0,17.0,41.0,13.0,83.6,38.4,1655.0,4.8,37134.0,70.80020591
Pickens,0.6803428928179802,0.0080040340331527,124937.0,2.294592562651576,0.41718076670642,0.206328441454493,0.0135716515363743,1.257836399145169,0.1581392439389452,0.1431907582221439,1.2690477640730926,0.1200605104972906,0.012286967,12.0,8.0,6.0,7.0,9.0,33.0,17.0,32.0,18.0,240.2,103.2,1604.0,2.8,48794.0,93.03132567
Richland,2.814924163482691,0.142314075103238,414576.0,2.294592562280499,0.4171807666145652,0.2063284414437883,0.0135716515331326,1.2578363993574158,0.1581392439745667,0.1431907582204469,1.2690477637393385,0.5089537262166648,0.040771601,11.0,14.0,17.0,3.0,21.0,27.0,16.0,33.0,19.0,507.9,213.6,5579.0,2.8,52611.0,100.3088715


In [6]:
df_temp = df
from pyspark.sql.functions import col

#renaming column deaths in order to rearrange it to the end of the table.
df = df.withColumnRenamed('deaths', 'zlabel')

In [7]:
#sort columns according to column name
df = df.select(sorted(df.columns))

In [8]:
df.count()

In [9]:
df.dtypes

In [10]:
display(df)

Adult_obesity_percentage,Adult_smoking_percentage,All_Specialties_AAMC,Clinical_Care_rank,Emergency_Medicine_specialists_2019,Excessive_drinking_percentage,Health_Behaviors_rank,House_per_sq_mile,ICU_Beds,Length_of_Life_rank,Med_HH_Income_Percent_of_State_Total_2018,Median_Household_Income_2018,POP_ESTIMATE_2018,Physical_Environment_rank,Population_per_sq_mile,Quality_of_Life_rank,Social_Economic_Factors_rank,Surgery_specialists_2019,Total_Hospitals_2019,Total_Primary_Care_Physicians_2019,Total_Specialist_Physicians_2019,Total_nurse_practitioners_2019,Total_physician_assistants_2019,Unemployed_2019,Unemployment_rate_2019,cases,county,pop_fraction,zlabel
40.0,21.0,2.2945925621281065,41.0,0.1431907581962431,13.0,44.0,19.0,0.0,41.0,63.29767965,33199.0,17142.0,41.0,46.9,46.0,42.0,0.1581392439621981,0.0135716515575778,1.257836399486641,1.269047763971532,0.4171807667133357,0.2063284414304048,266.0,4.0,7.117022517792556,Lee,0.001685835,0.5833625014584063
32.0,16.0,2.2945925621627485,6.0,0.1431907582228368,19.0,11.0,163.0,0.1864204560861194,9.0,115.5922897,60627.0,295032.0,7.0,375.4,10.0,3.0,0.158139243980314,0.013571651532715,1.2578363994414166,1.2690477636324196,0.4171807668320725,0.2063284414571978,3510.0,2.3,1.667615716261287,Lexington,0.029015011,0.0847365709482361
44.0,20.0,2.294592562260382,44.0,0.1431907582074164,14.0,41.0,30.6,0.322175327813396,40.0,65.52079163,34365.0,31039.0,19.0,67.6,45.0,43.0,0.1581392439511582,0.0135716515351654,1.2578363993685362,1.2690477637810496,0.4171807667772801,0.2063284414446341,564.0,4.3,1.5464415735043011,Marion,0.00305254,0.1288701311253584
39.0,21.0,2.294592562315327,39.0,0.1431907582013789,15.0,40.0,25.2,0.0,43.0,63.93639536,33534.0,26398.0,44.0,60.3,39.0,45.0,0.1581392439578756,0.0135716515266308,1.2578363993484354,1.2690477638457458,0.4171807667247518,0.2063284414349571,444.0,4.8,2.272899462080461,Marlboro,0.002596119,0.0
34.0,17.0,2.294592562167906,4.0,0.1431907582359192,14.0,9.0,15.2,0.0,46.0,85.19323533,44683.0,9410.0,13.0,28.5,25.0,34.0,0.1581392439957492,0.0135716515409139,1.2578363995749204,1.2690477640807651,0.4171807667375132,0.206328441445271,104.0,3.1,0.7438894792773645,McCormick,0.000925429,0.106269925611052
36.0,18.0,2.294592562305296,21.0,0.1431907582294911,16.0,23.0,28.4,0.181723779854621,10.0,91.57276592,48029.0,38520.0,16.0,59.5,21.0,15.0,0.1581392439771547,0.0135716515316718,1.2578363995327102,1.269047763759086,0.4171807666147455,0.2063284414330218,514.0,2.7,0.8047767393561787,Newberry,0.003788261,0.0259605399792315
32.0,17.0,2.2945925625845303,12.0,0.1431907581595937,16.0,14.0,61.9,0.1531120014290453,16.0,96.33930104,50529.0,78374.0,15.0,118.6,4.0,12.0,0.1581392440094929,0.013571651529844,1.257836399443693,1.2690477637992192,0.4171807667083471,0.2063284414474188,971.0,2.8,0.4976140046443973,Oconee,0.007707715,0.0
41.0,17.0,2.294592562173603,26.0,0.1431907581613637,13.0,27.0,38.4,0.230059585432627,34.0,70.80020591,37134.0,86934.0,37.0,83.6,38.0,38.0,0.1581392439091724,0.0135716515287459,1.2578363988773091,1.2690477638208295,0.4171807666735684,0.2063284414613384,1655.0,4.8,1.2193158027929234,Orangeburg,0.00854955,0.0230059585432627
32.0,17.0,2.294592562651576,7.0,0.1431907582221439,18.0,6.0,103.2,0.1200605104972906,12.0,93.03132567,48794.0,124937.0,33.0,240.2,8.0,9.0,0.1581392439389452,0.0135716515363743,1.257836399145169,1.2690477640730926,0.41718076670642,0.206328441454493,1604.0,2.8,0.6803428928179802,Pickens,0.012286967,0.0080040340331527
33.0,16.0,2.294592562280499,3.0,0.1431907582204469,19.0,17.0,213.6,0.5089537262166648,11.0,100.3088715,52611.0,414576.0,27.0,507.9,14.0,21.0,0.1581392439745667,0.0135716515331326,1.2578363993574158,1.2690477637393385,0.4171807666145652,0.2063284414437883,5579.0,2.8,2.814924163482691,Richland,0.040771601,0.142314075103238


In [11]:
#rename column back to label
df = df.withColumnRenamed('zlabel', 'label')

In [12]:
display(df)

Adult_obesity_percentage,Adult_smoking_percentage,All_Specialties_AAMC,Clinical_Care_rank,Emergency_Medicine_specialists_2019,Excessive_drinking_percentage,Health_Behaviors_rank,House_per_sq_mile,ICU_Beds,Length_of_Life_rank,Med_HH_Income_Percent_of_State_Total_2018,Median_Household_Income_2018,POP_ESTIMATE_2018,Physical_Environment_rank,Population_per_sq_mile,Quality_of_Life_rank,Social_Economic_Factors_rank,Surgery_specialists_2019,Total_Hospitals_2019,Total_Primary_Care_Physicians_2019,Total_Specialist_Physicians_2019,Total_nurse_practitioners_2019,Total_physician_assistants_2019,Unemployed_2019,Unemployment_rate_2019,cases,county,pop_fraction,label
40.0,21.0,2.2945925621281065,41.0,0.1431907581962431,13.0,44.0,19.0,0.0,41.0,63.29767965,33199.0,17142.0,41.0,46.9,46.0,42.0,0.1581392439621981,0.0135716515575778,1.257836399486641,1.269047763971532,0.4171807667133357,0.2063284414304048,266.0,4.0,7.117022517792556,Lee,0.001685835,0.5833625014584063
32.0,16.0,2.2945925621627485,6.0,0.1431907582228368,19.0,11.0,163.0,0.1864204560861194,9.0,115.5922897,60627.0,295032.0,7.0,375.4,10.0,3.0,0.158139243980314,0.013571651532715,1.2578363994414166,1.2690477636324196,0.4171807668320725,0.2063284414571978,3510.0,2.3,1.667615716261287,Lexington,0.029015011,0.0847365709482361
44.0,20.0,2.294592562260382,44.0,0.1431907582074164,14.0,41.0,30.6,0.322175327813396,40.0,65.52079163,34365.0,31039.0,19.0,67.6,45.0,43.0,0.1581392439511582,0.0135716515351654,1.2578363993685362,1.2690477637810496,0.4171807667772801,0.2063284414446341,564.0,4.3,1.5464415735043011,Marion,0.00305254,0.1288701311253584
39.0,21.0,2.294592562315327,39.0,0.1431907582013789,15.0,40.0,25.2,0.0,43.0,63.93639536,33534.0,26398.0,44.0,60.3,39.0,45.0,0.1581392439578756,0.0135716515266308,1.2578363993484354,1.2690477638457458,0.4171807667247518,0.2063284414349571,444.0,4.8,2.272899462080461,Marlboro,0.002596119,0.0
34.0,17.0,2.294592562167906,4.0,0.1431907582359192,14.0,9.0,15.2,0.0,46.0,85.19323533,44683.0,9410.0,13.0,28.5,25.0,34.0,0.1581392439957492,0.0135716515409139,1.2578363995749204,1.2690477640807651,0.4171807667375132,0.206328441445271,104.0,3.1,0.7438894792773645,McCormick,0.000925429,0.106269925611052
36.0,18.0,2.294592562305296,21.0,0.1431907582294911,16.0,23.0,28.4,0.181723779854621,10.0,91.57276592,48029.0,38520.0,16.0,59.5,21.0,15.0,0.1581392439771547,0.0135716515316718,1.2578363995327102,1.269047763759086,0.4171807666147455,0.2063284414330218,514.0,2.7,0.8047767393561787,Newberry,0.003788261,0.0259605399792315
32.0,17.0,2.2945925625845303,12.0,0.1431907581595937,16.0,14.0,61.9,0.1531120014290453,16.0,96.33930104,50529.0,78374.0,15.0,118.6,4.0,12.0,0.1581392440094929,0.013571651529844,1.257836399443693,1.2690477637992192,0.4171807667083471,0.2063284414474188,971.0,2.8,0.4976140046443973,Oconee,0.007707715,0.0
41.0,17.0,2.294592562173603,26.0,0.1431907581613637,13.0,27.0,38.4,0.230059585432627,34.0,70.80020591,37134.0,86934.0,37.0,83.6,38.0,38.0,0.1581392439091724,0.0135716515287459,1.2578363988773091,1.2690477638208295,0.4171807666735684,0.2063284414613384,1655.0,4.8,1.2193158027929234,Orangeburg,0.00854955,0.0230059585432627
32.0,17.0,2.294592562651576,7.0,0.1431907582221439,18.0,6.0,103.2,0.1200605104972906,12.0,93.03132567,48794.0,124937.0,33.0,240.2,8.0,9.0,0.1581392439389452,0.0135716515363743,1.257836399145169,1.2690477640730926,0.41718076670642,0.206328441454493,1604.0,2.8,0.6803428928179802,Pickens,0.012286967,0.0080040340331527
33.0,16.0,2.294592562280499,3.0,0.1431907582204469,19.0,17.0,213.6,0.5089537262166648,11.0,100.3088715,52611.0,414576.0,27.0,507.9,14.0,21.0,0.1581392439745667,0.0135716515331326,1.2578363993574158,1.2690477637393385,0.4171807666145652,0.2063284414437883,5579.0,2.8,2.814924163482691,Richland,0.040771601,0.142314075103238


In [13]:
df.where( df['label'].isNull() ).count()

In [14]:
df.where( df['pop_fraction'].isNull() ).count()

In [15]:
#define a function to identify the count of nulls in the columns
def count_nulls(df):
    null_counts = []          #make an empty list to hold our results
    for col in df.dtypes:     #iterate through the column data types we saw above, e.g. ('C0', 'bigint')
        cname = col[0]        #splits out the column name, e.g. 'C0'    
        ctype = col[1]        #splits out the column type, e.g. 'bigint'
        if ctype != 'string': #skip processing string columns for efficiency (can't have nulls)
            nulls = df.where( df[cname].isNull() ).count()
            result = tuple([cname, nulls])  #new tuple, (column name, null count)
            null_counts.append(result)      #put the new tuple in our result list
    return null_counts

null_counts = count_nulls(df)

In [16]:
null_counts

In [17]:
display(df)

Adult_obesity_percentage,Adult_smoking_percentage,All_Specialties_AAMC,Clinical_Care_rank,Emergency_Medicine_specialists_2019,Excessive_drinking_percentage,Health_Behaviors_rank,House_per_sq_mile,ICU_Beds,Length_of_Life_rank,Med_HH_Income_Percent_of_State_Total_2018,Median_Household_Income_2018,POP_ESTIMATE_2018,Physical_Environment_rank,Population_per_sq_mile,Quality_of_Life_rank,Social_Economic_Factors_rank,Surgery_specialists_2019,Total_Hospitals_2019,Total_Primary_Care_Physicians_2019,Total_Specialist_Physicians_2019,Total_nurse_practitioners_2019,Total_physician_assistants_2019,Unemployed_2019,Unemployment_rate_2019,cases,county,pop_fraction,label
40.0,21.0,2.2945925621281065,41.0,0.1431907581962431,13.0,44.0,19.0,0.0,41.0,63.29767965,33199.0,17142.0,41.0,46.9,46.0,42.0,0.1581392439621981,0.0135716515575778,1.257836399486641,1.269047763971532,0.4171807667133357,0.2063284414304048,266.0,4.0,7.117022517792556,Lee,0.001685835,0.5833625014584063
32.0,16.0,2.2945925621627485,6.0,0.1431907582228368,19.0,11.0,163.0,0.1864204560861194,9.0,115.5922897,60627.0,295032.0,7.0,375.4,10.0,3.0,0.158139243980314,0.013571651532715,1.2578363994414166,1.2690477636324196,0.4171807668320725,0.2063284414571978,3510.0,2.3,1.667615716261287,Lexington,0.029015011,0.0847365709482361
44.0,20.0,2.294592562260382,44.0,0.1431907582074164,14.0,41.0,30.6,0.322175327813396,40.0,65.52079163,34365.0,31039.0,19.0,67.6,45.0,43.0,0.1581392439511582,0.0135716515351654,1.2578363993685362,1.2690477637810496,0.4171807667772801,0.2063284414446341,564.0,4.3,1.5464415735043011,Marion,0.00305254,0.1288701311253584
39.0,21.0,2.294592562315327,39.0,0.1431907582013789,15.0,40.0,25.2,0.0,43.0,63.93639536,33534.0,26398.0,44.0,60.3,39.0,45.0,0.1581392439578756,0.0135716515266308,1.2578363993484354,1.2690477638457458,0.4171807667247518,0.2063284414349571,444.0,4.8,2.272899462080461,Marlboro,0.002596119,0.0
34.0,17.0,2.294592562167906,4.0,0.1431907582359192,14.0,9.0,15.2,0.0,46.0,85.19323533,44683.0,9410.0,13.0,28.5,25.0,34.0,0.1581392439957492,0.0135716515409139,1.2578363995749204,1.2690477640807651,0.4171807667375132,0.206328441445271,104.0,3.1,0.7438894792773645,McCormick,0.000925429,0.106269925611052
36.0,18.0,2.294592562305296,21.0,0.1431907582294911,16.0,23.0,28.4,0.181723779854621,10.0,91.57276592,48029.0,38520.0,16.0,59.5,21.0,15.0,0.1581392439771547,0.0135716515316718,1.2578363995327102,1.269047763759086,0.4171807666147455,0.2063284414330218,514.0,2.7,0.8047767393561787,Newberry,0.003788261,0.0259605399792315
32.0,17.0,2.2945925625845303,12.0,0.1431907581595937,16.0,14.0,61.9,0.1531120014290453,16.0,96.33930104,50529.0,78374.0,15.0,118.6,4.0,12.0,0.1581392440094929,0.013571651529844,1.257836399443693,1.2690477637992192,0.4171807667083471,0.2063284414474188,971.0,2.8,0.4976140046443973,Oconee,0.007707715,0.0
41.0,17.0,2.294592562173603,26.0,0.1431907581613637,13.0,27.0,38.4,0.230059585432627,34.0,70.80020591,37134.0,86934.0,37.0,83.6,38.0,38.0,0.1581392439091724,0.0135716515287459,1.2578363988773091,1.2690477638208295,0.4171807666735684,0.2063284414613384,1655.0,4.8,1.2193158027929234,Orangeburg,0.00854955,0.0230059585432627
32.0,17.0,2.294592562651576,7.0,0.1431907582221439,18.0,6.0,103.2,0.1200605104972906,12.0,93.03132567,48794.0,124937.0,33.0,240.2,8.0,9.0,0.1581392439389452,0.0135716515363743,1.257836399145169,1.2690477640730926,0.41718076670642,0.206328441454493,1604.0,2.8,0.6803428928179802,Pickens,0.012286967,0.0080040340331527
33.0,16.0,2.294592562280499,3.0,0.1431907582204469,19.0,17.0,213.6,0.5089537262166648,11.0,100.3088715,52611.0,414576.0,27.0,507.9,14.0,21.0,0.1581392439745667,0.0135716515331326,1.2578363993574158,1.2690477637393385,0.4171807666145652,0.2063284414437883,5579.0,2.8,2.814924163482691,Richland,0.040771601,0.142314075103238


In [18]:

df = df.drop("pop_fraction")
df_temp=df_temp.drop("pop_fraction")

In [19]:
df_described = df.describe()
df_described.show()

In [20]:
from pyspark.sql.functions import skewness, kurtosis
from pyspark.sql.functions import var_pop, var_samp, stddev, stddev_pop, sumDistinct, ntile
df.select(skewness('label')).show()


In [21]:
from pyspark.sql import Row

columns = df_described.columns  #list of column names
funcs   = [skewness, kurtosis]  #list of functions we want to include (imported earlier)
fnames  = ['skew', 'kurtosis']  #a list of strings describing the functions in the same order

def new_item(func, column):
    """
    This function takes in an aggregation function and a column name, then applies the aggregation to the
    column, collects it and returns a value.  The value is in string format despite being a number, 
    because that matches the output of describe.
    """
    return str(df.select(func(column)).collect()[0][0])

new_data = []
for func, fname in zip(funcs, fnames):
    row_dict = {'summary':fname}  #each row object begins with an entry for "summary"
    for column in columns[1:]:
        row_dict[column] = new_item(func, column)
    new_data.append(Row(**row_dict))  #using ** tells Python to unpack the entries of the dictionary
    
print(new_data)

In [22]:
df_described.collect()

In [23]:
new_describe = sc.parallelize(new_data).toDF()           #turns the results from our loop into a dataframe
new_describe = new_describe.select(df_described.columns) #forces the columns into the same order

expanded_describe = df_described.unionAll(new_describe)  #merges the new stats with the original describe
expanded_describe.show()


In [24]:
label = df[['label']].collect()


In [25]:
print(label[:5])


In [26]:
df.columns


In [27]:
df.dtypes

In [28]:
#drop county since it is string and is not a numerical feature
df = df.drop("county")
df_temp=df_temp.drop("county")


In [29]:
df = df.dropna()
df_temp=df_temp.dropna()

In [30]:
display(df)

Adult_obesity_percentage,Adult_smoking_percentage,All_Specialties_AAMC,Clinical_Care_rank,Emergency_Medicine_specialists_2019,Excessive_drinking_percentage,Health_Behaviors_rank,House_per_sq_mile,ICU_Beds,Length_of_Life_rank,Med_HH_Income_Percent_of_State_Total_2018,Median_Household_Income_2018,POP_ESTIMATE_2018,Physical_Environment_rank,Population_per_sq_mile,Quality_of_Life_rank,Social_Economic_Factors_rank,Surgery_specialists_2019,Total_Hospitals_2019,Total_Primary_Care_Physicians_2019,Total_Specialist_Physicians_2019,Total_nurse_practitioners_2019,Total_physician_assistants_2019,Unemployed_2019,Unemployment_rate_2019,cases,label
40.0,21.0,2.2945925621281065,41.0,0.1431907581962431,13.0,44.0,19.0,0.0,41.0,63.29767965,33199.0,17142.0,41.0,46.9,46.0,42.0,0.1581392439621981,0.0135716515575778,1.257836399486641,1.269047763971532,0.4171807667133357,0.2063284414304048,266.0,4.0,7.117022517792556,0.5833625014584063
32.0,16.0,2.2945925621627485,6.0,0.1431907582228368,19.0,11.0,163.0,0.1864204560861194,9.0,115.5922897,60627.0,295032.0,7.0,375.4,10.0,3.0,0.158139243980314,0.013571651532715,1.2578363994414166,1.2690477636324196,0.4171807668320725,0.2063284414571978,3510.0,2.3,1.667615716261287,0.0847365709482361
44.0,20.0,2.294592562260382,44.0,0.1431907582074164,14.0,41.0,30.6,0.322175327813396,40.0,65.52079163,34365.0,31039.0,19.0,67.6,45.0,43.0,0.1581392439511582,0.0135716515351654,1.2578363993685362,1.2690477637810496,0.4171807667772801,0.2063284414446341,564.0,4.3,1.5464415735043011,0.1288701311253584
39.0,21.0,2.294592562315327,39.0,0.1431907582013789,15.0,40.0,25.2,0.0,43.0,63.93639536,33534.0,26398.0,44.0,60.3,39.0,45.0,0.1581392439578756,0.0135716515266308,1.2578363993484354,1.2690477638457458,0.4171807667247518,0.2063284414349571,444.0,4.8,2.272899462080461,0.0
34.0,17.0,2.294592562167906,4.0,0.1431907582359192,14.0,9.0,15.2,0.0,46.0,85.19323533,44683.0,9410.0,13.0,28.5,25.0,34.0,0.1581392439957492,0.0135716515409139,1.2578363995749204,1.2690477640807651,0.4171807667375132,0.206328441445271,104.0,3.1,0.7438894792773645,0.106269925611052
36.0,18.0,2.294592562305296,21.0,0.1431907582294911,16.0,23.0,28.4,0.181723779854621,10.0,91.57276592,48029.0,38520.0,16.0,59.5,21.0,15.0,0.1581392439771547,0.0135716515316718,1.2578363995327102,1.269047763759086,0.4171807666147455,0.2063284414330218,514.0,2.7,0.8047767393561787,0.0259605399792315
32.0,17.0,2.2945925625845303,12.0,0.1431907581595937,16.0,14.0,61.9,0.1531120014290453,16.0,96.33930104,50529.0,78374.0,15.0,118.6,4.0,12.0,0.1581392440094929,0.013571651529844,1.257836399443693,1.2690477637992192,0.4171807667083471,0.2063284414474188,971.0,2.8,0.4976140046443973,0.0
41.0,17.0,2.294592562173603,26.0,0.1431907581613637,13.0,27.0,38.4,0.230059585432627,34.0,70.80020591,37134.0,86934.0,37.0,83.6,38.0,38.0,0.1581392439091724,0.0135716515287459,1.2578363988773091,1.2690477638208295,0.4171807666735684,0.2063284414613384,1655.0,4.8,1.2193158027929234,0.0230059585432627
32.0,17.0,2.294592562651576,7.0,0.1431907582221439,18.0,6.0,103.2,0.1200605104972906,12.0,93.03132567,48794.0,124937.0,33.0,240.2,8.0,9.0,0.1581392439389452,0.0135716515363743,1.257836399145169,1.2690477640730926,0.41718076670642,0.206328441454493,1604.0,2.8,0.6803428928179802,0.0080040340331527
33.0,16.0,2.294592562280499,3.0,0.1431907582204469,19.0,17.0,213.6,0.5089537262166648,11.0,100.3088715,52611.0,414576.0,27.0,507.9,14.0,21.0,0.1581392439745667,0.0135716515331326,1.2578363993574158,1.2690477637393385,0.4171807666145652,0.2063284414437883,5579.0,2.8,2.814924163482691,0.142314075103238


In [31]:
df.count()

In [32]:
df_temp.count()

In [33]:
from pyspark.ml.linalg import Vectors
from pyspark.ml import Pipeline
from pyspark.ml.regression import GBTRegressor, GeneralizedLinearRegression, AFTSurvivalRegression
from pyspark.ml.feature import VectorIndexer

from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql.types import DoubleType
 
from pyspark.ml.feature import StringIndexer, VectorAssembler, OneHotEncoder  

In [34]:
valuableColumns = list(df.columns)

In [35]:
valuableColumns

In [36]:
from pyspark.ml import Pipeline
from pyspark.ml.regression import DecisionTreeRegressor
from pyspark.ml.feature import VectorIndexer
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.sql import Row
from pyspark.ml.linalg import Vectors

In [37]:
df=df.rdd.map(lambda x:(Vectors.dense(x[0:-1]), x[-1])).toDF(["features", "label"])
df.show()

In [38]:
# Automatically identify categorical features, and index them.
# We specify maxCategories so features with > 4 distinct values are treated as continuous.
featureIndexer = VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(df)

In [39]:
# Split the data into training and test sets (30% held out for testing)
(trainingData, testData) = df.randomSplit([0.7, 0.3])

In [40]:
# Train a DecisionTree model.
dt = DecisionTreeRegressor(featuresCol="indexedFeatures")

In [41]:
# Chain indexer and tree in a Pipeline
pipeline = Pipeline(stages=[featureIndexer, dt])

In [42]:
# Train model.  This also runs the indexer.
model = pipeline.fit(trainingData)


In [43]:
# Make predictions.
predictions = model.transform(testData)

# Select example rows to display.
predictions.select("prediction", "label", "features").show(5)

In [44]:
display(predictions)

features,label,indexedFeatures,prediction
"List(1, 26, List(), List(22.0, 12.0, 2.2483920462417006, 3.0, 0.11354670752909125, 25.0, 18.0, 445.6, 0.23543675520103016, 5.0, 125.7669218, 76255.0, 1248743.0, 215.0, 1034.4, 9.0, 30.0, 0.12114203811352696, 0.01822182511533598, 1.069757014053332, 1.1810390584772046, 0.34492556138452823, 0.20357576314742104, 19047.0, 2.6, 1.8322425030610783))",0.0576579808655584,"List(1, 26, List(), List(22.0, 12.0, 2.2483920462417006, 3.0, 0.11354670752909125, 25.0, 18.0, 445.6, 0.23543675520103016, 5.0, 125.7669218, 76255.0, 1248743.0, 215.0, 1034.4, 9.0, 30.0, 0.12114203811352696, 0.01822182511533598, 1.069757014053332, 1.1810390584772046, 0.34492556138452823, 0.20357576314742104, 19047.0, 2.6, 1.8322425030610783))",0.0937689851019587
"List(1, 26, List(), List(22.0, 15.0, 2.530720043240164, 1.0, 0.12289494400787127, 16.0, 1.0, 117.6, 0.12083079804426723, 1.0, 221.3841042, 115930.0, 231729.0, 39.0, 314.4, 1.0, 1.0, 0.17326414583414246, 0.016986680963539304, 1.3110763499605143, 1.492316850286326, 0.8490386278799805, 0.17607064096422978, 3158.0, 2.4, 1.9548696969304662))",0.0431538564443811,"List(1, 26, List(), List(22.0, 15.0, 2.530720043240164, 1.0, 0.12289494400787127, 16.0, 1.0, 117.6, 0.12083079804426723, 1.0, 221.3841042, 115930.0, 231729.0, 39.0, 314.4, 1.0, 1.0, 0.17326414583414246, 0.016986680963539304, 1.3110763499605143, 1.492316850286326, 0.8490386278799805, 0.17607064096422978, 3158.0, 2.4, 1.9548696969304662))",0.0937689851019587
"List(1, 26, List(), List(22.0, 16.0, 2.248392046076314, 234.0, 0.11354670770338372, 15.0, 74.0, 1.1, 0.0, 109.0, 63.17456129, 38304.0, 1389.0, 2.0, 1.7, 189.0, 220.0, 0.12114203815694746, 0.01822182505399568, 1.0697570136789059, 1.1810390583153347, 0.3449255615550756, 0.20357576313894887, 23.0, 4.2, 2.8797696184305255))",0.7199424046076314,"List(1, 26, List(), List(22.0, 16.0, 2.248392046076314, 234.0, 0.11354670770338372, 15.0, 74.0, 1.1, 0.0, 109.0, 63.17456129, 38304.0, 1389.0, 2.0, 1.7, 189.0, 220.0, 0.12114203815694746, 0.01822182505399568, 1.0697570136789059, 1.1810390583153347, 0.3449255615550756, 0.20357576313894887, 23.0, 4.2, 2.8797696184305255))",0.4894762604013705
"List(1, 26, List(), List(22.0, 23.0, 2.530720043588812, 90.0, 0.12289494406102433, 17.0, 34.0, 29.5, 0.0, 53.0, 94.10686323, 49280.0, 11012.0, 21.0, 68.9, 35.0, 25.0, 0.17326414584090083, 0.016986680984380675, 1.311076350345078, 1.4923168507083182, 0.8490386276788958, 0.17607064093715946, 168.0, 3.1, 125.40864511442064))",0.2724300762804214,"List(1, 26, List(), List(22.0, 23.0, 2.530720043588812, 90.0, 0.12289494406102433, 17.0, 34.0, 29.5, 0.0, 53.0, 94.10686323, 49280.0, 11012.0, 21.0, 68.9, 35.0, 25.0, 0.17326414584090083, 0.016986680984380675, 1.311076350345078, 1.4923168507083182, 0.8490386276788958, 0.17607064093715946, 168.0, 3.1, 125.40864511442064))",0.2221527891999234
"List(1, 26, List(), List(24.0, 14.0, 2.4041213511070945, 16.0, 0.06574212065070041, 18.0, 23.0, 2.3, 0.0, 31.0, 87.87188163, 50123.0, 2213.0, 40.0, 4.2, 17.0, 8.0, 0.1394186357885224, 0.06460863624039766, 1.2275640845006779, 1.0858784789877993, 0.5361383303208315, 0.4579278761861726, 37.0, 3.1, 0.4518752824220515))",0.0,"List(1, 26, List(), List(24.0, 14.0, 2.4041213511070945, 16.0, 0.06574212065070041, 18.0, 23.0, 2.3, 0.0, 31.0, 87.87188163, 50123.0, 2213.0, 40.0, 4.2, 17.0, 8.0, 0.1394186357885224, 0.06460863624039766, 1.2275640845006779, 1.0858784789877993, 0.5361383303208315, 0.4579278761861726, 37.0, 3.1, 0.4518752824220515))",0.0105825245376147
"List(1, 26, List(), List(24.0, 15.0, 2.2483920458429676, 224.0, 0.11354670755001509, 20.0, 68.0, 4.4, 0.0, 31.0, 82.20906452, 49845.0, 9947.0, 58.0, 11.1, 51.0, 29.0, 0.12114203810194027, 0.018221825072886295, 1.0697570141751283, 1.1810390590127677, 0.34492556137528907, 0.20357576314466674, 117.0, 2.7, 4.021312958681009))",0.1005328239670252,"List(1, 26, List(), List(24.0, 15.0, 2.2483920458429676, 224.0, 0.11354670755001509, 20.0, 68.0, 4.4, 0.0, 31.0, 82.20906452, 49845.0, 9947.0, 58.0, 11.1, 51.0, 29.0, 0.12114203810194027, 0.018221825072886295, 1.0697570141751283, 1.1810390590127677, 0.34492556137528907, 0.20357576314466674, 117.0, 2.7, 4.021312958681009))",0.4894762604013705
"List(1, 26, List(), List(25.0, 13.0, 2.2483920452958106, 9.0, 0.1135467075328497, 22.0, 16.0, 291.6, 0.09428866766620415, 3.0, 145.7712099, 88384.0, 859064.0, 159.0, 754.3, 12.0, 5.0, 0.12114203807865305, 0.01822182511431046, 1.0697570139128167, 1.1810390587895663, 0.3449255614249928, 0.20357576315617926, 14619.0, 3.0, 1.1512529916280976))",0.0291014406377173,"List(1, 26, List(), List(25.0, 13.0, 2.2483920452958106, 9.0, 0.1135467075328497, 22.0, 16.0, 291.6, 0.09428866766620415, 3.0, 145.7712099, 88384.0, 859064.0, 159.0, 754.3, 12.0, 5.0, 0.12114203807865305, 0.01822182511431046, 1.0697570139128167, 1.1810390587895663, 0.3449255614249928, 0.20357576315617926, 14619.0, 3.0, 1.1512529916280976))",0.0384579348152506
"List(1, 26, List(), List(25.0, 13.0, 2.248392046307885, 241.0, 0.11354670755110553, 19.0, 21.0, 2.1, 0.0, 51.0, 102.2347935, 61987.0, 4794.0, 53.0, 5.6, 52.0, 62.0, 0.12114203817271588, 0.01822182519816437, 1.0697570139758033, 1.1810390586149355, 0.34492556132665836, 0.2035757632457238, 63.0, 4.0, 0.4171881518564873))",0.0,"List(1, 26, List(), List(25.0, 13.0, 2.248392046307885, 241.0, 0.11354670755110553, 19.0, 21.0, 2.1, 0.0, 51.0, 102.2347935, 61987.0, 4794.0, 53.0, 5.6, 52.0, 62.0, 0.12114203817271588, 0.01822182519816437, 1.0697570139758033, 1.1810390586149355, 0.34492556132665836, 0.2035757632457238, 63.0, 4.0, 0.4171881518564873))",0.0105825245376147
"List(1, 26, List(), List(25.0, 15.0, 2.2483920459895974, 213.0, 0.11354670763755818, 18.0, 73.0, 2.4, 0.0, 237.0, 65.5198575, 39726.0, 3653.0, 65.0, 4.4, 124.0, 85.0, 0.12114203805091707, 0.01822182507528059, 1.0697570139611279, 1.1810390585819874, 0.34492556145633724, 0.20357576320832194, 47.0, 3.3, 0.2737476047084588))",0.0,"List(1, 26, List(), List(25.0, 15.0, 2.2483920459895974, 213.0, 0.11354670763755818, 18.0, 73.0, 2.4, 0.0, 237.0, 65.5198575, 39726.0, 3653.0, 65.0, 4.4, 124.0, 85.0, 0.12114203805091707, 0.01822182507528059, 1.0697570139611279, 1.1810390585819874, 0.34492556145633724, 0.20357576320832194, 47.0, 3.3, 0.2737476047084588))",0.0105825245376147
"List(1, 26, List(), List(25.0, 18.0, 2.24839204583921, 215.0, 0.11354670768688294, 15.0, 123.0, 1.8, 0.0, 243.0, 70.71018604, 42873.0, 2836.0, 84.0, 4.0, 230.0, 202.0, 0.12114203808180535, 0.018221825105782793, 1.0697570137517631, 1.1810390585331454, 0.3449255613540198, 0.20357576304654443, 38.0, 3.4, 0.3526093088857546))",0.0,"List(1, 26, List(), List(25.0, 18.0, 2.24839204583921, 215.0, 0.11354670768688294, 15.0, 123.0, 1.8, 0.0, 243.0, 70.71018604, 42873.0, 2836.0, 84.0, 4.0, 230.0, 202.0, 0.12114203808180535, 0.018221825105782793, 1.0697570137517631, 1.1810390585331454, 0.3449255613540198, 0.20357576304654443, 38.0, 3.4, 0.3526093088857546))",0.0105825245376147


In [45]:
predictions.dtypes

In [46]:


pandas_df_temp = df_temp.toPandas()
pandas_predictions = predictions.toPandas()


In [47]:
pandas_predictions['prediction_de_nom'] = (pandas_predictions['prediction']*pandas_df_temp['POP_ESTIMATE_2018'])/1000


In [48]:
pandas_predictions['label_de_nom'] = (pandas_predictions['label']*pandas_df_temp['POP_ESTIMATE_2018'])/1000

In [49]:
pandas_predictions['prediction_de_nom'] = pandas_predictions['prediction_de_nom'].apply(lambda x: round(x))

In [50]:
pandas_predictions['label_de_nom'] = pandas_predictions['label_de_nom'].apply(lambda x: round(x))

In [51]:
display(pandas_predictions)

features,label,indexedFeatures,prediction,prediction_de_nom,label_de_nom
"List(1, 26, List(), List(22.0, 12.0, 2.2483920462417006, 3.0, 0.11354670752909125, 25.0, 18.0, 445.6, 0.23543675520103016, 5.0, 125.7669218, 76255.0, 1248743.0, 215.0, 1034.4, 9.0, 30.0, 0.12114203811352696, 0.01822182511533598, 1.069757014053332, 1.1810390584772046, 0.34492556138452823, 0.20357576314742104, 19047.0, 2.6, 1.8322425030610783))",0.0576579808655584,"List(1, 26, List(), List(22.0, 12.0, 2.2483920462417006, 3.0, 0.11354670752909125, 25.0, 18.0, 445.6, 0.23543675520103016, 5.0, 125.7669218, 76255.0, 1248743.0, 215.0, 1034.4, 9.0, 30.0, 0.12114203811352696, 0.01822182511533598, 1.069757014053332, 1.1810390584772046, 0.34492556138452823, 0.20357576314742104, 19047.0, 2.6, 1.8322425030610783))",0.0937689851019587,2,1
"List(1, 26, List(), List(22.0, 15.0, 2.530720043240164, 1.0, 0.12289494400787127, 16.0, 1.0, 117.6, 0.12083079804426723, 1.0, 221.3841042, 115930.0, 231729.0, 39.0, 314.4, 1.0, 1.0, 0.17326414583414246, 0.016986680963539304, 1.3110763499605143, 1.492316850286326, 0.8490386278799805, 0.17607064096422978, 3158.0, 2.4, 1.9548696969304662))",0.0431538564443811,"List(1, 26, List(), List(22.0, 15.0, 2.530720043240164, 1.0, 0.12289494400787127, 16.0, 1.0, 117.6, 0.12083079804426723, 1.0, 221.3841042, 115930.0, 231729.0, 39.0, 314.4, 1.0, 1.0, 0.17326414583414246, 0.016986680963539304, 1.3110763499605143, 1.492316850286326, 0.8490386278799805, 0.17607064096422978, 3158.0, 2.4, 1.9548696969304662))",0.0937689851019587,28,13
"List(1, 26, List(), List(22.0, 16.0, 2.248392046076314, 234.0, 0.11354670770338372, 15.0, 74.0, 1.1, 0.0, 109.0, 63.17456129, 38304.0, 1389.0, 2.0, 1.7, 189.0, 220.0, 0.12114203815694746, 0.01822182505399568, 1.0697570136789059, 1.1810390583153347, 0.3449255615550756, 0.20357576313894887, 23.0, 4.2, 2.8797696184305255))",0.7199424046076314,"List(1, 26, List(), List(22.0, 16.0, 2.248392046076314, 234.0, 0.11354670770338372, 15.0, 74.0, 1.1, 0.0, 109.0, 63.17456129, 38304.0, 1389.0, 2.0, 1.7, 189.0, 220.0, 0.12114203815694746, 0.01822182505399568, 1.0697570136789059, 1.1810390583153347, 0.3449255615550756, 0.20357576313894887, 23.0, 4.2, 2.8797696184305255))",0.4894762604013705,15,22
"List(1, 26, List(), List(22.0, 23.0, 2.530720043588812, 90.0, 0.12289494406102433, 17.0, 34.0, 29.5, 0.0, 53.0, 94.10686323, 49280.0, 11012.0, 21.0, 68.9, 35.0, 25.0, 0.17326414584090083, 0.016986680984380675, 1.311076350345078, 1.4923168507083182, 0.8490386276788958, 0.17607064093715946, 168.0, 3.1, 125.40864511442064))",0.2724300762804214,"List(1, 26, List(), List(22.0, 23.0, 2.530720043588812, 90.0, 0.12289494406102433, 17.0, 34.0, 29.5, 0.0, 53.0, 94.10686323, 49280.0, 11012.0, 21.0, 68.9, 35.0, 25.0, 0.17326414584090083, 0.016986680984380675, 1.311076350345078, 1.4923168507083182, 0.8490386276788958, 0.17607064093715946, 168.0, 3.1, 125.40864511442064))",0.2221527891999234,6,7
"List(1, 26, List(), List(24.0, 14.0, 2.4041213511070945, 16.0, 0.06574212065070041, 18.0, 23.0, 2.3, 0.0, 31.0, 87.87188163, 50123.0, 2213.0, 40.0, 4.2, 17.0, 8.0, 0.1394186357885224, 0.06460863624039766, 1.2275640845006779, 1.0858784789877993, 0.5361383303208315, 0.4579278761861726, 37.0, 3.1, 0.4518752824220515))",0.0,"List(1, 26, List(), List(24.0, 14.0, 2.4041213511070945, 16.0, 0.06574212065070041, 18.0, 23.0, 2.3, 0.0, 31.0, 87.87188163, 50123.0, 2213.0, 40.0, 4.2, 17.0, 8.0, 0.1394186357885224, 0.06460863624039766, 1.2275640845006779, 1.0858784789877993, 0.5361383303208315, 0.4579278761861726, 37.0, 3.1, 0.4518752824220515))",0.0105825245376147,0,0
"List(1, 26, List(), List(24.0, 15.0, 2.2483920458429676, 224.0, 0.11354670755001509, 20.0, 68.0, 4.4, 0.0, 31.0, 82.20906452, 49845.0, 9947.0, 58.0, 11.1, 51.0, 29.0, 0.12114203810194027, 0.018221825072886295, 1.0697570141751283, 1.1810390590127677, 0.34492556137528907, 0.20357576314466674, 117.0, 2.7, 4.021312958681009))",0.1005328239670252,"List(1, 26, List(), List(24.0, 15.0, 2.2483920458429676, 224.0, 0.11354670755001509, 20.0, 68.0, 4.4, 0.0, 31.0, 82.20906452, 49845.0, 9947.0, 58.0, 11.1, 51.0, 29.0, 0.12114203810194027, 0.018221825072886295, 1.0697570141751283, 1.1810390590127677, 0.34492556137528907, 0.20357576314466674, 117.0, 2.7, 4.021312958681009))",0.4894762604013705,19,4
"List(1, 26, List(), List(25.0, 13.0, 2.2483920452958106, 9.0, 0.1135467075328497, 22.0, 16.0, 291.6, 0.09428866766620415, 3.0, 145.7712099, 88384.0, 859064.0, 159.0, 754.3, 12.0, 5.0, 0.12114203807865305, 0.01822182511431046, 1.0697570139128167, 1.1810390587895663, 0.3449255614249928, 0.20357576315617926, 14619.0, 3.0, 1.1512529916280976))",0.0291014406377173,"List(1, 26, List(), List(25.0, 13.0, 2.2483920452958106, 9.0, 0.1135467075328497, 22.0, 16.0, 291.6, 0.09428866766620415, 3.0, 145.7712099, 88384.0, 859064.0, 159.0, 754.3, 12.0, 5.0, 0.12114203807865305, 0.01822182511431046, 1.0697570139128167, 1.1810390587895663, 0.3449255614249928, 0.20357576315617926, 14619.0, 3.0, 1.1512529916280976))",0.0384579348152506,3,2
"List(1, 26, List(), List(25.0, 13.0, 2.248392046307885, 241.0, 0.11354670755110553, 19.0, 21.0, 2.1, 0.0, 51.0, 102.2347935, 61987.0, 4794.0, 53.0, 5.6, 52.0, 62.0, 0.12114203817271588, 0.01822182519816437, 1.0697570139758033, 1.1810390586149355, 0.34492556132665836, 0.2035757632457238, 63.0, 4.0, 0.4171881518564873))",0.0,"List(1, 26, List(), List(25.0, 13.0, 2.248392046307885, 241.0, 0.11354670755110553, 19.0, 21.0, 2.1, 0.0, 51.0, 102.2347935, 61987.0, 4794.0, 53.0, 5.6, 52.0, 62.0, 0.12114203817271588, 0.01822182519816437, 1.0697570139758033, 1.1810390586149355, 0.34492556132665836, 0.2035757632457238, 63.0, 4.0, 0.4171881518564873))",0.0105825245376147,1,0
"List(1, 26, List(), List(25.0, 15.0, 2.2483920459895974, 213.0, 0.11354670763755818, 18.0, 73.0, 2.4, 0.0, 237.0, 65.5198575, 39726.0, 3653.0, 65.0, 4.4, 124.0, 85.0, 0.12114203805091707, 0.01822182507528059, 1.0697570139611279, 1.1810390585819874, 0.34492556145633724, 0.20357576320832194, 47.0, 3.3, 0.2737476047084588))",0.0,"List(1, 26, List(), List(25.0, 15.0, 2.2483920459895974, 213.0, 0.11354670763755818, 18.0, 73.0, 2.4, 0.0, 237.0, 65.5198575, 39726.0, 3653.0, 65.0, 4.4, 124.0, 85.0, 0.12114203805091707, 0.01822182507528059, 1.0697570139611279, 1.1810390585819874, 0.34492556145633724, 0.20357576320832194, 47.0, 3.3, 0.2737476047084588))",0.0105825245376147,1,0
"List(1, 26, List(), List(25.0, 18.0, 2.24839204583921, 215.0, 0.11354670768688294, 15.0, 123.0, 1.8, 0.0, 243.0, 70.71018604, 42873.0, 2836.0, 84.0, 4.0, 230.0, 202.0, 0.12114203808180535, 0.018221825105782793, 1.0697570137517631, 1.1810390585331454, 0.3449255613540198, 0.20357576304654443, 38.0, 3.4, 0.3526093088857546))",0.0,"List(1, 26, List(), List(25.0, 18.0, 2.24839204583921, 215.0, 0.11354670768688294, 15.0, 123.0, 1.8, 0.0, 243.0, 70.71018604, 42873.0, 2836.0, 84.0, 4.0, 230.0, 202.0, 0.12114203808180535, 0.018221825105782793, 1.0697570137517631, 1.1810390585331454, 0.3449255613540198, 0.20357576304654443, 38.0, 3.4, 0.3526093088857546))",0.0105825245376147,4,0


In [52]:
# Select (prediction, true label) and compute test error
evaluator = RegressionEvaluator(
    labelCol="label", predictionCol="prediction", metricName="rmse")
rmse = evaluator.evaluate(predictions)
print("Root Mean Squared Error (RMSE) on test data = %g" % rmse)

In [53]:
treeModel = model.stages[1]
# summary only
print(treeModel)

In [54]:
display(treeModel)

treeNode
"{""index"":31,""featureType"":""continuous"",""prediction"":null,""threshold"":4.354532061735112,""categories"":null,""feature"":25,""overflow"":false}"
"{""index"":15,""featureType"":""continuous"",""prediction"":null,""threshold"":1.7087879475080885,""categories"":null,""feature"":25,""overflow"":false}"
"{""index"":7,""featureType"":""continuous"",""prediction"":null,""threshold"":0.7818929879908421,""categories"":null,""feature"":25,""overflow"":false}"
"{""index"":3,""featureType"":""continuous"",""prediction"":null,""threshold"":0.2637752502671016,""categories"":null,""feature"":25,""overflow"":false}"
"{""index"":1,""featureType"":""continuous"",""prediction"":null,""threshold"":32778.0,""categories"":null,""feature"":12,""overflow"":false}"
"{""index"":0,""featureType"":null,""prediction"":0.0012319038986450127,""threshold"":null,""categories"":null,""feature"":null,""overflow"":false}"
"{""index"":2,""featureType"":null,""prediction"":0.007280124193144118,""threshold"":null,""categories"":null,""feature"":null,""overflow"":false}"
"{""index"":5,""featureType"":""continuous"",""prediction"":null,""threshold"":4.05,""categories"":null,""feature"":24,""overflow"":false}"
"{""index"":4,""featureType"":null,""prediction"":0.010582524537614709,""threshold"":null,""categories"":null,""feature"":null,""overflow"":false}"
"{""index"":6,""featureType"":null,""prediction"":0.020503345212901206,""threshold"":null,""categories"":null,""feature"":null,""overflow"":false}"


In [55]:
#print the nodes of the decision tree model
print(treeModel.toDebugString)

In [56]:
def extract_feature_imp(feature_imp, dataset, features_col): 
  """Affiche pour chaque features d'un modèle son nom correspondant et les trie dans l'ordre d'importance""" 
  list_extract = [] 
  for i in dataset.schema[features_col].metadata["ml_attr"]["attrs"]: 
    list_extract = list_extract + dataset.schema[features_col].metadata["ml_attr"]["attrs"][i] 
    varlist = pd.DataFrame(list_extract) 
    varlist['score'] = varlist['idx'].apply(lambda x: feature_imp[x]) * 100 
    return varlist.sort_values('score', ascending=False)

In [57]:
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [58]:
grid = ParamGridBuilder() \
  .addGrid(dt.maxDepth, [2, 3, 4, 5, 6, 7, 8]) \
  .addGrid(dt.maxBins, [2, 4, 8]) \
  .build()

In [59]:
cv = CrossValidator(estimator=pipeline, evaluator=evaluator, estimatorParamMaps=grid, numFolds=3)

In [60]:
# Explicitly create a new run.
# This allows this cell to be run multiple times.
# If you omit mlflow.start_run(), then this cell could run once,
# but a second run would hit conflicts when attempting to overwrite the first run.
import mlflow
import mlflow
import mlflow.mleap
import pyspark
#import pyspark.ml.mleap.SparkUtil 
#import mlflow.mleap.SparkUtil 
import mlflow.mleap
with mlflow.start_run():
  cvModel = cv.fit(trainingData)
  mlflow.set_tag('owner_team', 'UX Data Science') # Logs user-defined tags
  test_metric = evaluator.evaluate(cvModel.transform(testData))
  mlflow.log_metric('testData_' + evaluator.getMetricName(), test_metric) # Logs additional metrics
  mlflow.mleap.log_model(spark_model=cvModel.bestModel, sample_input=testData, artifact_path='dbfs:/databricks/mlflow/2835302286394144') # Logs the best model via mleap
