<a href="https://colab.research.google.com/github/gopinathmoorthy-DS/Spark-ML/blob/main/ML_with_spark.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# install java
!apt-get install openjdk-8-jdk-headless -qq > /dev/null

# install spark (change the version number if needed)
!wget -q https://archive.apache.org/dist/spark/spark-3.0.0/spark-3.0.0-bin-hadoop3.2.tgz

# unzip the spark file to the current folder
!tar xf spark-3.0.0-bin-hadoop3.2.tgz

# set your spark folder to your system path environment. 
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.0.0-bin-hadoop3.2"


# install findspark using pip
!pip install -q findspark

## install pyspark
!pip install pyspark

Collecting pyspark
[?25l  Downloading https://files.pythonhosted.org/packages/f0/26/198fc8c0b98580f617cb03cb298c6056587b8f0447e20fa40c5b634ced77/pyspark-3.0.1.tar.gz (204.2MB)
[K     |████████████████████████████████| 204.2MB 61kB/s 
[?25hCollecting py4j==0.10.9
[?25l  Downloading https://files.pythonhosted.org/packages/9e/b6/6a4fb90cd235dc8e265a6a2067f2a2c99f0d91787f06aca4bcf7c23f3f80/py4j-0.10.9-py2.py3-none-any.whl (198kB)
[K     |████████████████████████████████| 204kB 44.6MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.0.1-py2.py3-none-any.whl size=204612243 sha256=e9130ca0aa206a6633dbd9c83d5019e9fa85a81fc7f2533f925d4f1b6ba925a6
  Stored in directory: /root/.cache/pip/wheels/5e/bd/07/031766ca628adec8435bb40f0bd83bb676ce65ff4007f8e73f
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9 pyspark-3.0.1


In [None]:
## Loading files from google drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
## Creating spark session
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
sc= spark.sparkContext

In [None]:
## Importing libraries
from pyspark.sql import functions as F
from pyspark.ml import Pipeline
from pyspark.ml.feature import OneHotEncoder,StringIndexer,VectorAssembler,Imputer,QuantileDiscretizer
from pyspark.ml.classification import LogisticRegression

In [None]:
## function to read data in different formats
def read_data(file_path,file_type,inferschema="true",first_row_header="true",delimiter=",",nanvalue=' ',nullvalue=' '):
  df=spark.read.format(file_type)\
     .option("inferSchema",inferschema)\
     .option("header",first_row_header)\
     .option("sep",delimiter)\
     .option("nanValue",nanvalue)\
     .option("nullValue",nullvalue)\
     .load(file_path) 
  return df

In [None]:
## Finding missing columns in percentage/count
def missing_count(df,percentage=False,cols=None):
  if percentage is False:
    if cols is None:
      df_agg = df.agg(*[F.count(F.when(F.isnull(c), c)).alias(c) for c in df.columns])
    else:
      df_agg = df.agg(*[F.count(F.when(F.isnull(c), c)).alias(c) for c in cols])
  else:
    df_size=df.count()
    if cols is None:
      df_agg = df.agg(*[F.round(((F.count(F.when(F.isnull(c), c))/df_size)*100),2).alias(c) for c in df.columns])
    else:
      df_agg = df.agg(*[F.round(((F.count(F.when(F.isnull(c), c))/df_size)*100),2).alias(c) for c in cols])
  return df_agg

In [None]:
## File location and type
file_location='/content/drive/My Drive/wa-fnusec-telcocustomerchurn/WA_Fn-UseC_-Telco-Customer-Churn.csv'
file_type='csv'

In [None]:
## Reading data
customer_df=read_data(file_location,file_type)

In [None]:
customer_df.show(5)

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|          No|No phone service|            DSL|            No|         Yes|              No|         No|    

In [None]:
customer_df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



In [None]:
missing_count(customer_df).show()

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|         0|     0|            0|      0|         0|     0|           0|            0|              0|             0|           0|               0|          0|          0|              0|       0|               0| 

In [None]:
## churn and not churn users
customer_df.groupBy('Churn').count().show()

+-----+-----+
|Churn|count|
+-----+-----+
|   No| 5174|
|  Yes| 1869|
+-----+-----+



In [None]:
## EDA numeric columns
customer_df.select('tenure','TotalCharges','MonthlyCharges').describe().show()

+-------+------------------+------------------+------------------+
|summary|            tenure|      TotalCharges|    MonthlyCharges|
+-------+------------------+------------------+------------------+
|  count|              7043|              7032|              7043|
|   mean| 32.37114865824223|2283.3004408418697| 64.76169246059922|
| stddev|24.559481023094442| 2266.771361883145|30.090047097678482|
|    min|                 0|              18.8|             18.25|
|    max|                72|            8684.8|            118.75|
+-------+------------------+------------------+------------------+



In [None]:
## Gender wise churn
customer_df.groupBy('gender','Churn').count().show()

+------+-----+-----+
|gender|Churn|count|
+------+-----+-----+
|  Male|   No| 2625|
|  Male|  Yes|  930|
|Female|   No| 2549|
|Female|  Yes|  939|
+------+-----+-----+



In [None]:
## Split data into train and test
(train_data,test_data)=customer_df.randomSplit([0.7,0.3],24)

In [None]:
## Catergorical columns
catColumns=['gender','SeniorCitizen','Partner','Dependents','PhoneService','MultipleLines','InternetService','OnlineSecurity',
 'OnlineBackup','DeviceProtection','TechSupport','StreamingTV','StreamingMovies','Contract',
 'PaperlessBilling','PaymentMethod']

In [None]:
## setting different stages for catergorical column indexing
stages=[]
for catCol in catColumns:
  stringIndexer=StringIndexer(inputCol=catCol,outputCol=catCol+"Index")
  encoder=OneHotEncoder(inputCols=[stringIndexer.getOutputCol()],outputCols=[catCol+"catVec"])
  stages+=[stringIndexer,encoder]

In [None]:
stages

[StringIndexer_5595ae2e2b67,
 OneHotEncoder_b7e3182ab1ae,
 StringIndexer_3478750a1900,
 OneHotEncoder_85550c08544f,
 StringIndexer_29faa38db638,
 OneHotEncoder_6e7814bb2163,
 StringIndexer_ef57268e408f,
 OneHotEncoder_dc750f1abba5,
 StringIndexer_f61a440c248b,
 OneHotEncoder_210e90fcf582,
 StringIndexer_3bfd3d787d50,
 OneHotEncoder_40c0c0e6ff48,
 StringIndexer_d2739a21a07a,
 OneHotEncoder_de5e7e06ce30,
 StringIndexer_211bc5331312,
 OneHotEncoder_5aa9e27e3c89,
 StringIndexer_6142645124e3,
 OneHotEncoder_4e24e152607b,
 StringIndexer_5a18ba84cd20,
 OneHotEncoder_d83fc8bb7f28,
 StringIndexer_28ca45568974,
 OneHotEncoder_25e0fac4b4dd,
 StringIndexer_2570da0153b9,
 OneHotEncoder_a828e58f02dd,
 StringIndexer_09454fb12bc9,
 OneHotEncoder_7eeaa733ea8a,
 StringIndexer_ccf28dd9b1ab,
 OneHotEncoder_d53ab22b5195,
 StringIndexer_eb346c83a282,
 OneHotEncoder_a79c9a8f85ff,
 StringIndexer_96711aafb79b,
 OneHotEncoder_ab8cb71cab9e]

In [None]:
## Imputing Missing in Total Charges
imputer=Imputer(inputCols=["TotalCharges"],outputCols=["Out_TotalCharges"])
stages+=[imputer]

In [None]:
label_Idx=StringIndexer(inputCol="Churn",outputCol="label")
stages+=[label_Idx]

In [None]:
temp=label_Idx.fit(train_data).transform(train_data)

In [None]:
temp.show()

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|     OnlineSecurity|       OnlineBackup|   DeviceProtection|        TechSupport|        StreamingTV|    StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|label|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+-----+
|0002-ORFBO|Female|            0|    Yes|  

In [None]:
customer_df.stat.corr("MonthlyCharges","TotalCharges")

0.6511738315787816

In [None]:
customer_df.groupBy("tenure","Churn").count().show()

+------+-----+-----+
|tenure|Churn|count|
+------+-----+-----+
|    46|  Yes|   12|
|    27|   No|   59|
|    15|  Yes|   37|
|    60|   No|   70|
|    71|  Yes|    6|
|    33|   No|   50|
|     3|  Yes|   94|
|    14|  Yes|   24|
|    46|   No|   62|
|    47|   No|   54|
|     8|   No|   81|
|    60|  Yes|    6|
|    57|  Yes|    8|
|    52|  Yes|    8|
|    14|   No|   52|
|    41|   No|   56|
|    39|  Yes|   14|
|    19|  Yes|   19|
|    31|   No|   49|
|    66|  Yes|   13|
+------+-----+-----+
only showing top 20 rows



In [None]:
## Bucketing tenure of the customer
tenure_bin=QuantileDiscretizer(numBuckets=3,inputCol="tenure",outputCol="tenure_bin")
stages+=[tenure_bin]

In [None]:
stages

[StringIndexer_5595ae2e2b67,
 OneHotEncoder_b7e3182ab1ae,
 StringIndexer_3478750a1900,
 OneHotEncoder_85550c08544f,
 StringIndexer_29faa38db638,
 OneHotEncoder_6e7814bb2163,
 StringIndexer_ef57268e408f,
 OneHotEncoder_dc750f1abba5,
 StringIndexer_f61a440c248b,
 OneHotEncoder_210e90fcf582,
 StringIndexer_3bfd3d787d50,
 OneHotEncoder_40c0c0e6ff48,
 StringIndexer_d2739a21a07a,
 OneHotEncoder_de5e7e06ce30,
 StringIndexer_211bc5331312,
 OneHotEncoder_5aa9e27e3c89,
 StringIndexer_6142645124e3,
 OneHotEncoder_4e24e152607b,
 StringIndexer_5a18ba84cd20,
 OneHotEncoder_d83fc8bb7f28,
 StringIndexer_28ca45568974,
 OneHotEncoder_25e0fac4b4dd,
 StringIndexer_2570da0153b9,
 OneHotEncoder_a828e58f02dd,
 StringIndexer_09454fb12bc9,
 OneHotEncoder_7eeaa733ea8a,
 StringIndexer_ccf28dd9b1ab,
 OneHotEncoder_d53ab22b5195,
 StringIndexer_eb346c83a282,
 OneHotEncoder_a79c9a8f85ff,
 StringIndexer_96711aafb79b,
 OneHotEncoder_ab8cb71cab9e,
 Imputer_6b72daa84a4c,
 StringIndexer_67c85fb92e11,
 QuantileDiscretizer

In [None]:
## Adding all the columns into vector assembler
numericCols=["tenure_bin","Out_TotalCharges","MonthlyCharges"]
assembleInputs=[c+"catVec" for c in catColumns]+numericCols
assembler=VectorAssembler(inputCols=assembleInputs,outputCol="features")
stages+=[assembler]

In [None]:
## adding different stages into Pipeline
pipeline=Pipeline().setStages(stages)
pipelineModel=pipeline.fit(train_data)

In [None]:
trainprepDF=pipelineModel.transform(train_data)
testprepDF=pipelineModel.transform(test_data)

In [None]:
trainprepDF.head(1)

[Row(customerID='0002-ORFBO', gender='Female', SeniorCitizen=0, Partner='Yes', Dependents='Yes', tenure=9, PhoneService='Yes', MultipleLines='No', InternetService='DSL', OnlineSecurity='No', OnlineBackup='Yes', DeviceProtection='No', TechSupport='Yes', StreamingTV='Yes', StreamingMovies='No', Contract='One year', PaperlessBilling='Yes', PaymentMethod='Mailed check', MonthlyCharges=65.6, TotalCharges=593.3, Churn='No', genderIndex=1.0, gendercatVec=SparseVector(1, {}), SeniorCitizenIndex=0.0, SeniorCitizencatVec=SparseVector(1, {0: 1.0}), PartnerIndex=1.0, PartnercatVec=SparseVector(1, {}), DependentsIndex=1.0, DependentscatVec=SparseVector(1, {}), PhoneServiceIndex=0.0, PhoneServicecatVec=SparseVector(1, {0: 1.0}), MultipleLinesIndex=0.0, MultipleLinescatVec=SparseVector(2, {0: 1.0}), InternetServiceIndex=1.0, InternetServicecatVec=SparseVector(2, {1: 1.0}), OnlineSecurityIndex=0.0, OnlineSecuritycatVec=SparseVector(2, {0: 1.0}), OnlineBackupIndex=1.0, OnlineBackupcatVec=SparseVector(2

In [None]:
customer_df.select('PaymentMethod').distinct().show()

+--------------------+
|       PaymentMethod|
+--------------------+
|Credit card (auto...|
|        Mailed check|
|Bank transfer (au...|
|    Electronic check|
+--------------------+



In [None]:
## Create initial Logistic Regression model
lr=LogisticRegression(labelCol="label",featuresCol="features",maxIter=10)

## Train model with Training Data
lrModel=lr.fit(trainprepDF)

In [None]:
print("Coefficients: "+str(lrModel.coefficients))
print("Intercept: "+str(lrModel.intercept))

Coefficients: [0.0342809611746217,-0.331216284496287,-0.10403353480847842,0.12868949953119735,-0.6191319066991583,-0.2471779782628914,0.02685651848830146,0.5947672912779578,-0.43247345849078295,0.2603430987936111,-0.0787349437852818,0.17599225878190355,0.02509437687639419,0.13427108880327826,0.07038245227252787,0.24746208612235363,-0.06197277109519446,-0.03742781487967394,0.24498552330495715,0.012591165389668446,0.1928780997457295,0.707410578947039,-0.852599384482354,0.2978156771216554,0.3370045996232975,-0.08658398567782667,0.010002030098458728,-0.722185024128052,-0.00011665000139783336,0.005179321234768732]
Intercept: -1.1540815271668787


In [None]:
summary=lrModel.summary

In [None]:
## Evaluation metrics
accuracy=summary.accuracy
falsePositiveRate=summary.weightedFalsePositiveRate
truePositiveRate=summary.weightedTruePositiveRate
fMeasure=summary.weightedFMeasure()
precision=summary.weightedPrecision
recall=summary.weightedRecall

In [None]:
print("Accuracy: %s\nFPR:%s\nTPR:%s\nF-Measure:%s\nPrecision:%s\nRecall:%s\nAreaUnderROC:%s"
      %(accuracy,falsePositiveRate,truePositiveRate,fMeasure,precision,recall,summary.areaUnderROC))

Accuracy: 0.8057466612707406
FPR:0.3625173099768732
TPR:0.8057466612707406
F-Measure:0.7993759299320431
Precision:0.7971094940545889
Recall:0.8057466612707406
AreaUnderROC:0.8498426976249556


In [None]:
## ROC 
from pyspark.ml.evaluation import BinaryClassificationEvaluator
predictions=lrModel.transform(testprepDF)
evaluatorLR=BinaryClassificationEvaluator(rawPredictionCol="prediction")
area_under_curve=evaluatorLR.evaluate(predictions)

## default evaluation in ROC
print("areaUnderROC= %g" % area_under_curve)

evaluatorLR.getMetricName()

areaUnderROC= 0.704116


'areaUnderROC'

In [None]:
from pyspark.mllib.evaluation import BinaryClassificationMetrics
results=predictions.select(['prediction','label'])

In [None]:
## Prepare score-label set
results_collect=results.collect()
results_list=[(float(i[0]),float(i[1])) for i in results_collect]
predictionAndLabels=sc.parallelize(results_list)

metrics=BinaryClassificationMetrics(predictionAndLabels)

## Area under precision-recall curve
print("Area under PR=%s"% metrics.areaUnderPR)

## Area under ROC curve
print("Area under ROC=%s"% metrics.areaUnderROC)

predictions.show(1)

Area under PR=0.5395619163792253
Area under ROC=0.7041164527040232
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+--------------------+--------------+------------+-----+-----------+-------------+------------------+-------------------+------------+-------------+---------------+----------------+-----------------+------------------+------------------+-------------------+--------------------+---------------------+-------------------+--------------------+-----------------+------------------+---------------------+----------------------+----------------+-----------------+----------------+-----------------+--------------------+---------------------+-------------+--------------+---------------------+----------------------+------------------+-------------------+----------------+-----+----------+--------------------+-----------------

In [None]:
## evaluation metrics for test data
count=predictions.count()
correct=results.filter(results.prediction==results.label).count()
wrong=results.filter(results.prediction!=results.label).count()
tp=results.filter(results.prediction==1.0).filter(results.prediction==results.label).count()
fp=results.filter(results.prediction==1.0).filter(results.prediction!=results.label).count()
fn=results.filter(results.prediction==0.0).filter(results.prediction!=results.label).count()
tn=results.filter(results.prediction==0.0).filter(results.prediction==results.label).count()

accuracy=(tp+tn)/count
precision=tp/(tp+fp)
recall=tp/(tp+fn)

print("Correct: %s\nWrong:%s\ntp:%s\nfp:%s\nfn:%s\ntn:%s\nAccuracy:%s\nPrecision:%s\nRecall:%s"
      %(correct,wrong,tp,fp,fn,tn,accuracy,precision,recall))

Correct: 1652
Wrong:449
tp:298
fp:180
fn:269
tn:1354
Accuracy:0.786292241789624
Precision:0.6234309623430963
Recall:0.5255731922398589


In [None]:
from pyspark.ml.tuning import ParamGridBuilder,CrossValidator

## Create a ParamGrid for Cross Validation
paramGrid=(ParamGridBuilder()
          .addGrid(lr.regParam,[0.01,0.5,2.0])
          .addGrid(lr.elasticNetParam,[0.0,0.5,1.0])
          .addGrid(lr.maxIter,[5,10,20])
          .build())

In [None]:
cv=CrossValidator(estimator=lr,estimatorParamMaps=paramGrid,evaluator=evaluatorLR,numFolds=5)

## Run Cross validations
cvModel=cv.fit(trainprepDF)

In [None]:
predictions=cvModel.bestModel.transform(testprepDF)

In [None]:
evaluatorLR.evaluate(predictions)

0.7058421804184516

In [None]:
results=predictions.select(['prediction','label'])

count=predictions.count()
correct=results.filter(results.prediction==results.label).count()
wrong=results.filter(results.prediction!=results.label).count()
tp=results.filter(results.prediction==1.0).filter(results.prediction==results.label).count()
fp=results.filter(results.prediction==1.0).filter(results.prediction!=results.label).count()
fn=results.filter(results.prediction==0.0).filter(results.prediction!=results.label).count()
tn=results.filter(results.prediction==0.0).filter(results.prediction==results.label).count()

accuracy=(tp+tn)/count
precision=tp/(tp+fp)
recall=tp/(tp+fn)

print("Correct: %s\nWrong:%s\ntp:%s\nfp:%s\nfn:%s\ntn:%s\nAccuracy:%s\nPrecision:%s\nRecall:%s"
      %(correct,wrong,tp,fp,fn,tn,accuracy,precision,recall))

Correct: 1659
Wrong:442
tp:297
fp:172
fn:270
tn:1362
Accuracy:0.7896239885768681
Precision:0.6332622601279317
Recall:0.5238095238095238


In [None]:
cvModel.explainParams()

"estimator: estimator to be cross-validated (current: LogisticRegression_9099ab21d606)\nestimatorParamMaps: estimator param maps (current: [{Param(parent='LogisticRegression_9099ab21d606', name='regParam', doc='regularization parameter (>= 0).'): 0.01, Param(parent='LogisticRegression_9099ab21d606', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.'): 0.0, Param(parent='LogisticRegression_9099ab21d606', name='maxIter', doc='max number of iterations (>= 0).'): 5}, {Param(parent='LogisticRegression_9099ab21d606', name='regParam', doc='regularization parameter (>= 0).'): 0.01, Param(parent='LogisticRegression_9099ab21d606', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.'): 0.0, Param(parent='LogisticRegression_9099ab21d606', name='maxIter', doc='max number of iterations

In [None]:
## Random forest classier
from pyspark.ml.classification import RandomForestClassifier

rf=RandomForestClassifier(labelCol="label",featuresCol="features").setImpurity("gini").setMaxDepth(6).setNumTrees(50)\
                                                                  .setFeatureSubsetStrategy("auto").setSeed(1010)
rfModel=rf.fit(trainprepDF)

In [None]:
predictions=rfModel.transform(testprepDF)

In [None]:
results=predictions.select(['prediction','label'])

count=predictions.count()
correct=results.filter(results.prediction==results.label).count()
wrong=results.filter(results.prediction!=results.label).count()
tp=results.filter(results.prediction==1.0).filter(results.prediction==results.label).count()
fp=results.filter(results.prediction==1.0).filter(results.prediction!=results.label).count()
fn=results.filter(results.prediction==0.0).filter(results.prediction!=results.label).count()
tn=results.filter(results.prediction==0.0).filter(results.prediction==results.label).count()

accuracy=(tp+tn)/count
precision=tp/(tp+fp)
recall=tp/(tp+fn)

print("Correct: %s\nWrong:%s\ntp:%s\nfp:%s\nfn:%s\ntn:%s\nAccuracy:%s\nPrecision:%s\nRecall:%s"
      %(correct,wrong,tp,fp,fn,tn,accuracy,precision,recall))

Correct: 1647
Wrong:454
tp:233
fp:120
fn:334
tn:1414
Accuracy:0.7839124226558781
Precision:0.660056657223796
Recall:0.4109347442680776
