In [1]:
train = sqlContext.sql("select * from train_csv")
train.printSchema()

In [2]:
test = sqlContext.sql("select * from test_csv")
test.printSchema()

In [3]:
train = train.drop('Names','Onboard_date')
train.describe().show()

In [4]:
test = test.drop('Names')
test.describe().show()

In [5]:
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=['Age', 'Total_Purchase', 'Account_Manager', 'Years', 'Num_Sites'],outputCol='features')

In [6]:
output = assembler.transform(train)
train_churn = output.select('features','churn')
train_churn.show()

In [7]:
output = assembler.transform(test)
test_churn = output.select('features','churn')
test_churn.show()

In [8]:
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

In [9]:
lrc = LogisticRegression(labelCol='churn')
dtc = DecisionTreeClassifier(labelCol='churn')
rfc = RandomForestClassifier(labelCol='churn',numTrees=100)
gbt = GBTClassifier(labelCol='churn')

In [10]:
# Train the models
lrc_model = lrc.fit(train_churn)
dtc_model = dtc.fit(train_churn)
rfc_model = rfc.fit(train_churn)
gbt_model = gbt.fit(train_churn)

In [11]:
# Test the models
lrc_predictions = lrc_model.transform(test_churn)
dtc_predictions = dtc_model.transform(test_churn)
rfc_predictions = rfc_model.transform(test_churn)
gbt_predictions = gbt_model.transform(test_churn)

In [12]:
# Evaluate the models
acc_evaluator = MulticlassClassificationEvaluator(labelCol="churn", predictionCol="prediction", metricName="accuracy")
lrc_acc = acc_evaluator.evaluate(lrc_predictions)
dtc_acc = acc_evaluator.evaluate(dtc_predictions)
rfc_acc = acc_evaluator.evaluate(rfc_predictions)
gbt_acc = acc_evaluator.evaluate(gbt_predictions)

In [13]:
print('accuracy of logistic reg: {0:2.2f}%'.format(lrc_acc*100))
print('accuracy of decision tree: {0:2.2f}%'.format(dtc_acc*100))
print('accuracy of random forest: {0:2.2f}%'.format(rfc_acc*100))
print('accuracy of GBT: {0:2.2f}%'.format(gbt_acc*100))