In [1]:
from pyspark.sql import SparkSession

spark = SparkSession \
    .builder \
    .appName("Python Spark data cleaning and engineering") \
    .config("spark.some.config.option", "some-value") \
    .getOrCreate()

In [2]:
import plotly
plotly.__version__

'4.9.0'

In [3]:
file_location = "datasets/datasets_13996_18858_WA_Fn-UseC_-Telco-Customer-Churn.csv"
file_type = "csv"

# CSV options
infer_schema = "false"
first_row_is_header = "True"
delimiter = ","

# The applied options are for CSV files. For other file types, these will be ignored.
df = spark.read.format(file_type) \
  .option("inferSchema", infer_schema) \
  .option("header", first_row_is_header) \
  .option("sep", delimiter) \
  .option("nanValue", ' ') \
  .option("nullValue", ' ') \
  .load(file_location)

display(df)

DataFrame[customerID: string, gender: string, SeniorCitizen: string, Partner: string, Dependents: string, tenure: string, PhoneService: string, MultipleLines: string, InternetService: string, OnlineSecurity: string, OnlineBackup: string, DeviceProtection: string, TechSupport: string, StreamingTV: string, StreamingMovies: string, Contract: string, PaperlessBilling: string, PaymentMethod: string, MonthlyCharges: string, TotalCharges: string, Churn: string]

In [4]:
df.show()

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|     OnlineSecurity|       OnlineBackup|   DeviceProtection|        TechSupport|        StreamingTV|    StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|  

In [5]:
from pyspark.sql.functions import isnan, when, count, col
df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns]).show()

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|         0|     0|            0|      0|         0|     0|           0|            0|              0|             0|           0|               0|          0|          0|              0|       0|               0| 

In [6]:
df.groupBy("Churn").count().show()

+-----+-----+
|Churn|count|
+-----+-----+
|   No| 5174|
|  Yes| 1869|
+-----+-----+



In [7]:
df.select("tenure").show()

+------+
|tenure|
+------+
|     1|
|    34|
|     2|
|    45|
|     2|
|     8|
|    22|
|    10|
|    28|
|    62|
|    13|
|    16|
|    58|
|    49|
|    25|
|    69|
|    52|
|    71|
|    10|
|    21|
+------+
only showing top 20 rows



In [8]:
df.select("tenure", "TotalCharges", "MonthlyCharges").describe().show()

+-------+------------------+------------------+------------------+
|summary|            tenure|      TotalCharges|    MonthlyCharges|
+-------+------------------+------------------+------------------+
|  count|              7043|              7032|              7043|
|   mean| 32.37114865824223|2283.3004408418697| 64.76169246059922|
| stddev|24.559481023094442| 2266.771361883145|30.090047097678482|
|    min|                 0|             100.2|               100|
|    max|                 9|             999.9|             99.95|
+-------+------------------+------------------+------------------+



In [9]:
df.stat.crosstab("SeniorCitizen", "InternetService").show()

+-----------------------------+----+-----------+----+
|SeniorCitizen_InternetService| DSL|Fiber optic|  No|
+-----------------------------+----+-----------+----+
|                            1| 259|        831|  52|
|                            0|2162|       2265|1474|
+-----------------------------+----+-----------+----+



In [10]:
df.columns

['customerID',
 'gender',
 'SeniorCitizen',
 'Partner',
 'Dependents',
 'tenure',
 'PhoneService',
 'MultipleLines',
 'InternetService',
 'OnlineSecurity',
 'OnlineBackup',
 'DeviceProtection',
 'TechSupport',
 'StreamingTV',
 'StreamingMovies',
 'Contract',
 'PaperlessBilling',
 'PaymentMethod',
 'MonthlyCharges',
 'TotalCharges',
 'Churn']

In [11]:
df.stat.freqItems(["PhoneService", "InternetService", "MultipleLines", "OnlineSecurity", 
                   "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies"], 0.6).collect()

[Row(PhoneService_freqItems=['Yes'], InternetService_freqItems=['Fiber optic'], MultipleLines_freqItems=['Yes'], OnlineSecurity_freqItems=['No'], OnlineBackup_freqItems=['Yes'], DeviceProtection_freqItems=['No'], TechSupport_freqItems=['No'], StreamingTV_freqItems=['Yes'], StreamingMovies_freqItems=['No'])]

In [12]:
df = df.withColumn("TotalCharges", df["TotalCharges"].cast("float"))
df = df.withColumn("MonthlyCharges", df["MonthlyCharges"].cast("float"))
df = df.withColumn("tenure", df["tenure"].cast("int"))

In [13]:
churn_df = df
(train_data, test_data) = churn_df.randomSplit([0.7, 0.3], 24)

print(f"Records for training: {str(train_data.count())}")
print(f"Records for testing: {str(test_data.count())}")

Records for training: 4942
Records for testing: 2101


In [14]:
from pyspark.ml import pipeline
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler 

In [15]:
cat_columns = ["gender", "SeniorCitizen", "Partner", "Dependents", "PhoneService", "MultipleLines", "InternetService", 
               "OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies",
              "Contract", "PaperlessBilling", "PaymentMethod"]

In [16]:
df = df.withColumn("TotalCharges", df["TotalCharges"].cast("float"))
df = df.withColumn("MonthlyCharges", df["MonthlyCharges"].cast("float"))


In [17]:
len(cat_columns)

16

In [18]:
stage = []
for i in cat_columns:
    string_indexer = StringIndexer(inputCol=i, outputCol=i + "index")
    print(string_indexer.getOutputCol())
    one_hot = OneHotEncoder(inputCols=[string_indexer.getOutputCol()],
                        outputCols=[i + "catVec"])
    stage += [string_indexer, one_hot]

genderindex
SeniorCitizenindex
Partnerindex
Dependentsindex
PhoneServiceindex
MultipleLinesindex
InternetServiceindex
OnlineSecurityindex
OnlineBackupindex
DeviceProtectionindex
TechSupportindex
StreamingTVindex
StreamingMoviesindex
Contractindex
PaperlessBillingindex
PaymentMethodindex


In [19]:
from pyspark.ml.feature import Imputer
imputer = Imputer(inputCols=["TotalCharges"], outputCols=["Out_TotalCharges"])
stage += [imputer]

In [20]:
label_idx = StringIndexer(inputCol="Churn", outputCol="label")
stage += [label_idx]

In [21]:
stage

[StringIndexer_a89b3accba7e,
 OneHotEncoder_fbfd737c41a9,
 StringIndexer_927d65535359,
 OneHotEncoder_1e4aef4b5415,
 StringIndexer_8c9dba0294fc,
 OneHotEncoder_c3a6628ac928,
 StringIndexer_fda9336a1461,
 OneHotEncoder_a129ef515604,
 StringIndexer_59f1a09127f0,
 OneHotEncoder_e5d85d73de43,
 StringIndexer_764880f064a6,
 OneHotEncoder_e98ca3123f3a,
 StringIndexer_d66c24be3835,
 OneHotEncoder_7e466340ac5d,
 StringIndexer_85025c255102,
 OneHotEncoder_ce188d9fa2e4,
 StringIndexer_b490a1773dc7,
 OneHotEncoder_5a7716a5c16c,
 StringIndexer_848c72a4bf6f,
 OneHotEncoder_564139078285,
 StringIndexer_09ff5eb99813,
 OneHotEncoder_f887ab1ff4f8,
 StringIndexer_8bf743934276,
 OneHotEncoder_3a87170a7fc2,
 StringIndexer_b117fbb1c2bf,
 OneHotEncoder_2cf69133ffa5,
 StringIndexer_93742915bebe,
 OneHotEncoder_b7d865625929,
 StringIndexer_faa462f39493,
 OneHotEncoder_0b564380877b,
 StringIndexer_399a91df7ae5,
 OneHotEncoder_a1cf63b8309e,
 Imputer_a731b973d26c,
 StringIndexer_fabdfdfaba8c]

In [22]:
temp = label_idx.fit(train_data).transform(train_data)
temp.show(1)

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|label|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+-----+
|0002-ORFBO|Female|            0|    Yes|       Yes|     9|         Yes|           No|            DSL|            No|         Yes|              No|        Yes|        Yes|             No|One year|

In [23]:
df.select("MonthlyCharges").show()

+--------------+
|MonthlyCharges|
+--------------+
|         29.85|
|         56.95|
|         53.85|
|          42.3|
|          70.7|
|         99.65|
|          89.1|
|         29.75|
|         104.8|
|         56.15|
|         49.95|
|         18.95|
|        100.35|
|         103.7|
|         105.5|
|        113.25|
|         20.65|
|         106.7|
|          55.2|
|         90.05|
+--------------+
only showing top 20 rows



In [24]:
df.select("TotalCharges").show()

+------------+
|TotalCharges|
+------------+
|       29.85|
|      1889.5|
|      108.15|
|     1840.75|
|      151.65|
|       820.5|
|      1949.4|
|       301.9|
|     3046.05|
|     3487.95|
|      587.45|
|       326.8|
|      5681.1|
|      5036.3|
|     2686.05|
|     7895.15|
|     1022.95|
|     7382.25|
|      528.35|
|      1862.9|
+------------+
only showing top 20 rows



In [25]:
df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: string (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: float (nullable = true)
 |-- TotalCharges: float (nullable = true)
 |-- Churn: string (nullable = true)



In [26]:
df = df.withColumn("TotalCharges", df["TotalCharges"].cast("float"))

In [27]:
df = df.withColumn("MonthlyCharges", df["MonthlyCharges"].cast("float"))

In [28]:
df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: string (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: float (nullable = true)
 |-- TotalCharges: float (nullable = true)
 |-- Churn: string (nullable = true)



In [29]:
df.stat.corr("TotalCharges", "MonthlyCharges")

0.6511738319154315

In [30]:
from pyspark.ml.feature import QuantileDiscretizer
tenure_bin = QuantileDiscretizer(numBuckets=3, inputCol="tenure", outputCol="tenure_bin")
stage += [tenure_bin]

In [31]:
stage

[StringIndexer_a89b3accba7e,
 OneHotEncoder_fbfd737c41a9,
 StringIndexer_927d65535359,
 OneHotEncoder_1e4aef4b5415,
 StringIndexer_8c9dba0294fc,
 OneHotEncoder_c3a6628ac928,
 StringIndexer_fda9336a1461,
 OneHotEncoder_a129ef515604,
 StringIndexer_59f1a09127f0,
 OneHotEncoder_e5d85d73de43,
 StringIndexer_764880f064a6,
 OneHotEncoder_e98ca3123f3a,
 StringIndexer_d66c24be3835,
 OneHotEncoder_7e466340ac5d,
 StringIndexer_85025c255102,
 OneHotEncoder_ce188d9fa2e4,
 StringIndexer_b490a1773dc7,
 OneHotEncoder_5a7716a5c16c,
 StringIndexer_848c72a4bf6f,
 OneHotEncoder_564139078285,
 StringIndexer_09ff5eb99813,
 OneHotEncoder_f887ab1ff4f8,
 StringIndexer_8bf743934276,
 OneHotEncoder_3a87170a7fc2,
 StringIndexer_b117fbb1c2bf,
 OneHotEncoder_2cf69133ffa5,
 StringIndexer_93742915bebe,
 OneHotEncoder_b7d865625929,
 StringIndexer_faa462f39493,
 OneHotEncoder_0b564380877b,
 StringIndexer_399a91df7ae5,
 OneHotEncoder_a1cf63b8309e,
 Imputer_a731b973d26c,
 StringIndexer_fabdfdfaba8c,
 QuantileDiscretizer

In [32]:
numeric_cols = ["tenure_bin", "Out_TotalCharges", "MonthlyCharges"]
assembleInputs = assemblerInputs = [c + "catVec" for c in cat_columns] + numeric_cols
assembler = VectorAssembler(inputCols=assembleInputs, outputCol="feature")
stage += [assembler]

In [33]:
stage

[StringIndexer_a89b3accba7e,
 OneHotEncoder_fbfd737c41a9,
 StringIndexer_927d65535359,
 OneHotEncoder_1e4aef4b5415,
 StringIndexer_8c9dba0294fc,
 OneHotEncoder_c3a6628ac928,
 StringIndexer_fda9336a1461,
 OneHotEncoder_a129ef515604,
 StringIndexer_59f1a09127f0,
 OneHotEncoder_e5d85d73de43,
 StringIndexer_764880f064a6,
 OneHotEncoder_e98ca3123f3a,
 StringIndexer_d66c24be3835,
 OneHotEncoder_7e466340ac5d,
 StringIndexer_85025c255102,
 OneHotEncoder_ce188d9fa2e4,
 StringIndexer_b490a1773dc7,
 OneHotEncoder_5a7716a5c16c,
 StringIndexer_848c72a4bf6f,
 OneHotEncoder_564139078285,
 StringIndexer_09ff5eb99813,
 OneHotEncoder_f887ab1ff4f8,
 StringIndexer_8bf743934276,
 OneHotEncoder_3a87170a7fc2,
 StringIndexer_b117fbb1c2bf,
 OneHotEncoder_2cf69133ffa5,
 StringIndexer_93742915bebe,
 OneHotEncoder_b7d865625929,
 StringIndexer_faa462f39493,
 OneHotEncoder_0b564380877b,
 StringIndexer_399a91df7ae5,
 OneHotEncoder_a1cf63b8309e,
 Imputer_a731b973d26c,
 StringIndexer_fabdfdfaba8c,
 QuantileDiscretizer

In [34]:
from pyspark.ml import Pipeline

In [35]:
pipeline = Pipeline().setStages(stage)
pipelineModel = pipeline.fit(train_data)

In [36]:
train_df = pipelineModel.transform(train_data)
test_df = pipelineModel.transform(test_data)

In [37]:
len(train_df.columns)

57

In [38]:
train_df.head(1)

[Row(customerID='0002-ORFBO', gender='Female', SeniorCitizen='0', Partner='Yes', Dependents='Yes', tenure=9, PhoneService='Yes', MultipleLines='No', InternetService='DSL', OnlineSecurity='No', OnlineBackup='Yes', DeviceProtection='No', TechSupport='Yes', StreamingTV='Yes', StreamingMovies='No', Contract='One year', PaperlessBilling='Yes', PaymentMethod='Mailed check', MonthlyCharges=65.5999984741211, TotalCharges=593.2999877929688, Churn='No', genderindex=1.0, gendercatVec=SparseVector(1, {}), SeniorCitizenindex=0.0, SeniorCitizencatVec=SparseVector(1, {0: 1.0}), Partnerindex=1.0, PartnercatVec=SparseVector(1, {}), Dependentsindex=1.0, DependentscatVec=SparseVector(1, {}), PhoneServiceindex=0.0, PhoneServicecatVec=SparseVector(1, {0: 1.0}), MultipleLinesindex=0.0, MultipleLinescatVec=SparseVector(2, {0: 1.0}), InternetServiceindex=1.0, InternetServicecatVec=SparseVector(2, {1: 1.0}), OnlineSecurityindex=0.0, OnlineSecuritycatVec=SparseVector(2, {0: 1.0}), OnlineBackupindex=1.0, OnlineB

In [39]:
train_df.select("tenure_bin").show()

+----------+
|tenure_bin|
+----------+
|       0.0|
|       0.0|
|       0.0|
|       0.0|
|       0.0|
|       0.0|
|       2.0|
|       0.0|
|       2.0|
|       2.0|
|       0.0|
|       2.0|
|       1.0|
|       0.0|
|       1.0|
|       0.0|
|       0.0|
|       0.0|
|       2.0|
|       2.0|
+----------+
only showing top 20 rows



In [40]:
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(labelCol="label", featuresCol="feature", maxIter=10)
lrModel = lr.fit(train_df)

In [41]:
print(f"coefficients: {str(lrModel.coefficients)}")
print(f"Intercept: {str(lrModel.intercept)}")

coefficients: [0.034280961208810704,-0.3312162846190503,-0.10403353386750051,0.12868949858748596,-0.6191319042742331,-0.24717797905720307,0.02685652018272491,0.5947672944599564,-0.43247345944430066,0.26034309970076586,-0.07873494210934841,0.17599225888273645,0.025094379273215246,0.13427108952998187,0.07038245400016194,0.2474620870056942,-0.06197276941501819,-0.03742781503932004,0.2449855258719482,0.012591165496948949,0.19287810202563474,0.7074105781395923,-0.852599383222226,0.2978156779228366,0.3370045996985892,-0.08658398635534728,0.010002029823934406,-0.7221850206119766,-0.00011665000400444835,0.005179321118112393]
Intercept: -1.1540815262190598


In [42]:
summary = lrModel.summary

In [43]:
print(f"accuracy: {summary.accuracy}, FP: {summary.weightedFalsePositiveRate}, TP: {summary.weightedTruePositiveRate}, fmeasure: {summary.weightedFMeasure()}, precision: {summary.weightedPrecision}, Recall: {summary.weightedRecall}")

accuracy: 0.8057466612707406, FP: 0.3625173099768732, TP: 0.8057466612707406, fmeasure: 0.7993759299320431, precision: 0.7971094940545889, Recall: 0.8057466612707406


In [44]:
display(lrModel, train_df, "ROC")

LogisticRegressionModel: uid=LogisticRegression_e206a36e2618, numClasses=2, numFeatures=30

DataFrame[customerID: string, gender: string, SeniorCitizen: string, Partner: string, Dependents: string, tenure: int, PhoneService: string, MultipleLines: string, InternetService: string, OnlineSecurity: string, OnlineBackup: string, DeviceProtection: string, TechSupport: string, StreamingTV: string, StreamingMovies: string, Contract: string, PaperlessBilling: string, PaymentMethod: string, MonthlyCharges: float, TotalCharges: float, Churn: string, genderindex: double, gendercatVec: vector, SeniorCitizenindex: double, SeniorCitizencatVec: vector, Partnerindex: double, PartnercatVec: vector, Dependentsindex: double, DependentscatVec: vector, PhoneServiceindex: double, PhoneServicecatVec: vector, MultipleLinesindex: double, MultipleLinescatVec: vector, InternetServiceindex: double, InternetServicecatVec: vector, OnlineSecurityindex: double, OnlineSecuritycatVec: vector, OnlineBackupindex: double, OnlineBackupcatVec: vector, DeviceProtectionindex: double, DeviceProtectioncatVec: vector, 

'ROC'

In [45]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
predictions = lrModel.transform(test_df)
evaluate_model = BinaryClassificationEvaluator(rawPredictionCol="prediction")

In [46]:
area_under_curve = evaluate_model.evaluate(predictions)
print(f"Area under curve: {area_under_curve}")
evaluate_model.getMetricName()

Area under curve: 0.7041164527040232


'areaUnderROC'

In [54]:
# from pyspark.mllib.evaluation import BinaryClassificationMetrics
# results = predictions.select(["prediction", "label"])
# results_collect = results.collect()
# results_list = [(float(i[0]), float(i[1])) for i in results_collect]
# prediction_and_label = spark.parallelize(results_list)
# metrics = BinaryClassificationMetrics(prediction_and_label)

# print(f"Area under PR: {metrics.areaUnderPR}, Area under ROC {metrics.areaUnderROC}")


In [55]:
test_df.groupBy("Churn").count().show()

+-----+-----+
|Churn|count|
+-----+-----+
|   No| 1534|
|  Yes|  567|
+-----+-----+



In [58]:
count = predictions.count()

In [59]:
count

2101

In [60]:
correct = results.filter(results.prediction == results.label ).count()

In [61]:
correct

1652

In [63]:
tp = results.filter(results.prediction == 1.0).filter(results.prediction == results.label).count()
tp

298

In [64]:
fp = results.filter(results.prediction == 1.0).filter(results.prediction != results.label).count()
fp

180

In [65]:
fn = results.filter(results.prediction == 0.0).filter(results.prediction != results.label).count()
fn 

269

In [66]:
tn = results.filter(results.prediction == 0.0).filter(results.prediction == results.label).count()
tn

1354

In [67]:
accuracy = (tp+tn)/count
accuracy

0.786292241789624

In [68]:
precision = tp/(tp+fp)
precision

0.6234309623430963

In [70]:
recall = tp/(tp+fn)
recall

0.5255731922398589

In [71]:
from pyspark.ml.tuning import ParamGridBuilder, CrossValidator

paramgrid = (ParamGridBuilder()
            .addGrid(lr.regParam, [0.01, 0.5, 2.0])
            .addGrid(lr.elasticNetParam, [0.0, 0.5, 1.0])
            .addGrid(lr.maxIter, [5, 10, 20])
            .build())

In [72]:
cv = CrossValidator(estimator=lr, estimatorParamMaps=paramgrid, evaluator=evaluate_model, numFolds=5)
cv_model = cv.fit(train_df)

In [73]:
predictions = cv_model.bestModel.transform(test_df)

In [74]:
evaluate_model.evaluate(predictions)

0.7007811188602149

In [75]:
cv_model.explainParams()

"estimator: estimator to be cross-validated (current: LogisticRegression_e206a36e2618)\nestimatorParamMaps: estimator param maps (current: [{Param(parent='LogisticRegression_e206a36e2618', name='regParam', doc='regularization parameter (>= 0).'): 0.01, Param(parent='LogisticRegression_e206a36e2618', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.'): 0.0, Param(parent='LogisticRegression_e206a36e2618', name='maxIter', doc='max number of iterations (>= 0).'): 5}, {Param(parent='LogisticRegression_e206a36e2618', name='regParam', doc='regularization parameter (>= 0).'): 0.01, Param(parent='LogisticRegression_e206a36e2618', name='elasticNetParam', doc='the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.'): 0.0, Param(parent='LogisticRegression_e206a36e2618', name='maxIter', doc='max number of iterations

In [78]:
from pyspark.ml.classification import RandomForestClassifier
rf = RandomForestClassifier(labelCol="label", featuresCol="feature").setImpurity("gini").setMaxDepth(6).setNumTrees(50).setFeatureSubsetStrategy("auto").setSeed(1010)

In [79]:
rf_model = rf.fit(train_df)

In [80]:
predictions = rf_model.transform(test_df)