In [3]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler,StringIndexer,OneHotEncoder
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

spark = SparkSession.builder.getOrCreate()
irisDF = spark.read.option('header','true').option('inferSchema','true').csv('datasets/iris-dataset.txt')
indexer = StringIndexer(inputCol='class',outputCol='label')
indexerModel = indexer.fit(irisDF)
irisDF = indexerModel.transform(irisDF)
#print(irisDF.columns[0:4])
vec = VectorAssembler(inputCols=irisDF.columns[0:4],outputCol='features')
irisDF = vec.transform(irisDF)
irisDF =irisDF.select('features','label')

trainDF, testDF = irisDF.randomSplit([0.75,0.25],seed=1234)

classifier = DecisionTreeClassifier()
model = classifier.fit(trainDF)#Training date iss used to create model
resultDF = model.transform(testDF)

eva = MulticlassClassificationEvaluator(metricName='accuracy')

result = eva.evaluate(resultDF)
print("Accuracy : ",result)
if result > 0.96:
    model.save('dtModel')
    print("Model saved.")

Accuracy :  0.9772727272727273
Model saved.


In [4]:
from pyspark.ml.classification import DecisionTreeClassificationModel 
model = DecisionTreeClassificationModel.load('dtModel')
newDF = spark.read.option('header','true').option('inferSchema','true').csv('datasets/new-transactions.txt')
newDF = vec.transform(newDF)
newTransactionsFroudDF = model.transform(newDF)
newTransactionsFroudDF.show()

+------------+-----------+------------+-----------+-----------------+--------------+-------------+----------+
|sepal-length|sepal-width|petal-length|petal-width|         features| rawPrediction|  probability|prediction|
+------------+-----------+------------+-----------+-----------------+--------------+-------------+----------+
|         5.1|        3.5|         1.4|        0.2|[5.1,3.5,1.4,0.2]|[34.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|         6.3|        2.5|         4.9|        1.5|[6.3,2.5,4.9,1.5]|[0.0,32.0,0.0]|[0.0,1.0,0.0]|       1.0|
|         6.1|        2.8|         4.7|        1.2|[6.1,2.8,4.7,1.2]|[0.0,32.0,0.0]|[0.0,1.0,0.0]|       1.0|
|         6.9|        3.1|         5.1|        2.3|[6.9,3.1,5.1,2.3]|[0.0,0.0,33.0]|[0.0,0.0,1.0]|       2.0|
|         4.9|        3.0|         1.4|        0.2|[4.9,3.0,1.4,0.2]|[34.0,0.0,0.0]|[1.0,0.0,0.0]|       0.0|
|         5.9|        3.2|         4.8|        1.8|[5.9,3.2,4.8,1.8]| [0.0,1.0,0.0]|[0.0,1.0,0.0]|       1.0|
|         