In [0]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('tree_ca').getOrCreate()

In [0]:
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier, DecisionTreeClassifier
from pyspark.ml.feature import VectorAssembler, VectorIndexer, OneHotEncoder, StringIndexer
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.sql.functions import year, month, dayofmonth

In [0]:
df = sqlContext.sql("SELECT * FROM college_csv")
df.describe().show()

In [0]:
df.columns

In [0]:
ass = VectorAssembler(inputCols=['Apps',
 'Accept',
 'Enroll',
 'Top10perc',
 'Top25perc',
 'F_Undergrad',
 'P_Undergrad',
 'Outstate',
 'Room_Board',
 'Books',
 'Personal',
 'PhD',
 'Terminal',
 'S_F_Ratio',
 'perc_alumni',
 'Expend',
 'Grad_Rate'], outputCol='features')

In [0]:
op = ass.transform(df)

In [0]:
indxr = StringIndexer(inputCol = 'Private', outputCol='PrivateIndex')

In [0]:
opFixed = indxr.fit(op).transform(op)

In [0]:
opFixed.printSchema()

In [0]:
fdf = opFixed.select('features','PrivateIndex')
fdf.show()

In [0]:
train, test = fdf.randomSplit([.7,.3])

In [0]:
dtc = DecisionTreeClassifier(labelCol='PrivateIndex')
rfc = RandomForestClassifier(labelCol='PrivateIndex', numTrees=100)
gbt = GBTClassifier(labelCol='PrivateIndex')

In [0]:
dtcModel = dtc.fit(train)
rfcModel = rfc.fit(train)
gbtModel = gbt.fit(train)


In [0]:
dtcPreds = dtcModel.transform(test)
rfcPreds = rfcModel.transform(test)
gbtPreds = gbtModel.transform(test)


In [0]:
binEval=BinaryClassificationEvaluator(labelCol='PrivateIndex')

In [0]:
print('DTC')
print(binEval.evaluate(dtcPreds))

In [0]:
print('RFC')
print(binEval.evaluate(rfcPreds))

In [0]:
binEval2=BinaryClassificationEvaluator(labelCol='PrivateIndex', rawPredictionCol='prediction')

In [0]:
print('GBT')
print(binEval2.evaluate(gbtPreds))

In [0]:
acc = MulticlassClassificationEvaluator(labelCol='PrivateIndex', metricName='accuracy')

In [0]:
dtcAcc = acc.evaluate(dtcPreds)
rfcAcc = acc.evaluate(rfcPreds)
gbtAcc = acc.evaluate(gbtPreds)


In [0]:
dtcAcc

In [0]:
rfcAcc

In [0]:
gbtAcc