# Setup

In [None]:
!pip3 install gpy
!pip3 install git+https://github.com/BRML/climin
!pip3 install -U imbalanced-learn

In [2]:
from google.colab import drive
drive.mount('/content/gdrive/')

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).


# Finish data preprocessing
- Import mostly-preprocessed dataset (see load_vbac_data.py) 
- Check data looks right
- Add indicator columns for missing features
- Alter some features to make them more processable by the models


##  Load data

In [31]:
import pandas as pd
import os
import numpy as np
import torch
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import KFold

In [32]:
headers = ['FACILITY_RECODE', 'MOTHERS_AGE_RECODE', 'MARITAL_STATUS', 'MOTHERS_EDUCATION', 'PRIOR_BIRTHS_NOW_LIVING', 'PRIOR_BIRTHS_NOW_DEAD', 
           'PRIOR_OTHER_TERMINATIONS', 'LIVE_BIRTH_ORDER_RECODE', 'TOTAL_BIRTH_ORDER_RECODE', 'INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE', 
           'TRIMESTER_PRENATAL_CARE_BEGAN_RECODE', 'NUMBER_OF_PRENATAL_VISITS_RECODE', 'CIGARETTES_BEFORE_PREGNANCY_RECODE', 
           'CIGARETTES_FIRST_TRIMESTER_RECODE', 'CIGARETTES_SECOND_TRIMESTER_RECODE', 'CIGARETTES_THIRD_TRIMESTER_RECODE', 
           'MOTHERS_HEIGHT_IN_TOTAL_INCHES', 'MOTHERS_BMI_RECODE', 'PRE_PREGNANCY_WEIGHT_RECODE', 'DELIVERY_WEIGHT_RECODE', 'WEIGHT_GAIN', 
           'PRE_PREGNANCY_DIABETES', 'GESTATIONAL_DIABETES', 'PRE_PREGNANCY_HYPERTENSION', 'GESTATIONAL_HYPERTENSION', 
           'HYPERTENSION_ECLAMPSIA', 'PREVIOUS_PRETERM_BIRTH', 'PREVIOUS_CESAREAN', 'NUMBER_OF_PREVIOUS_CESAREANS', 'NO_INFECTIONS_REPORTED', 
           'INDUCTION_OF_LABOR', 'AUGMENTATION_OF_LABOR', 'CHORIOAMNIONITIS', 'ATTENDANT_AT_BIRTH', 'PAYMENT_SOURCE_FOR_DELIVERY', 
           'PLURALITY_RECODE', 'SEX_OF_INFANT', 'COMBINED_GESTATION_RECODE', 'BIRTH_WEIGHT_RECODE', 'TOL_ATTEMPTED', 
           'DELIVERY_METHOD_1', 'DELIVERY_METHOD_2']

result = ['successful_vbac']
headers.extend(result)

# Default values taken from UserGuide2019-508.pdf
list_of_cols_with_missing_vals_and_their_default_numb = [
        ('IN_HOSPITAL', 3),
        ('MARITAL_STATUS', 9),
        ('MOTHERS_EDUCATION', 9),
        ('PRIOR_BIRTHS_NOW_LIVING', 99),
        ('PRIOR_BIRTHS_NOW_DEAD', 99),
        ('PRIOR_OTHER_TERMINATIONS', 99),
        ('LIVE_BIRTH_ORDER_RECODE', 9),
        ('TOTAL_BIRTH_ORDER_RECODE', 9),
        ('INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE', 999),
        ('TRIMESTER_PRENATAL_CARE_BEGAN_RECODE', 5),
        ('NUMBER_OF_PRENATAL_VISITS_RECODE', 12),
        ('CIGARETTES_BEFORE_PREGNANCY_RECODE', 6),
        ('CIGARETTES_FIRST_TRIMESTER_RECODE', 6),
        ('CIGARETTES_SECOND_TRIMESTER_RECODE', 6),
        ('CIGARETTES_THIRD_TRIMESTER_RECODE', 6),
        ('MOTHERS_HEIGHT_IN_TOTAL_INCHES', 99),
        ('MOTHERS_BMI_RECODE', 9),
        ('PRE_PREGNANCY_WEIGHT_RECODE', 999),
        ('DELIVERY_WEIGHT_RECODE', 999),
        ('WEIGHT_GAIN', 99),
        ('NUMBER_OF_PREVIOUS_CESAREANS', 99),
        ('NO_INFECTIONS_REPORTED', 9),
        ('INDUCTION_OF_LABOR', -1),
        ('AUGMENTATION_OF_LABOR', -1),
        ('ATTENDANT_AT_BIRTH', 9),
        ('COMBINED_GESTATION_RECODE', 99),
        ('BIRTH_WEIGHT_RECODE', 12),
]

In [33]:
filename = '2019_vbac_data'
# Load data - edith MYPATH to contain the proper path to the dataset
MYPATH = 'gdrive/MyDrive/AA222'
data_path = os.getcwd() + f'/{MYPATH}/{filename}.csv'
birth_df = pd.read_csv(data_path, header=None, names=headers, index_col=False, skip_blank_lines=True, dtype=float)

In [34]:
X = pd.DataFrame(birth_df).iloc[:, :-1]
y = pd.DataFrame(birth_df).iloc[:, -1:] # 1 if successful VBAC, 0 if failed.

In [35]:
# Drop columns we don't want 
X = X.drop(columns=["CHORIOAMNIONITIS", "PAYMENT_SOURCE_FOR_DELIVERY"])

## Sanity check dataset

In [36]:
# Check that 100% of samples have had a prior cesarean
X["PREVIOUS_CESAREAN"].value_counts()

1.0    109126
Name: PREVIOUS_CESAREAN, dtype: int64

In [37]:
# Check that INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE does not have many 888s (first time births) 
print((X["INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE"] == 888.0).value_counts())

# If not too many, just make these 999 (unknown value) - probably a mistake inputting data, 
# as we saw above that 100% of samples had a previous cesarean. 
X["INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE"] = X["INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE"].apply(lambda x : 999.0 if x == 888.0 else x)

False    108908
True        218
Name: INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE, dtype: int64


In [38]:
# Make sure values are equivalent (data consistency/reliability reasons)
# For delivery_method_recode, 2 = VBAC and 4 = CBAC
# For delivery_method_recode_2, 1 = Vaginal and 2 = Cesarean

print(X["DELIVERY_METHOD_1"].value_counts())
print(X["DELIVERY_METHOD_2"].value_counts())

2.0    80289
4.0    28837
Name: DELIVERY_METHOD_1, dtype: int64
1.0    80289
2.0    28837
Name: DELIVERY_METHOD_2, dtype: int64


## Reform data to make more processable by models 

In [39]:
# Change facility_recode column to "in_hospital" for interpretability
X = X.rename(columns={"FACILITY_RECODE": "IN_HOSPITAL"})
X['IN_HOSPITAL'] = X['IN_HOSPITAL'].apply(lambda x: x if x==1 else 0)

In [40]:
# Add indicator columns to indicate if a feature is missing
# Set missing feature values to median 

for col, val in list_of_cols_with_missing_vals_and_their_default_numb:
    X[col+'_MISSING'] = X[col].apply(lambda x: 1 if x==val or x==-1 else 0)
    the_median = X[X[col]!=val][col].median()
    X[col] = X[col].apply(lambda x: the_median if x==val or x==-1 else x)

In [41]:
# Change marital status to 1 or 0 
X['MARITAL_STATUS'] = X['MARITAL_STATUS'].apply(lambda x: x if x == 1 else 0)

## Standardize dataset & delete unused columns

In [42]:
from sklearn import preprocessing
scaler = preprocessing.StandardScaler().fit(X)
X_scaled = pd.DataFrame(scaler.transform(X), columns=X.columns.values)

In [43]:
# Find empty columns
X_scaled.describe()

Unnamed: 0,IN_HOSPITAL,MOTHERS_AGE_RECODE,MARITAL_STATUS,MOTHERS_EDUCATION,PRIOR_BIRTHS_NOW_LIVING,PRIOR_BIRTHS_NOW_DEAD,PRIOR_OTHER_TERMINATIONS,LIVE_BIRTH_ORDER_RECODE,TOTAL_BIRTH_ORDER_RECODE,INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE,TRIMESTER_PRENATAL_CARE_BEGAN_RECODE,NUMBER_OF_PRENATAL_VISITS_RECODE,CIGARETTES_BEFORE_PREGNANCY_RECODE,CIGARETTES_FIRST_TRIMESTER_RECODE,CIGARETTES_SECOND_TRIMESTER_RECODE,CIGARETTES_THIRD_TRIMESTER_RECODE,MOTHERS_HEIGHT_IN_TOTAL_INCHES,MOTHERS_BMI_RECODE,PRE_PREGNANCY_WEIGHT_RECODE,DELIVERY_WEIGHT_RECODE,WEIGHT_GAIN,PRE_PREGNANCY_DIABETES,GESTATIONAL_DIABETES,PRE_PREGNANCY_HYPERTENSION,GESTATIONAL_HYPERTENSION,HYPERTENSION_ECLAMPSIA,PREVIOUS_PRETERM_BIRTH,PREVIOUS_CESAREAN,NUMBER_OF_PREVIOUS_CESAREANS,NO_INFECTIONS_REPORTED,INDUCTION_OF_LABOR,AUGMENTATION_OF_LABOR,ATTENDANT_AT_BIRTH,PLURALITY_RECODE,SEX_OF_INFANT,COMBINED_GESTATION_RECODE,BIRTH_WEIGHT_RECODE,TOL_ATTEMPTED,DELIVERY_METHOD_1,DELIVERY_METHOD_2,IN_HOSPITAL_MISSING,MARITAL_STATUS_MISSING,MOTHERS_EDUCATION_MISSING,PRIOR_BIRTHS_NOW_LIVING_MISSING,PRIOR_BIRTHS_NOW_DEAD_MISSING,PRIOR_OTHER_TERMINATIONS_MISSING,LIVE_BIRTH_ORDER_RECODE_MISSING,TOTAL_BIRTH_ORDER_RECODE_MISSING,INTERVAL_SINCE_LAST_LIVE_BIRTH_RECODE_MISSING,TRIMESTER_PRENATAL_CARE_BEGAN_RECODE_MISSING,NUMBER_OF_PRENATAL_VISITS_RECODE_MISSING,CIGARETTES_BEFORE_PREGNANCY_RECODE_MISSING,CIGARETTES_FIRST_TRIMESTER_RECODE_MISSING,CIGARETTES_SECOND_TRIMESTER_RECODE_MISSING,CIGARETTES_THIRD_TRIMESTER_RECODE_MISSING,MOTHERS_HEIGHT_IN_TOTAL_INCHES_MISSING,MOTHERS_BMI_RECODE_MISSING,PRE_PREGNANCY_WEIGHT_RECODE_MISSING,DELIVERY_WEIGHT_RECODE_MISSING,WEIGHT_GAIN_MISSING,NUMBER_OF_PREVIOUS_CESAREANS_MISSING,NO_INFECTIONS_REPORTED_MISSING,INDUCTION_OF_LABOR_MISSING,AUGMENTATION_OF_LABOR_MISSING,ATTENDANT_AT_BIRTH_MISSING,COMBINED_GESTATION_RECODE_MISSING,BIRTH_WEIGHT_RECODE_MISSING
count,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0,109126.0
mean,2.41315e-14,2.264939e-14,1.804939e-15,-4.590282e-15,9.200503e-15,1.98227e-15,-1.503709e-14,6.279719e-15,-7.58027e-15,6.425383e-16,1.344328e-14,-1.475439e-15,5.750471e-14,6.768001e-15,-2.628479e-14,-3.42004e-14,3.785009e-15,-3.246086e-15,-1.282122e-15,6.18143e-15,4.557235e-15,-6.622707e-15,6.250855e-15,7.248151000000001e-17,-4.49817e-15,-4.012963e-15,-3.852565e-14,0.0,-1.439385e-14,8.364722e-15,-3.198312e-14,8.730193e-15,-1.421809e-14,-2.394209e-14,4.565724e-16,-5.737308000000001e-17,3.824714e-15,-2.56136e-14,-2.56136e-14,-2.56136e-14,0.0,-4.039649e-13,-3.440561e-14,1.020674e-15,-1.057158e-14,9.511585e-16,2.91955e-15,-8.216552e-15,9.238789e-14,-1.499143e-14,1.70679e-14,-1.145405e-14,1.247236e-14,-1.181669e-14,-8.318734e-15,-5.024552e-15,-6.00456e-16,1.956915e-14,-1.818337e-14,-1.058516e-14,6.184666e-15,-1.252049e-14,5.587774e-15,5.587774e-15,4.123841e-15,1.1446e-15,3.030865e-15
std,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,0.0,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,0.0,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005,1.000005
min,-6.999215,-3.535614,-1.48828,-1.845953,-1.414333,-0.1278545,-0.5570391,-1.504695,-1.547539,-1.307769,-0.5065923,-2.766871,-0.2491192,-0.2217337,-0.2077425,-0.1997061,-11.75236,-1.737009,-2.107308,-2.199932,-1.921657,-0.1060656,-0.285453,-0.1550235,-0.2502652,-0.04555544,-0.3472214,0.0,-0.3270972,-6.118675,-0.5551422,-0.5836499,-0.4515723,-0.1308031,-0.983567,-3.999247,-4.885927,-0.5993038,-0.5993038,-0.5993038,0.0,-0.3247713,-0.1252319,-0.02840877,-0.04274239,-0.04504766,-0.04494542,-0.05639798,-0.2351934,-0.1651694,-0.1839046,-0.07478964,-0.07540666,-0.07503704,-0.1239043,-0.07503704,-0.1675257,-0.1574547,-0.1253074,-0.1936412,-0.04830129,-0.0501715,-0.01420009,-0.01420009,-0.02791994,-0.02758927,-0.03807843
25%,0.1428732,-0.7388512,-1.48828,-0.7259067,-0.7118656,-0.1278545,-0.5570391,-0.7602978,-0.9425716,-0.6617724,-0.5065923,-0.3120103,-0.2491192,-0.2217337,-0.2077425,-0.1997061,-0.6107354,-0.9023752,-0.7501833,-0.6966793,-0.5966313,-0.1060656,-0.285453,-0.1550235,-0.2502652,-0.04555544,-0.3472214,0.0,-0.3270972,0.1634341,-0.5551422,-0.5836499,-0.4515723,-0.1308031,-0.983567,-0.6142953,-0.8491491,-0.5993038,-0.5993038,-0.5993038,0.0,-0.3247713,-0.1252319,-0.02840877,-0.04274239,-0.04504766,-0.04494542,-0.05639798,-0.2351934,-0.1651694,-0.1839046,-0.07478964,-0.07540666,-0.07503704,-0.1239043,-0.07503704,-0.1675257,-0.1574547,-0.1253074,-0.1936412,-0.04830129,-0.0501715,-0.01420009,-0.01420009,-0.02791994,-0.02758927,-0.03807843
50%,0.1428732,0.1934031,0.6719165,-0.1658834,-0.00939833,-0.1278545,-0.5570391,-0.01590079,-0.3376036,-0.3094105,-0.5065923,0.178962,-0.2491192,-0.2217337,-0.2077425,-0.1997061,0.08561616,-0.06774143,-0.2320086,-0.1705409,-0.03872582,-0.1060656,-0.285453,-0.1550235,-0.2502652,-0.04555544,-0.3472214,0.0,-0.3270972,0.1634341,-0.5551422,-0.5836499,-0.4515723,-0.1308031,-0.983567,0.06269509,-0.04179345,-0.5993038,-0.5993038,-0.5993038,0.0,-0.3247713,-0.1252319,-0.02840877,-0.04274239,-0.04504766,-0.04494542,-0.05639798,-0.2351934,-0.1651694,-0.1839046,-0.07478964,-0.07540666,-0.07503704,-0.1239043,-0.07503704,-0.1675257,-0.1574547,-0.1253074,-0.1936412,-0.04830129,-0.0501715,-0.01420009,-0.01420009,-0.02791994,-0.02758927,-0.03807843
75%,0.1428732,1.125657,0.6719165,0.9541632,0.6930689,-0.1278545,0.3854312,0.7284963,0.2673643,0.3072229,-0.5065923,0.6699342,-0.2491192,-0.2217337,-0.2077425,-0.1997061,0.7819678,0.7668923,0.4835661,0.5309771,0.5889178,-0.1060656,-0.285453,-0.1550235,-0.2502652,-0.04555544,-0.3472214,0.0,-0.3270972,0.1634341,-0.5551422,1.713356,-0.4515723,-0.1308031,1.016708,0.7396855,0.7655622,1.668603,1.668603,1.668603,0.0,-0.3247713,-0.1252319,-0.02840877,-0.04274239,-0.04504766,-0.04494542,-0.05639798,-0.2351934,-0.1651694,-0.1839046,-0.07478964,-0.07540666,-0.07503704,-0.1239043,-0.07503704,-0.1675257,-0.1574547,-0.1253074,-0.1936412,-0.04830129,-0.0501715,-0.01420009,-0.01420009,-0.02791994,-0.02758927,-0.03807843
max,0.1428732,3.92242,0.6719165,2.07421,11.23008,48.03873,21.11978,3.706084,2.687236,7.413188,3.840647,2.142851,8.119734,9.905886,11.25106,12.04775,4.960077,2.43616,5.295189,5.316332,4.912685,9.428129,3.503204,6.450636,3.995761,21.95127,2.880007,0.0,17.93462,0.1634341,1.80134,1.713356,4.680513,15.12737,1.016708,2.093666,3.187629,1.668603,1.668603,1.668603,0.0,3.079089,7.985187,35.2004,23.39598,22.19871,22.24921,17.73113,4.25182,6.054389,5.437601,13.37084,13.26143,13.32675,8.070748,13.32675,5.969234,6.351032,7.980378,5.164191,20.70338,19.93163,70.4221,70.4221,35.81669,36.24598,26.26159


In [44]:
# We don't want these anymore as they indicate the result 
# Used for data sanity checks

X_scaled = X_scaled.drop(columns=['PREVIOUS_CESAREAN', 'TOL_ATTEMPTED', 
           'DELIVERY_METHOD_1', 'DELIVERY_METHOD_2'])

## Extract train/test datasets

In [61]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2)

## Balance the training dataset
We want to keep the test dataset unbalanced to accurately represent how real-world samples look.

In [62]:
import climin
from sklearn.metrics import accuracy_score
from imblearn.under_sampling import AllKNN
from imblearn.over_sampling import ADASYN
from sklearn.metrics import roc_auc_score, confusion_matrix

In [None]:
undersample, oversample = False, True

if undersample:
  rs = AllKNN()
  X_train, y_train = rus.fit_resample(X_train, y_train)
elif oversample:
  rs = ADASYN()
  X_train, y_train = rus.fit_resample(X_train, y_train)


In [None]:
# Check that dataset is properly balanced 
y_train["successful_vbac"].value_counts()

# Logistic Regression Model

## Train model (10-fold cross validation)

In [None]:
model = LogisticRegression(max_iter=1000)
scores = []
kfold = KFold(n_splits=10)

for i, (train, test) in enumerate(kfold.split(X_train, y_train)):
  model.fit(X_train.iloc[train,:], y_train.iloc[train,:].values.ravel())
  score = model.score(X_train.iloc[test,:], y_train.iloc[test,:].values.ravel())
  scores.append(score)

## Print results

In [50]:
print(scores)

[0.7631208526040619, 0.7699577719686306, 0.7703599436959582, 0.7584958777397949, 0.8009249949728534, 0.7808164086064749, 0.7775990347878544, 0.7759903478785442, 0.7609089081037603, 0.7609089081037603]


In [51]:
model.intercept_

array([-0.06855292])

In [52]:
y_test.mean()

successful_vbac    0.736461
dtype: float64

In [53]:
import statistics

print("Mean Train Accuracy: ", statistics.mean(scores))
print("Train Accuracy Stddev : ", statistics.stdev(scores))
# cm = confusion_matrix(y_test, model.predict(X_test)
# print("Confusion Matrix: ", cm)
print("Test Accuracy: ", model.score(X_test, y_test.values.ravel()))
print("Test AUC: ", roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]))

Mean Train Accuracy:  0.7719083048461693
Train Accuracy Stddev :  0.012804765072436174
Test Accuracy:  0.5877393933840374
Test AUC:  0.7409791886551127


In [54]:
sorted_indexes = np.argsort(np.abs(model.coef_))[0][::-1]
print("Features with highest +/- coefficient:\n")
for i, feature in enumerate(sorted_indexes[:10]):
  print(f'{i+1})      {X_train.columns[feature]}: {round(model.coef_[0][feature] , 3)}')
  print("")

Features with highest +/- coefficient:

1)      LIVE_BIRTH_ORDER_RECODE: 1.421

2)      ATTENDANT_AT_BIRTH: 1.091

3)      NUMBER_OF_PREVIOUS_CESAREANS: -0.968

4)      IN_HOSPITAL: -0.714

5)      TOTAL_BIRTH_ORDER_RECODE: 0.682

6)      PRIOR_BIRTHS_NOW_LIVING: -0.675

7)      MOTHERS_HEIGHT_IN_TOTAL_INCHES: 0.443

8)      PRIOR_OTHER_TERMINATIONS: -0.43

9)      CIGARETTES_THIRD_TRIMESTER_RECODE_MISSING: 0.399

10)      CIGARETTES_FIRST_TRIMESTER_RECODE_MISSING: -0.369



# Gradient Boosted Decision Tree

In [55]:
from sklearn.ensemble import GradientBoostingClassifier

gbc = GradientBoostingClassifier()

## Hyperparameter sweep 

Skip this if not changing the dataset.

In [56]:
n_estimators = [int(x) for x in np.linspace(50, 500, 5)]
max_depth = [int(x) for x in np.linspace(2, 25, 5)]
# Add the default as a possible value
max_depth.append(3)
max_depth.append(4)

max_features = ['auto', 'log2', .5, .75]
subsample = [.8,1.]
criterion = ['friedman_mse']
min_samples_split = [int(x) for x in np.linspace(2, 200, 6)]
min_impurity_decrease = [0.02, 0.05, 0.1]

# creating hyper param grid to search over
hyper_param_grid = {
    'n_estimators': n_estimators,
    'max_depth': max_depth,
    'max_features': max_features,
    'subsample': subsample,
    'criterion': criterion,
    'min_samples_split': min_samples_split,
    'min_impurity_decrease': min_impurity_decrease
  }


In [110]:
from sklearn.model_selection import RandomizedSearchCV

# Set to True if you have changed the dataset / would like to re-tune the GBDT
tune = True

if tune: 
  gbc_CV_tuner = RandomizedSearchCV(estimator = gbc, param_distributions = hyper_param_grid, scoring='f1',
                                n_iter = 50, cv = 6, verbose = 50,#, random_state = 100, 
                                n_jobs = 1, refit=False)

  gbc_CV_tuner.fit(X_train, np.ravel(y_train))

Fitting 6 folds for each of 50 candidates, totalling 300 fits
[CV 1/6; 1/50] START criterion=friedman_mse, max_depth=13, max_features=0.5, min_impurity_decrease=0.05, min_samples_split=81, n_estimators=162, subsample=0.8
[CV 1/6; 1/50] END criterion=friedman_mse, max_depth=13, max_features=0.5, min_impurity_decrease=0.05, min_samples_split=81, n_estimators=162, subsample=0.8;, score=0.841 total time=  39.3s
[CV 2/6; 1/50] START criterion=friedman_mse, max_depth=13, max_features=0.5, min_impurity_decrease=0.05, min_samples_split=81, n_estimators=162, subsample=0.8
[CV 2/6; 1/50] END criterion=friedman_mse, max_depth=13, max_features=0.5, min_impurity_decrease=0.05, min_samples_split=81, n_estimators=162, subsample=0.8;, score=0.842 total time=  39.2s
[CV 3/6; 1/50] START criterion=friedman_mse, max_depth=13, max_features=0.5, min_impurity_decrease=0.05, min_samples_split=81, n_estimators=162, subsample=0.8
[CV 3/6; 1/50] END criterion=friedman_mse, max_depth=13, max_features=0.5, min_im

In [111]:
# Then run this to find the winner:
if tune: 
  print(gbc_CV_tuner.best_params_)

{'subsample': 1.0, 'n_estimators': 162, 'min_samples_split': 120, 'min_impurity_decrease': 0.1, 'max_features': 0.5, 'max_depth': 3, 'criterion': 'friedman_mse'}


Input these values into the ```optimized_gbc``` instantiation below.

## Fit model 

In [57]:
# Model

# Optimized for imbalanced dataset:
# optimized_gbc = GradientBoostingClassifier(n_estimators=162,min_samples_split=120, min_impurity_decrease=0.1, max_features=0.5, max_depth=3, criterion='friedman_mse', subsample=1.0)
# GBC Accuracy: 0.756
# GBC F1: 0.85
# AUC: 0.59


# Optimized for balanced training data and imbalanced test data: 
optimized_gbc = GradientBoostingClassifier(n_estimators=275,min_samples_split=120, min_impurity_decrease=0.05, max_features=0.75, max_depth=25, criterion='friedman_mse', subsample=0.8)

optimized_gbc.fit(X_train, np.ravel(y_train))


GradientBoostingClassifier(max_depth=25, max_features=0.75,
                           min_impurity_decrease=0.05, min_samples_split=120,
                           n_estimators=275, subsample=0.8)

## Evaluate model

In [58]:
gbc_predictions = optimized_gbc.predict(X_test)
gbc_probs = optimized_gbc.predict_proba(X_test)  # Gives probability of output (prediction before rounding)

## Print results

In [59]:
from sklearn.metrics import accuracy_score, f1_score
print('GBC Accuracy: ' + str(round(accuracy_score(y_test, gbc_predictions), 3)))
print('GBC F1: ' + str(round(f1_score(y_test, gbc_predictions), 2)))
print('AUC: ' + str(round(roc_auc_score(y_test, gbc_predictions), 2)))

GBC Accuracy: 0.631
GBC F1: 0.69
AUC: 0.69


In [None]:
sorted_indexes = np.argsort(optimized_gbc.feature_importances_)[::-1]
print("Features with highest GBDT feature importances:\n")
for i, feature in enumerate(sorted_indexes[:10]):
  print(f'{i+1})      {X_train.columns[feature]}: {round(optimized_gbc.feature_importances_[feature], 3)}')
  print("")

# Save model

In [47]:
import pickle
filename = 'finalized_GBDT.sav'
pickle.dump(optimized_gbc, open(filename, 'wb'))

# Input user & get prediction

In [40]:
# User
user = [1, 6, 1, 7, 3, 0, 0, 3, 3, 4, 1, 5, 0, 0, 0, 0, 69, 2, 135, 181, 27, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 1, 7, 12, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]

gbc_success_prob = optimized_gbc.predict_proba(np.array(user).reshape(1, -1))[0][1]  # Gives probability of output (prediction before rounding)
print(gbc_probs)

[0.]
[[0.98818512 0.01181488]]
