In [1]:
# this workbook demonstrates how to use scikit-learn with PySpark to do single variable binary regression
# 
# it's here as an example of how to switch from SparkML to scikit-learn if you'd like to do your 
# data pipeline and formatting in spark, but switch to pandas and scikit-learn for
# building, fitting, and interpreting your model
# 
# this can be a good approach when the heavy lifting that benefits from the cluster is mainly
# limited to data preparation, but your ML model itself is relatively small and fits neatly into
# memory and/or runs on a single node, and you'd prefer to use pandas scikit-learn. 
#

In [2]:
import numpy as np
from sklearn import tree

from pyspark import SparkConf, SparkContext
from pyspark.sql import SQLContext

from pyspark.sql.session import SparkSession

In [3]:
sc = SparkContext('local')
spark = SparkSession(sc)

sqlContext = SQLContext(sc)

In [4]:
data = sqlContext.read.format("csv").option("inferschema","true").option("header", "true").option("delimiter", ",").load("gapminder_all_binary.csv")

In [5]:
#data.printSchema()

In [6]:
data = data.toPandas()

In [7]:
data

Unnamed: 0,continent,country,gdpPercap_1952,gdpPercap_1957,gdpPercap_1962,gdpPercap_1967,gdpPercap_1972,gdpPercap_1977,gdpPercap_1982,gdpPercap_1987,...,pop_1967,pop_1972,pop_1977,pop_1982,pop_1987,pop_1992,pop_1997,pop_2002,pop_2007,Over_65
0,Africa,Algeria,2449.008185,3013.976023,2550.816880,3246.991771,4182.663766,4910.416756,5745.160213,5681.358539,...,12760499,14760787,17152804,20033753,23254956,26298373,29072015,31287142,33333216,1
1,Africa,Angola,3520.610273,3827.940465,4269.276742,5522.776375,5473.288005,3008.647355,2756.953672,2430.208311,...,5247469,5894858,6162675,7016384,7874230,8735988,9875024,10866106,12420476,0
2,Africa,Benin,1062.752200,959.601080,949.499064,1035.831411,1085.796879,1029.161251,1277.897616,1225.856010,...,2427334,2761407,3168267,3641603,4243788,4981671,6066080,7026113,8078314,0
3,Africa,Botswana,851.241141,918.232535,983.653976,1214.709294,2263.611114,3214.857818,4551.142150,6205.883850,...,553541,619351,781472,970347,1151184,1342614,1536536,1630347,1639131,0
4,Africa,Burkina Faso,543.255241,617.183465,722.512021,794.826560,854.735976,743.387037,807.198586,912.063142,...,5127935,5433886,5889574,6634596,7586551,8878303,10352843,12251209,14326203,0
5,Africa,Burundi,339.296459,379.564628,355.203227,412.977514,464.099504,556.103265,559.603231,621.818819,...,3330989,3529983,3834415,4580410,5126023,5809236,6121610,7021078,8390505,0
6,Africa,Cameroon,1172.667655,1313.048099,1399.607441,1508.453148,1684.146528,1783.432873,2367.983282,2602.664206,...,6335506,7021028,7959865,9250831,10780667,12467171,14195809,15929988,17696293,0
7,Africa,Central African Republic,1071.310713,1190.844328,1193.068753,1136.056615,1070.013275,1109.374338,956.752991,844.876350,...,1733638,1927260,2167533,2476971,2840009,3265124,3696513,4048013,4369038,0
8,Africa,Chad,1178.665927,1308.495577,1389.817618,1196.810565,1104.103987,1133.984950,797.908101,952.386129,...,3495967,3899068,4388260,4875118,5498955,6429417,7562011,8835739,10238807,0
9,Africa,Comoros,1102.990936,1211.148548,1406.648278,1876.029643,1937.577675,1172.603047,1267.100083,1315.980812,...,217378,250027,304739,348643,395114,454429,527982,614382,710960,1


In [8]:
# Select Labels
y = data.Over_65

# Select systolic column
feature_cols = ['lifeExp_2007']
X = data.loc[:, feature_cols]

# Learn decision boundary 
clf = tree.DecisionTreeClassifier(max_depth=1)
clf = clf.fit(X, y)

threshold = clf.tree_.threshold[0]

print("Learned 2007 life expectancy threshold for life expectency over 65: " + str(threshold))

Learned 2007 life expectancy threshold for life expectency over 65: 64.92499923706055


In [9]:
# predictions

In [10]:
predictions = clf.predict_proba(X)

In [11]:
import pandas as pd

In [12]:
pd.DataFrame({"lifeExp_2007":data['lifeExp_2007'], "Over_65_Pred":[x[1] for x in predictions]})

Unnamed: 0,Over_65_Pred,lifeExp_2007
0,0.988889,72.301
1,0.019231,42.731
2,0.019231,56.728
3,0.019231,50.728
4,0.019231,52.295
5,0.019231,49.580
6,0.019231,50.430
7,0.019231,44.741
8,0.019231,50.651
9,0.988889,65.152
