### Spark notebook ###

This notebook will only work in a Jupyter notebook or Jupyter lab session running on the cluster master node in the cloud.

Follow the instructions on the computing resources page to start a cluster and open this notebook.

**Steps**

1. Connect to the Windows server using Windows App.
2. Connect to Kubernetes.
3. Start Jupyter and open this notebook from Jupyter in order to connect to Spark.

In [1]:
# Run this cell to import pyspark and to define start_spark() and stop_spark()

import findspark

findspark.init()

import getpass
import pandas
import pyspark
import random
import re

from IPython.display import display, HTML
from pyspark import SparkContext
from pyspark.sql import SparkSession


# Constants used to interact with Azure Blob Storage using the hdfs command or Spark

global username

username = re.sub('@.*', '', getpass.getuser())


# Functions used below

def dict_to_html(d):
    """Convert a Python dictionary into a two column table for display.
    """

    html = []

    html.append(f'<table width="100%" style="width:100%; font-family: monospace;">')
    for k, v in d.items():
        html.append(f'<tr><td style="text-align:left;">{k}</td><td>{v}</td></tr>')
    html.append(f'</table>')

    return ''.join(html)


def show_as_html(df, n=20):
    """Leverage existing pandas jupyter integration to show a spark dataframe as html.
    
    Args:
        n (int): number of rows to show (default: 20)
    """

    display(df.limit(n).toPandas())

    
def display_spark():
    """Display the status of the active Spark session if one is currently running.
    """
    
    if 'spark' in globals() and 'sc' in globals():

        name = sc.getConf().get("spark.app.name")

        html = [
            f'<p><b>Spark</b></p>',
            f'<p>The spark session is <b><span style="color:green">active</span></b>, look for <code>{name}</code> under the running applications section in the Spark UI.</p>',
            f'<ul>',
            f'<li><a href="http://localhost:{sc.uiWebUrl.split(":")[-1]}" target="_blank">Spark Application UI</a></li>',
            f'</ul>',
            f'<p><b>Config</b></p>',
            dict_to_html(dict(sc.getConf().getAll())),
            f'<p><b>Notes</b></p>',
            f'<ul>',
            f'<li>The spark session <code>spark</code> and spark context <code>sc</code> global variables have been defined by <code>start_spark()</code>.</li>',
            f'<li>Please run <code>stop_spark()</code> before closing the notebook or restarting the kernel or kill <code>{name}</code> by hand using the link in the Spark UI.</li>',
            f'</ul>',
        ]
        display(HTML(''.join(html)))
        
    else:
        
        html = [
            f'<p><b>Spark</b></p>',
            f'<p>The spark session is <b><span style="color:red">stopped</span></b>, confirm that <code>{username} (notebook)</code> is under the completed applications section in the Spark UI.</p>',
            f'<ul>',
            f'<li><a href="http://mathmadslinux2p.canterbury.ac.nz:8080/" target="_blank">Spark UI</a></li>',
            f'</ul>',
        ]
        display(HTML(''.join(html)))


# Functions to start and stop spark

def start_spark(executor_instances=2, executor_cores=1, worker_memory=1, master_memory=1):
    """Start a new Spark session and define globals for SparkSession (spark) and SparkContext (sc).
    
    Args:
        executor_instances (int): number of executors (default: 2)
        executor_cores (int): number of cores per executor (default: 1)
        worker_memory (float): worker memory (default: 1)
        master_memory (float): master memory (default: 1)
    """

    global spark
    global sc

    cores = executor_instances * executor_cores
    partitions = cores * 4
    port = 4000 + random.randint(1, 999)

    spark = (
        SparkSession.builder
        .config("spark.driver.extraJavaOptions", f"-Dderby.system.home=/tmp/{username}/spark/")
        .config("spark.dynamicAllocation.enabled", "false")
        .config("spark.executor.instances", str(executor_instances))
        .config("spark.executor.cores", str(executor_cores))
        .config("spark.cores.max", str(cores))
        .config("spark.driver.memory", f'{master_memory}g')
        .config("spark.executor.memory", f'{worker_memory}g')
        .config("spark.driver.maxResultSize", "0")
        .config("spark.sql.shuffle.partitions", str(partitions))
        .config("spark.kubernetes.container.image", "madsregistry001.azurecr.io/hadoop-spark:v3.3.5-openjdk-8")
        .config("spark.kubernetes.container.image.pullPolicy", "IfNotPresent")
        .config("spark.kubernetes.memoryOverheadFactor", "0.3")
        .config("spark.memory.fraction", "0.1")
        .config("spark.app.name", f"{username} (notebook)")
        .getOrCreate()
    )
    sc = SparkContext.getOrCreate()
    
    display_spark()

    
def stop_spark():
    """Stop the active Spark session and delete globals for SparkSession (spark) and SparkContext (sc).
    """

    global spark
    global sc

    if 'spark' in globals() and 'sc' in globals():

        spark.stop()

        del spark
        del sc

    display_spark()


# Make css changes to improve spark output readability

html = [
    '<style>',
    'pre { white-space: pre !important; }',
    'table.dataframe td { white-space: nowrap !important; }',
    'table.dataframe thead th:first-child, table.dataframe tbody th { display: none; }',
    '</style>',
]
display(HTML(''.join(html)))

### Credit Card Fraud ###

The credit card fraud dataset is relatively simple, with numeric features that have been anonymised using PCA. There is however significant class imbalance with only 492 examples of fraud out of 284,807 transactions in total. This requires careful handling and makes evaluating the performance of the model hard.

**Sections**

- [Data](#Data)
- [Data processing](#Data-processing)
- [Training](#Training)

**Key points**

- You can use `sc.getConf()` to calculate the ideal number of partitions for the resources you have allocated.
- We can write functions to extract snippets of code that we want to use more than once and we can customize their behaviour with keyword arguments.
  - `with_custom_prediction(pred, threshold)`
  - `show_class_balance(data, name)`
  - `show_metrics(pred, threshold)`

In [2]:
# Run this cell to start a spark session in this notebook

start_spark(executor_instances=2, executor_cores=1, worker_memory=4, master_memory=1)

25/09/18 10:05:54 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).


0,1
spark.dynamicAllocation.enabled,false
spark.fs.azure.sas.uco-user.madsstorage002.blob.core.windows.net,"""sp=racwdl&st=2024-09-19T08:00:18Z&se=2025-09-19T16:00:18Z&spr=https&sv=2022-11-02&sr=c&sig=qtg6fCdoFz6k3EJLw7dA8D3D8wN0neAYw8yG4z4Lw2o%3D"""
spark.kubernetes.driver.pod.name,spark-master-driver
spark.kubernetes.namespace,dew59
spark.kubernetes.executor.podNamePrefix,dew59-notebook-1713659959b67d74
spark.fs.azure.sas.campus-user.madsstorage002.blob.core.windows.net,"""sp=racwdl&st=2024-09-19T08:03:31Z&se=2025-09-19T16:03:31Z&spr=https&sv=2022-11-02&sr=c&sig=kMP%2BsBsRzdVVR8rrg%2BNbDhkRBNs6Q98kYY695XMRFDU%3D"""
spark.kubernetes.container.image.pullPolicy,IfNotPresent
spark.driver.memory,1g
spark.driver.extraJavaOptions,-Djava.net.preferIPv6Addresses=false -XX:+IgnoreUnrecognizedVMOptions --add-opens=java.base/java.lang=ALL-UNNAMED --add-opens=java.base/java.lang.invoke=ALL-UNNAMED --add-opens=java.base/java.lang.reflect=ALL-UNNAMED --add-opens=java.base/java.io=ALL-UNNAMED --add-opens=java.base/java.net=ALL-UNNAMED --add-opens=java.base/java.nio=ALL-UNNAMED --add-opens=java.base/java.util=ALL-UNNAMED --add-opens=java.base/java.util.concurrent=ALL-UNNAMED --add-opens=java.base/java.util.concurrent.atomic=ALL-UNNAMED --add-opens=java.base/jdk.internal.ref=ALL-UNNAMED --add-opens=java.base/sun.nio.ch=ALL-UNNAMED --add-opens=java.base/sun.nio.cs=ALL-UNNAMED --add-opens=java.base/sun.security.action=ALL-UNNAMED --add-opens=java.base/sun.util.calendar=ALL-UNNAMED --add-opens=java.security.jgss/sun.security.krb5=ALL-UNNAMED -Djdk.reflect.useDirectMethodHandle=false -Dderby.system.home=/tmp/dew59/spark/
spark.executor.instances,2


In [3]:
# Spark imports

from pyspark.sql import Row, DataFrame, Window, functions as F
from pyspark.sql.types import *

from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.stat import Correlation
from pyspark.ml.tuning import CrossValidator, ParamGridBuilder

In [4]:
# Other imports to be used locally

import datetime

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

np.set_printoptions(edgeitems=5, threshold=100, precision=4)

In [5]:
# Function to apply a blue-white-red colour gradient to a DataFrame
def color_gradient(val, vmin=-1, vmax=1):
    """Return background colour for a cell based on value, using a blue-white-red gradient.
    Dark blue for 1, white for 0, dark red for -1."""
    from matplotlib import colors
    norm = colors.Normalize(vmin=-1, vmax=1)
    cmap = colors.LinearSegmentedColormap.from_list('', ['red', 'white', 'blue'])
    rgb = cmap(norm(val))[:3]
    return f'background-color: rgb({int(rgb[0]*255)}, {int(rgb[1]*255)}, {int(rgb[2]*255)})'

In [6]:
# Helper functions

def show_class_balance(data, name="data", labelCol="label"):
    """Helper function to show class balance based on label.
    
    Note that this function does not return anything.

    Args:
        data (pyspark.sql.DataFrame): datafame with label
        name (str): name to print above metrics for readability 
        labelCol (str): label column name
    """

    total = data.count()
    counts = data.groupBy(labelCol).count().toPandas()
    counts["ratio"] = counts["count"] / total

    print(f'Class balance [{name}]')
    print(f'')
    print(f'total:   {total}')
    print(f'counts:')
    print(counts)
    print(f'')

    
def with_custom_prediction(
    pred,
    threshold,
    probabilityCol="probability",
    customPredictionCol="customPrediction",
):
    """Helper function to select a custom prediction column for a custom classification threshold.
    
    Args:
        pred (pyspark.sql.DataFrame): datafame with column for probability 
        threshold (float): classification threshold
        probabilityCol (str): probability column name
        customPredictionCol (str): new custom prediction column name
    
    Returns:
        pred (pyspark.sql.DataFrame): dataframe with new colum for custom prediction
    """

    classification_udf = F.udf(lambda x: int(x[1] > threshold), IntegerType())
    
    return pred.withColumn(customPredictionCol, classification_udf(F.col(probabilityCol)))


def show_metrics(
    pred,
    name="data",
    threshold=0.5,
    labelCol="label",
    predictionCol="prediction",
    rawPredictionCol="rawPrediction",
    probabilityCol="probability",
):
    """Helper function to evaluate and show metrics based on a custom classification threshold.
    
    Note that this function does not return anything.
    
    Args:
        pred (pyspark.sql.DataFrame): datafame with column for probability 
        name (str): name to print above metrics for readability 
        threshold (float): classification threshold (default: 0.5)
        predictionCol (str): prediction column name
        rawPredictionCol (str): raw prediction column name
        probabilityCol (str): probability column name
    """

    if threshold != 0.5:

        predictionCol = "customPrediction"
        pred = with_custom_prediction(pred, threshold, probabilityCol=probabilityCol, customPredictionCol=predictionCol)

    total = pred.count()

    nP_actual = pred.filter((F.col(labelCol) == 1)).count()
    nN_actual = pred.filter((F.col(labelCol) == 0)).count()

    nP = pred.filter((F.col(predictionCol) == 1)).count()
    nN = pred.filter((F.col(predictionCol) == 0)).count()
    TP = pred.filter((F.col(predictionCol) == 1) & (F.col(labelCol) == 1)).count()
    FP = pred.filter((F.col(predictionCol) == 1) & (F.col(labelCol) == 0)).count()
    FN = pred.filter((F.col(predictionCol) == 0) & (F.col(labelCol) == 1)).count()
    TN = pred.filter((F.col(predictionCol) == 0) & (F.col(labelCol) == 0)).count()

    if TP + FP > 0:
        precision = TP / (TP + FP)
    else:
        precision = 0
        
    recall = TP / (TP + FN)
    accuracy = (TP + TN) / total

    binary_evaluator = BinaryClassificationEvaluator(
        rawPredictionCol=rawPredictionCol,
        labelCol=labelCol,
        metricName='areaUnderROC',
    )
    auroc = binary_evaluator.evaluate(pred)

    print(f'Metrics [{name}]')
    print(f'')
    print(f'threshold: {threshold}')
    print(f'')
    print(f'total:     {total}')
    print(f'')
    print(f'nP actual: {nP_actual}')
    print(f'nN actual: {nN_actual}')
    print(f'')
    print(f'nP:        {nP}')
    print(f'nN:        {nN}')
    print(f'')
    print(f'TP         {TP}')
    print(f'FP         {FP}')
    print(f'FN         {FN}')
    print(f'TN         {TN}')
    print(f'')
    print(f'precision: {precision:.8f}')
    print(f'recall:    {recall:.8f}')
    print(f'accuracy:  {accuracy:.8f}')
    print(f'')
    print(f'auroc:     {auroc:.8f}')


def expand(x, s=0.05, d=0):
    """Expand a two element array about its center point by a relative scale or a fixed offset.
    Args:
        x (list|np.array): two element array
        s (float): relative scale to expand array based on its width x[1] - x[0]
        d (float): fixed offset to expand array
    Returns:
        x (np.array): expanded two element array
    """
    
    x = np.array(x)
    d = d + s * (x[1] - x[0])
    
    return x + np.array([-d, d])

In [7]:
# Determine ideal number of partitions

conf = sc.getConf()

N = int(conf.get("spark.executor.instances"))
M = int(conf.get("spark.executor.cores"))
partitions = 4 * N * M

print(f'ideal # partitions = {partitions}')

ideal # partitions = 8


### Data ###

The credit card fraud dataset is stored in HDFS.

We will load the dataset and then use `VectorAssembler` to combine the separate feature columns into the single vector column that is expected by most of the classes in the machine learning. 

**Key points**

- The data in gzip compressed but Spark will automatically uncompress it as it is loaded.
- The data has a header so we can infer schema conveniently.

In [8]:
# Load data from HDFS

fraud = (
    spark.read.csv("hdfs:///data/fraud/manipulated.csv.gz", header=True, inferSchema=True)
    .repartition(partitions)
    .cache()
)

fraud.printSchema()
show_as_html(fraud)

25/09/18 10:06:16 WARN SparkStringUtils: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.


root
 |-- Time: double (nullable = true)
 |-- V1: double (nullable = true)
 |-- V2: double (nullable = true)
 |-- V3: double (nullable = true)
 |-- V4: double (nullable = true)
 |-- V5: double (nullable = true)
 |-- V6: double (nullable = true)
 |-- V7: double (nullable = true)
 |-- V8: double (nullable = true)
 |-- V9: double (nullable = true)
 |-- V10: double (nullable = true)
 |-- V11: double (nullable = true)
 |-- V12: double (nullable = true)
 |-- V13: double (nullable = true)
 |-- V14: double (nullable = true)
 |-- V15: double (nullable = true)
 |-- V16: double (nullable = true)
 |-- V17: double (nullable = true)
 |-- V18: double (nullable = true)
 |-- V19: double (nullable = true)
 |-- V20: double (nullable = true)
 |-- V21: double (nullable = true)
 |-- V22: double (nullable = true)
 |-- V23: double (nullable = true)
 |-- V24: double (nullable = true)
 |-- V25: double (nullable = true)
 |-- V26: double (nullable = true)
 |-- V27: double (nullable = true)
 |-- V28: double (nulla

                                                                                

Unnamed: 0,Time,V1,V2,V3,V4,V5,V6,V7,V8,V9,...,V23,V24,V25,V26,V27,V28,V29,V30,Amount,Class
0,52859.0,-1.481699,-1.246688,0.128296,1.398033,3.271913,0.349771,3.013195,-1.219669,1.628425,...,4.087976,6.306455,-1.358532,-1.490907,0.878895,2.755216,-0.930104,0.442916,20.0,0
1,3474.0,0.60493,1.345334,1.885777,3.223215,-2.435636,-2.085326,4.057805,1.719032,1.984682,...,-4.362017,0.50826,-4.342881,0.236705,2.024898,1.59345,-1.178279,0.554425,1.0,0
2,33914.0,-1.40558,0.572848,0.224496,4.821386,0.931036,-2.1564,-1.624538,1.192495,0.992681,...,1.286916,3.258592,-4.837623,-0.783147,2.386024,2.329652,-1.239618,0.682728,52.4,0
3,55143.0,-3.064371,4.088536,0.81515,1.622903,3.642811,0.346418,-1.56172,5.526182,6.920875,...,-3.518951,-1.517298,-0.147714,-2.997555,2.487232,1.216808,-1.139944,0.718347,8.99,0
4,124904.0,-2.250919,0.883931,-1.851606,2.226903,4.399949,-1.72394,-2.760515,4.942099,-2.654509,...,3.191902,11.383269,-0.418173,-1.748567,-2.285626,1.425441,-0.980045,0.296008,2.0,0
5,115476.0,-0.933414,-1.127887,1.220867,5.57201,-1.917753,-0.52221,-3.383634,-2.131774,-1.85756,...,2.216314,5.972034,-2.400616,-2.626058,-4.936418,-0.936139,-1.445127,0.965554,17.95,0
6,60655.0,-1.109151,0.634344,0.340107,0.381706,2.389911,1.715712,2.798189,-1.884216,2.018426,...,3.577823,4.248832,0.313004,-2.069879,2.015703,0.292201,-1.284644,0.857579,4.99,0
7,51615.0,-0.770658,0.645011,1.093084,4.592722,-0.223543,-4.783671,-4.286958,-2.670793,-5.914593,...,4.85064,-0.632296,-3.664068,-0.630959,-2.891812,2.338466,-0.964926,0.459016,98.0,0
8,61992.0,0.919797,5.255682,0.831634,5.170233,1.260635,-4.529845,-2.105848,-4.87982,-3.71663,...,2.382781,4.312173,-2.971982,-1.094132,-1.190674,4.251221,-0.787285,-0.172302,19.95,0
9,167774.0,4.976433,0.693264,4.578195,-4.176534,-2.055746,-1.532248,-3.990775,2.375791,-2.436445,...,3.391831,8.203964,-2.924866,-3.119336,-3.754262,2.3479,-1.220631,0.642744,38.94,0


In [9]:
# Select what we need

assembler = VectorAssembler(
    inputCols=[col for col in fraud.columns if col.startswith("V")],
    outputCol="features"
)

data = assembler.transform(fraud)
data = data.select(
    F.col('features'),
    F.col('Class').alias('label'),
)

data.printSchema()
show_as_html(data)

root
 |-- features: vector (nullable = true)
 |-- label: integer (nullable = true)



                                                                                

Unnamed: 0,features,label
0,"[-1.4816990841017386, -1.2466882989312165, 0.1...",0
1,"[0.604930131762184, 1.3453338402369985, 1.8857...",0
2,"[-1.4055799050769973, 0.5728480731599195, 0.22...",0
3,"[-3.064371219791205, 4.0885357000030815, 0.815...",0
4,"[-2.2509187458604036, 0.8839309476906481, -1.8...",0
5,"[-0.9334139448496037, -1.1278871006329023, 1.2...",0
6,"[-1.109151444168311, 0.6343437730955894, 0.340...",0
7,"[-0.7706581155469284, 0.6450112264105162, 1.09...",0
8,"[0.9197965376970174, 5.255681864247269, 0.8316...",0
9,"[4.9764332233179145, 0.6932636017656277, 4.578...",0


### Data processing ###

The data is well structured but we should verify our assumptions and explore the relationships between the variables before fitting a model.

**Key points**

- The features in 2013 were generated by PCA but have been manipulated to introduce random noise and correlations.
- We can take advantage of some `pandas` functionality to present the descriptive statistics from `.describe()` in a more readable way.
- We can use the `display` and `HTML` functions from `IPython.display` to customize how the correlations are displayed.

In [10]:
# Compute descriptive statistics

statistics = (
    fraud
    .describe()
    .toPandas()
    .set_index("summary")
    .rename_axis(None)
    .T
    .reset_index()
)

display(statistics)

                                                                                

Unnamed: 0,index,count,mean,stddev,min,max
0,Time,284807,94813.85957508069,47488.14595456623,0.0,172792.0
1,V1,284807,-0.4966252866877263,2.3714602902649955,-90.31036760269627,166.78800475963453
2,V2,284807,1.1830151841182168,3.150791192182841,-53.76288656188865,29.71726732098913
3,V3,284807,1.000482656844799,1.577373351168701,-30.72593841714288,70.62452677273802
4,V4,284807,1.80763596418206,4.481625613251927,-81.53762437458782,103.76386278009876
5,V5,284807,-0.1703513868971145,3.4560545766348145,-144.60446179645302,20.94242148040474
6,V6,284807,-1.343781395341357,3.778507904618,-60.396172830381296,113.9857925057676
7,V7,284807,-0.6089904520512494,5.862578219751249,-199.33061338105705,148.07994125031107
8,V8,284807,1.4290944209704708,3.304246059803501,-173.22178540130156,32.2823287781014
9,V9,284807,0.4607341981527921,3.7383482980775753,-112.75901543979754,57.291111325946


In [11]:
# Compute class frequency

show_as_html(data.groupby('label').count())

Unnamed: 0,label,count
0,1,492
1,0,284315


In [12]:
# Explore correlations in case there are features that are highly correlated

correlations = Correlation.corr(data, 'features', 'pearson')

show_as_html(correlations)

                                                                                

Unnamed: 0,pearson(features)
0,"DenseMatrix([[ 1.0000e+00, 1.8011e-01, 4.727..."


In [13]:
# Collect correlations locally and convert to numpy array

correlations_local = correlations.collect()[0][0].toArray()

display(pd.DataFrame(correlations_local))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,20,21,22,23,24,25,26,27,28,29
0,1.0,0.180114,0.472703,-0.005646,-0.438756,-0.053963,-0.196499,-0.265054,-0.185566,-0.176487,...,0.179457,0.080131,0.030442,0.115921,0.02878,-3e-06,-0.241397,-0.082162,-0.001619,0.001707
1,0.180114,1.0,0.000131,0.255457,0.103999,-0.397789,-0.087665,0.0896,0.091379,0.166596,...,-0.19201,0.118492,-0.285377,0.001977,-0.018142,-0.04332,-0.0252,0.097324,0.001275,-0.001868
2,0.472703,0.000131,1.0,0.123138,-0.524002,0.200265,-0.108416,0.000199,0.167162,-0.369185,...,-0.038272,0.00036,-0.108682,-0.404195,-0.030157,0.000352,0.106034,0.000967,0.000474,-0.000256
3,-0.005646,0.255457,0.123138,1.0,-0.229535,0.100972,-0.043979,-0.31748,0.256921,-0.307768,...,-0.269359,0.203203,-0.296478,-0.329738,0.162679,-0.04457,0.302312,-0.190907,0.001448,0.000233
4,-0.438756,0.103999,-0.524002,-0.229535,1.0,-0.136316,0.126119,0.259644,0.098811,0.168075,...,0.086608,0.166438,0.090124,0.294079,-0.001993,-0.114198,-0.026213,0.091906,0.000965,-0.001012
5,-0.053963,-0.397789,0.200265,0.100972,-0.136316,1.0,0.088335,-0.017608,0.375811,-0.254253,...,-0.149429,0.514011,0.115878,-0.239125,0.474068,0.224309,-0.024505,-0.359918,0.001758,-0.000806
6,-0.196499,-0.087665,-0.108416,-0.043979,0.126119,0.088335,1.0,-0.116391,0.311079,-0.182431,...,0.027621,0.219516,0.109531,-0.104118,0.274941,-0.166594,0.160847,-0.15698,0.000564,0.002092
7,-0.265054,0.0896,0.000199,-0.31748,0.259644,-0.017608,-0.116391,1.0,0.157847,0.080288,...,0.240087,-0.072218,-0.070841,-0.069691,-0.045443,0.044315,0.12982,-0.163594,-0.001702,-0.000179
8,-0.185566,0.091379,0.167162,0.256921,0.098811,0.375811,0.311079,0.157847,1.0,-0.177333,...,-0.061656,0.357348,-0.351954,-0.262966,0.340635,-0.0379,0.536768,-0.314038,0.001151,-0.000174
9,-0.176487,0.166596,-0.369185,-0.307768,0.168075,-0.254253,-0.182431,0.080288,-0.177333,1.0,...,-0.116557,-0.157504,-0.004253,0.171775,-0.094034,-0.072619,-0.414382,0.228018,-0.000333,-0.001129


In [14]:
# Round correlations and display with blue-white-red colour gradient
df_corr = pd.DataFrame(correlations_local)
styler = df_corr.style.applymap(color_gradient).format("{:.1f}")
display(styler)

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29
0,1.0,0.2,0.5,-0.0,-0.4,-0.1,-0.2,-0.3,-0.2,-0.2,0.0,0.1,0.4,0.0,-0.2,0.1,0.0,0.2,-0.2,0.0,0.2,0.1,0.0,0.1,0.0,-0.0,-0.2,-0.1,-0.0,0.0
1,0.2,1.0,0.0,0.3,0.1,-0.4,-0.1,0.1,0.1,0.2,0.5,0.0,0.3,0.3,0.0,0.1,0.0,0.3,0.1,0.1,-0.2,0.1,-0.3,0.0,-0.0,-0.0,-0.0,0.1,0.0,-0.0
2,0.5,0.0,1.0,0.1,-0.5,0.2,-0.1,0.0,0.2,-0.4,0.0,0.0,0.3,0.0,0.2,-0.0,0.0,0.4,0.0,-0.3,-0.0,0.0,-0.1,-0.4,-0.0,0.0,0.1,0.0,0.0,-0.0
3,-0.0,0.3,0.1,1.0,-0.2,0.1,-0.0,-0.3,0.3,-0.3,-0.2,-0.4,0.1,-0.1,0.4,-0.1,-0.4,0.0,-0.1,0.0,-0.3,0.2,-0.3,-0.3,0.2,-0.0,0.3,-0.2,0.0,0.0
4,-0.4,0.1,-0.5,-0.2,1.0,-0.1,0.1,0.3,0.1,0.2,0.2,0.5,-0.2,0.4,-0.2,0.2,0.2,-0.2,0.3,0.3,0.1,0.2,0.1,0.3,-0.0,-0.1,-0.0,0.1,0.0,-0.0
5,-0.1,-0.4,0.2,0.1,-0.1,1.0,0.1,-0.0,0.4,-0.3,-0.0,0.0,-0.1,0.1,-0.0,-0.4,0.2,0.1,-0.1,0.1,-0.1,0.5,0.1,-0.2,0.5,0.2,-0.0,-0.4,0.0,-0.0
6,-0.2,-0.1,-0.1,-0.0,0.1,0.1,1.0,-0.1,0.3,-0.2,0.1,0.0,-0.2,0.1,0.1,-0.2,-0.0,-0.5,0.1,-0.1,0.0,0.2,0.1,-0.1,0.3,-0.2,0.2,-0.2,0.0,0.0
7,-0.3,0.1,0.0,-0.3,0.3,-0.0,-0.1,1.0,0.2,0.1,0.4,0.2,-0.2,0.1,0.0,-0.1,0.4,0.1,0.0,0.0,0.2,-0.1,-0.1,-0.1,-0.0,0.0,0.1,-0.2,-0.0,-0.0
8,-0.2,0.1,0.2,0.3,0.1,0.4,0.3,0.2,1.0,-0.2,0.1,-0.0,-0.2,0.1,0.2,-0.1,-0.0,0.3,-0.0,0.0,-0.1,0.4,-0.4,-0.3,0.3,-0.0,0.5,-0.3,0.0,-0.0
9,-0.2,0.2,-0.4,-0.3,0.2,-0.3,-0.2,0.1,-0.2,1.0,0.3,-0.0,0.1,0.3,-0.3,0.2,0.2,0.1,0.1,-0.1,-0.1,-0.2,-0.0,0.2,-0.1,-0.1,-0.4,0.2,-0.0,-0.0


### Modeling ###

We can easily train and evaluate a baseline model using the `LogisticRegression` class and the helper function `show_metrics`.

**Key points**

- We will explore different splitting and training strategies in the other examples in this module.

In [15]:
# Split into test and training

training, test = data.randomSplit([0.8, 0.2])
training.cache()
test.cache()

show_class_balance(data, "data")
show_class_balance(training, "training")
show_class_balance(test, "test")

Class balance [data]

total:   284807
counts:
   label   count     ratio
0      1     492  0.001727
1      0  284315  0.998273



                                                                                

Class balance [training]

total:   228027
counts:
   label   count     ratio
0      1     387  0.001697
1      0  227640  0.998303





Class balance [test]

total:   56780
counts:
   label  count     ratio
0      1    105  0.001849
1      0  56675  0.998151



                                                                                

In [16]:
# Train and evaluate metrics

lr = LogisticRegression(featuresCol='features', labelCol='label')
lr_model = lr.fit(training)

pred = lr_model.transform(test)
pred.cache()

show_metrics(pred)

25/09/18 10:06:41 WARN InstanceBuilder: Failed to load implementation from:dev.ludovic.netlib.blas.JNIBLAS
                                                                                

Metrics [data]

threshold: 0.5

total:     56780

nP actual: 105
nN actual: 56675

nP:        65
nN:        56715

TP         59
FP         6
FN         46
TN         56669

precision: 0.90769231
recall:    0.56190476
accuracy:  0.99908418

auroc:     0.96362560


### Stop Spark ###

In [17]:
# Run this cell before closing the notebook or kill your spark application by hand using the link in the Spark UI
#stop+
stop_spark()

25/09/18 10:06:59 WARN ExecutorPodsWatchSnapshotSource: Kubernetes client has been closed.
