In [1]:
import pyspark  
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import LogisticRegressionWithSGD
from pyspark.mllib.tree import DecisionTree

In [2]:
sc = pyspark.SparkContext()

In [3]:
raw_rdd = sc.textFile("datasets/titanic.csv")

In [4]:
raw_rdd.count()

1317

In [5]:
raw_rdd.take(5)

['"","class","age","sex","survived"',
 '"1","1st class","adults","man","yes"',
 '"2","1st class","adults","man","yes"',
 '"3","1st class","adults","man","yes"',
 '"4","1st class","adults","man","yes"']

In [6]:
header = raw_rdd.first()
data_rdd = raw_rdd.filter(lambda line: line != header)

In [7]:
data_rdd.takeSample(False, 5, 0)

['"159","1st class","adults","man","no"',
 '"256","1st class","adults","women","yes"',
 '"1204","3rd class","adults","women","no"',
 '"758","3rd class","adults","man","no"',
 '"730","3rd class","adults","man","no"']

In [10]:
def row_to_labeled_point(line):
    '''
    Builds a LabelPoint consisting of:
    
    survival (truth): 0=no, 1=yes
    ticket class: 0=1st class, 1=2nd class, 2=3rd class
    age group: 0=child, 1=adults
    gender: 0=man, 1=woman
    '''
    passenger_id, klass, age, sex, survived = [segs.strip('"') for segs in line.split(',')]
    klass = int(klass[0]) - 1
    
    if (age not in ['adults', 'child'] or 
        sex not in ['man', 'women'] or
        survived not in ['yes', 'no']):
        raise RuntimeError('unknown value')
    
    features = [
        klass,
        (1 if age == 'adults' else 0),
        (1 if sex == 'women' else 0)
    ]
    return LabeledPoint(1 if survived == 'yes' else 0, features)

In [11]:
labeled_points_rdd = data_rdd.map(row_to_labeled_point)

In [12]:
labeled_points_rdd.takeSample(False, 5, 0)

[LabeledPoint(0.0, [0.0,1.0,0.0]),
 LabeledPoint(1.0, [0.0,1.0,1.0]),
 LabeledPoint(0.0, [2.0,1.0,1.0]),
 LabeledPoint(0.0, [2.0,1.0,0.0]),
 LabeledPoint(0.0, [2.0,1.0,0.0])]

In [13]:
training_rdd, test_rdd = labeled_points_rdd.randomSplit([0.7, 0.3], seed = 0)

In [14]:
training_count = training_rdd.count()
test_count = test_rdd.count()

In [15]:
training_count, test_count

(914, 402)

In [16]:
model = DecisionTree.trainClassifier(training_rdd, 
                                     numClasses=2, 
                                     categoricalFeaturesInfo={
                                        0: 3,
                                        1: 2,
                                        2: 2
                                     })

In [17]:
predictions_rdd = model.predict(test_rdd.map(lambda x: x.features))

In [18]:
truth_and_predictions_rdd = test_rdd.map(lambda lp: lp.label).zip(predictions_rdd)

In [19]:
accuracy = truth_and_predictions_rdd.filter(lambda v_p: v_p[0] == v_p[1]).count() / float(test_count)
print('Accuracy =', accuracy)
print(model.toDebugString())

Accuracy = 0.7985074626865671
DecisionTreeModel classifier of depth 4 with 21 nodes
  If (feature 2 in {0.0})
   If (feature 1 in {0.0})
    If (feature 0 in {0.0,1.0})
     Predict: 1.0
    Else (feature 0 not in {0.0,1.0})
     Predict: 0.0
   Else (feature 1 not in {0.0})
    If (feature 0 in {1.0})
     Predict: 0.0
    Else (feature 0 not in {1.0})
     If (feature 0 in {0.0})
      Predict: 0.0
     Else (feature 0 not in {0.0})
      Predict: 0.0
  Else (feature 2 not in {0.0})
   If (feature 0 in {2.0})
    If (feature 1 in {0.0})
     Predict: 0.0
    Else (feature 1 not in {0.0})
     Predict: 0.0
   Else (feature 0 not in {2.0})
    If (feature 0 in {1.0})
     If (feature 1 in {0.0})
      Predict: 1.0
     Else (feature 1 not in {0.0})
      Predict: 1.0
    Else (feature 0 not in {1.0})
     If (feature 1 in {0.0})
      Predict: 1.0
     Else (feature 1 not in {0.0})
      Predict: 1.0



In [20]:
model = LogisticRegressionWithSGD.train(training_rdd)

In [21]:
predictions_rdd = model.predict(test_rdd.map(lambda x: x.features))

In [22]:
labels_and_predictions_rdd = test_rdd.map(lambda lp: lp.label).zip(predictions_rdd)

In [23]:
accuracy = labels_and_predictions_rdd.filter(lambda v_p: v_p[0] == v_p[1]).count() / float(test_count)
print('Accuracy =', accuracy)

Accuracy = 0.7860696517412935
