In [1]:
sc

# Introducing MLlib

In [2]:
import pyspark.sql.types as typ
import pyspark.sql.functions as fn

In [3]:
birthFilePath = './data/births_train.csv.gz'

In [4]:
labels = [
    ('INFANT_ALIVE_AT_REPORT', typ.StringType()),
    ('BIRTH_YEAR', typ.IntegerType()),
    ('BIRTH_MONTH', typ.IntegerType()),
    ('BIRTH_PLACE', typ.StringType()),
    ('MOTHER_AGE_YEARS', typ.IntegerType()),
    ('MOTHER_RACE_6CODE', typ.StringType()),
    ('MOTHER_EDUCATION', typ.StringType()),
    ('FATHER_COMBINED_AGE', typ.IntegerType()),
    ('FATHER_EDUCATION', typ.StringType()),
    ('MONTH_PRECARE_RECODE', typ.StringType()),
    ('CIG_BEFORE', typ.IntegerType()),
    ('CIG_1_TRI', typ.IntegerType()),
    ('CIG_2_TRI', typ.IntegerType()),
    ('CIG_3_TRI', typ.IntegerType()),
    ('MOTHER_HEIGHT_IN', typ.IntegerType()),
    ('MOTHER_BMI_RECODE', typ.IntegerType()),
    ('MOTHER_PRE_WEIGHT', typ.IntegerType()),
    ('MOTHER_DELIVERY_WEIGHT', typ.IntegerType()),
    ('MOTHER_WEIGHT_GAIN', typ.IntegerType()),
    ('DIABETES_PRE', typ.StringType()),
    ('DIABETES_GEST', typ.StringType()),
    ('HYP_TENS_PRE', typ.StringType()),
    ('HYP_TENS_GEST', typ.StringType()),
    ('PREV_BIRTH_PRETERM', typ.StringType()),
    ('NO_RISK', typ.StringType()),
    ('NO_INFECTIONS_REPORTED', typ.StringType()),
    ('LABOR_IND', typ.StringType()),
    ('LABOR_AUGM', typ.StringType()),
    ('STEROIDS', typ.StringType()),
    ('ANTIBIOTICS', typ.StringType()),
    ('ANESTHESIA', typ.StringType()),
    ('DELIV_METHOD_RECODE_COMB', typ.StringType()),
    ('ATTENDANT_BIRTH', typ.StringType()),
    ('APGAR_5', typ.IntegerType()),
    ('APGAR_5_RECODE', typ.StringType()),
    ('APGAR_10', typ.IntegerType()),
    ('APGAR_10_RECODE', typ.StringType()),
    ('INFANT_SEX', typ.StringType()),
    ('OBSTETRIC_GESTATION_WEEKS', typ.IntegerType()),
    ('INFANT_WEIGHT_GRAMS', typ.IntegerType()),
    ('INFANT_ASSIST_VENTI', typ.StringType()),
    ('INFANT_ASSIST_VENTI_6HRS', typ.StringType()),
    ('INFANT_NICU_ADMISSION', typ.StringType()),
    ('INFANT_SURFACANT', typ.StringType()),
    ('INFANT_ANTIBIOTICS', typ.StringType()),
    ('INFANT_SEIZURES', typ.StringType()),
    ('INFANT_NO_ABNORMALITIES', typ.StringType()),
    ('INFANT_ANCEPHALY', typ.StringType()),
    ('INFANT_MENINGOMYELOCELE', typ.StringType()),
    ('INFANT_LIMB_REDUCTION', typ.StringType()),
    ('INFANT_DOWN_SYNDROME', typ.StringType()),
    ('INFANT_SUSPECTED_CHROMOSOMAL_DISORDER', typ.StringType()),
    ('INFANT_NO_CONGENITAL_ANOMALIES_CHECKED', typ.StringType()),
    ('INFANT_BREASTFED', typ.StringType())
]

schema = typ.StructType([
 typ.StructField(e[0], e[1], False) for e in labels
 ])

In [5]:
df_birth = spark.read.csv(birthFilePath,
                         schema = schema,
                         header = True)

In [6]:
recode_dict = {
    'YNU': {
        'Y': 1,
        'N': 0,
        'U': 0
    }
}

We will drop all of the features that relate to the infant and will
try to predict the infant's chances of surviving only based on the features related to
its mother, father, and the place of birth:

In [7]:
selected_features = [
 'INFANT_ALIVE_AT_REPORT',
 'BIRTH_PLACE',
 'MOTHER_AGE_YEARS',
 'FATHER_COMBINED_AGE',
 'CIG_BEFORE',
 'CIG_1_TRI',
 'CIG_2_TRI',
 'CIG_3_TRI',
 'MOTHER_HEIGHT_IN',
 'MOTHER_PRE_WEIGHT',
 'MOTHER_DELIVERY_WEIGHT',
 'MOTHER_WEIGHT_GAIN',
 'DIABETES_PRE',
 'DIABETES_GEST',
 'HYP_TENS_PRE',
 'HYP_TENS_GEST',
 'PREV_BIRTH_PRETERM'
]
births_trimmed = df_birth.select(selected_features)

Number of cigs smoked: 
- 0 means the mother smoked no cigarettes before or during the pregnancy;
- Between 1-97 states the actual number of cigarette smoked;
- 98 indicates either 98 or more.

Whereas 99 identifies the unknown; we will assume the unknown is 0 and recode accordingly.

In [8]:
for col in ['CIG_BEFORE', 'CIG_1_TRI', 'CIG_2_TRI', 'CIG_3_TRI']:
    births_trimmed = births_trimmed.withColumn(col,
                                               fn.when(births_trimmed[col] > 98,
                                               fn.lit('0')).otherwise(births_trimmed[col]))

Now we will focus on correcting the Yes/No/Unknown features. First, we will
figure out which these are with the following snippet:

In [9]:
cols = [(col.name, col.dataType) for col in births_trimmed.schema]
YNU_cols = []

for i, s in enumerate(cols):
    if s[1] == typ.StringType():
        if births_trimmed.select(s[0]).distinct().collect()[0][0] == 'Y' \
        or births_trimmed.select(s[0]).distinct().collect()[0][0] == 'N' \
        or births_trimmed.select(s[0]).distinct().collect()[0][0] == 'U':
            YNU_cols.append(s[0])

In [10]:
YNU_cols

['INFANT_ALIVE_AT_REPORT',
 'DIABETES_PRE',
 'DIABETES_GEST',
 'HYP_TENS_PRE',
 'HYP_TENS_GEST',
 'PREV_BIRTH_PRETERM']

In [11]:
for col in YNU_cols:
    births_trimmed = births_trimmed.withColumn(col,
        fn.when(births_trimmed[col] == 'Y',
        fn.lit(1)).otherwise(fn.lit(0)))

In [12]:
births_trimmed.show(5)

+----------------------+-----------+----------------+-------------------+----------+---------+---------+---------+----------------+-----------------+----------------------+------------------+------------+-------------+------------+-------------+------------------+
|INFANT_ALIVE_AT_REPORT|BIRTH_PLACE|MOTHER_AGE_YEARS|FATHER_COMBINED_AGE|CIG_BEFORE|CIG_1_TRI|CIG_2_TRI|CIG_3_TRI|MOTHER_HEIGHT_IN|MOTHER_PRE_WEIGHT|MOTHER_DELIVERY_WEIGHT|MOTHER_WEIGHT_GAIN|DIABETES_PRE|DIABETES_GEST|HYP_TENS_PRE|HYP_TENS_GEST|PREV_BIRTH_PRETERM|
+----------------------+-----------+----------------+-------------------+----------+---------+---------+---------+----------------+-----------------+----------------------+------------------+------------+-------------+------------+-------------+------------------+
|                     0|          1|              29|                 99|         0|        0|        0|        0|              99|              999|                   999|                99|           0| 

## Exploratory Data Analysis

Numeric variables:

In [15]:
numeric_cols = ['MOTHER_AGE_YEARS','FATHER_COMBINED_AGE',
 'CIG_BEFORE','CIG_1_TRI','CIG_2_TRI','CIG_3_TRI',
 'MOTHER_HEIGHT_IN','MOTHER_PRE_WEIGHT',
 'MOTHER_DELIVERY_WEIGHT','MOTHER_WEIGHT_GAIN'
 ]

In [16]:
births_trimmed.select(numeric_cols).describe().show()

+-------+------------------+-------------------+-----------------+------------------+-----------------+------------------+------------------+------------------+----------------------+------------------+
|summary|  MOTHER_AGE_YEARS|FATHER_COMBINED_AGE|       CIG_BEFORE|         CIG_1_TRI|        CIG_2_TRI|         CIG_3_TRI|  MOTHER_HEIGHT_IN| MOTHER_PRE_WEIGHT|MOTHER_DELIVERY_WEIGHT|MOTHER_WEIGHT_GAIN|
+-------+------------------+-------------------+-----------------+------------------+-----------------+------------------+------------------+------------------+----------------------+------------------+
|  count|             45429|              45429|            45429|             45429|            45429|             45429|             45429|             45429|                 45429|             45429|
|   mean|28.298421713002707|  44.54975896453807|1.427986528428977|0.9057430275815008|0.702480794206344|0.5800259745977239|   65.120891941271|214.49840410310594|    223.62609786700125| 30.7

Categorical variables:

In [22]:
categorical_cols = [col for col in births_trimmed.columns if (col not in numeric_cols)]

In [23]:
categorical_cols

['INFANT_ALIVE_AT_REPORT',
 'BIRTH_PLACE',
 'DIABETES_PRE',
 'DIABETES_GEST',
 'HYP_TENS_PRE',
 'HYP_TENS_GEST',
 'PREV_BIRTH_PRETERM']