# ML with PySpark

The primary machine learning API is contained in `pyspark.ml`. It is based on PySpark DataFrames. The well-known RDD-based API `pyspark.mllib` is mainly provided for legacy purposes and might become deprecated at some point.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys
sys.path.append('..')

In [None]:
import matplotlib.pyplot as plt
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
from pyspark.ml.feature import (
    StringIndexer,
    OneHotEncoder,
    VectorAssembler
)
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

from utils import download_textfile

## Start session

In [None]:
# start session
spark = SparkSession.builder \
    .appName('Logistic Regression') \
    .getOrCreate()

## Import data

In [None]:
# download CSV file
url = 'https://web.stanford.edu/class/archive/cs/cs109/cs109.1166/stuff/titanic.csv'

download_textfile(url, save_path='../data/titanic.csv')

In [None]:
# import CSV file
df = spark.read.csv('../data/titanic.csv', header=True, inferSchema=True)

print(f'The data has {df.count()} rows and {len(df.columns)} columns')

In [None]:
# show summary
df.show(10)
df.printSchema()

In [None]:
# drop NaNs
df = df.na.drop() # note that there are no NaNs

In [None]:
# select columns
data_columns = [
    'Survived',
    'Pclass',
    'Sex',
    'Age',
    'Siblings/Spouses Aboard',
    'Parents/Children Aboard',
    'Fare'
]

df = df.select(data_columns) # df = df.drop('Name')

## Create pipeline

In [None]:
# create transformations
sex_indexer = StringIndexer(inputCol='Sex', outputCol='SexIdx')
sex_encoder = OneHotEncoder(inputCol='SexIdx', outputCol='SexVec')

feature_columns = [
    'Pclass',
    'SexVec',
    'Age',
    'Siblings/Spouses Aboard',
    'Parents/Children Aboard',
    'Fare'
]

assembler = VectorAssembler(inputCols=feature_columns, outputCol='Features')

In [None]:
# create model
lr = LogisticRegression(
    featuresCol='Features',
    labelCol='Survived',
    maxIter=10,
    regParam=0.1,
)

In [None]:
# create pipeline
stages = [
    sex_indexer, # estimator
    sex_encoder, # estimator
    assembler, # transformer
    lr # estimator
]

pipeline = Pipeline(stages=stages) # contains transformers and estimators

## Fit model

In [None]:
# split data
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)

In [None]:
# fit model
model = pipeline.fit(train_df)

In [None]:
# get summary
train_summary = model.stages[-1].summary

train_acc = train_summary.accuracy
train_roc_auc = train_summary.areaUnderROC

train_pr = train_summary.pr.toPandas()
train_roc = train_summary.roc.toPandas()

print(f'Train acc.: {train_acc}')

In [None]:
# plot PR curve
fig, ax = plt.subplots(figsize=(5, 3.5))
ax.plot(train_pr['recall'], train_pr['precision'])
ax.set_title('Train PR curve')
ax.set(xlabel='recall', ylabel='precision')
ax.set(xlim=(0, 1), ylim=(0, 1))
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

In [None]:
# plot ROC curve
fig, ax = plt.subplots(figsize=(5, 3.5))
ax.plot(train_roc['FPR'], train_roc['TPR'])
ax.set_title('Train ROC curve')
ax.set(xlabel='FPR', ylabel='TPR')
ax.set(xlim=(0, 1), ylim=(0, 1))
ax.grid(visible=True, which='both', color='lightgray', linestyle='-')
ax.set_axisbelow(True)
fig.tight_layout()

## Test model

In [None]:
# make test predictions
pred_df = model.transform(test_df)

In [None]:
# show summary
pred_df.show(10)
pred_df.printSchema()

In [None]:
# instantiate evaluators
roc_auc_metric = BinaryClassificationEvaluator(
    rawPredictionCol='rawPrediction',
    labelCol='Survived',
    metricName='areaUnderROC'
)

pr_auc_metric = BinaryClassificationEvaluator(
    rawPredictionCol='rawPrediction',
    labelCol='Survived',
    metricName='areaUnderPR'
)

In [None]:
# calculate PR and ROC AUC
pr_auc = pr_auc_metric.evaluate(pred_df)
roc_auc = roc_auc_metric.evaluate(pred_df)

print(f'Test PR AUC: {pr_auc:.4f}')
print(f'Test ROC AUC: {roc_auc:.4f}')

## Stop session

In [None]:
# stop session
spark.stop()