In [1]:
import numpy as np
import pandas as pd 
import os
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, IntegerType, StringType

from pyspark.sql.functions import round
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import MulticlassClassificationEvaluator, BinaryClassificationEvaluator

In [2]:
spark = SparkSession.builder \
                    .master('local[*]') \
                    .appName('ML with PySpark') \
                    .getOrCreate()
# What version of Spark?
print(spark.version)

3.3.1


In [3]:
flights_df = spark.read.csv('flights-larger.csv',
                         sep=',',
                         header=True,
                         inferSchema=True,
                         nullValue='NA')

In [4]:
# Get number of records
print("The data contain %d records." % flights_df.count())
# View the first five records
flights_df.show(5)

The data contain 275000 records.
+---+---+---+-------+------+---+----+------+--------+-----+
|mon|dom|dow|carrier|flight|org|mile|depart|duration|delay|
+---+---+---+-------+------+---+----+------+--------+-----+
| 10| 10|  1|     OO|  5836|ORD| 157|  8.18|      51|   27|
|  1|  4|  1|     OO|  5866|ORD| 466|  15.5|     102| null|
| 11| 22|  1|     OO|  6016|ORD| 738|  7.17|     127|  -19|
|  2| 14|  5|     B6|   199|JFK|2248| 21.17|     365|   60|
|  5| 25|  3|     WN|  1675|SJC| 386| 12.92|      85|   22|
+---+---+---+-------+------+---+----+------+--------+-----+
only showing top 5 rows



In [5]:
# Check column data types
print(flights_df.dtypes)

[('mon', 'int'), ('dom', 'int'), ('dow', 'int'), ('carrier', 'string'), ('flight', 'int'), ('org', 'string'), ('mile', 'int'), ('depart', 'double'), ('duration', 'int'), ('delay', 'int')]


Data preparation for training our ML model
Data Prepartion includes:
Data Cleaning

removing an uninformative column and
removing rows having missing vlaues
Column/Data manipulation

We will consider a flight to be "delayed" when it arrives 15 minutes or more after its scheduled time (this complies with FAA's defintion of delayed flight
Based on this definition, we will create new boolean column 'label' stating if a flight was delayed or not
Convert columns that hold categorical data(carrier & org) into indexed numerical values
Assembling columns

The final stage consists of consolidating all predictor columns into a single one

In [6]:
# Remove the 'flight' column
flights_df =  flights_df.drop('flight')

# Remove records with missing 'delay' values
#flights_valid_delay = flights_drop_column.filter('delay IS NOT NULL')

# Remove records with missing values 
flights_df = flights_df.dropna()
print(flights_df.count())

258289


In [7]:
# Convert columns 'mile' to 'km' and then drop it
flights_km = flights_df.withColumn('km', round(flights_df.mile * 1.60934, 0)) \
                    .drop('mile')

# Create 'label' column indicating whether a flight is delayed or not
flights_km = flights_km.withColumn('label', (flights_km.delay >= 15).cast('integer'))

# Check first five records
flights_km.show(5)

+---+---+---+-------+---+------+--------+-----+------+-----+
|mon|dom|dow|carrier|org|depart|duration|delay|    km|label|
+---+---+---+-------+---+------+--------+-----+------+-----+
| 10| 10|  1|     OO|ORD|  8.18|      51|   27| 253.0|    1|
| 11| 22|  1|     OO|ORD|  7.17|     127|  -19|1188.0|    0|
|  2| 14|  5|     B6|JFK| 21.17|     365|   60|3618.0|    1|
|  5| 25|  3|     WN|SJC| 12.92|      85|   22| 621.0|    1|
|  3| 28|  1|     B6|LGA| 13.33|     182|   70|1732.0|    1|
+---+---+---+-------+---+------+--------+-----+------+-----+
only showing top 5 rows



In [8]:
# Create an indexer, which identifies categories and then creates a new column with numeric index values
flights_indexed = StringIndexer(inputCol='carrier', outputCol='carrier_idx').fit(flights_km).transform(flights_km)

# Repeat the process for org column
flights_indexed = StringIndexer(inputCol='org', outputCol='org_idx').fit(flights_indexed).transform(flights_indexed)
flights_indexed.show(5)

+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+
|mon|dom|dow|carrier|org|depart|duration|delay|    km|label|carrier_idx|org_idx|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+
| 10| 10|  1|     OO|ORD|  8.18|      51|   27| 253.0|    1|        2.0|    0.0|
| 11| 22|  1|     OO|ORD|  7.17|     127|  -19|1188.0|    0|        2.0|    0.0|
|  2| 14|  5|     B6|JFK| 21.17|     365|   60|3618.0|    1|        4.0|    2.0|
|  5| 25|  3|     WN|SJC| 12.92|      85|   22| 621.0|    1|        3.0|    5.0|
|  3| 28|  1|     B6|LGA| 13.33|     182|   70|1732.0|    1|        4.0|    3.0|
+---+---+---+-------+---+------+--------+-----+------+-----+-----------+-------+
only showing top 5 rows



In [9]:
# Create an assembler object
assembler = VectorAssembler(inputCols=['mon', 'dom', 'dow',
'carrier_idx', 'org_idx', 'km', 'depart', 'duration'], outputCol='features')
# Consolidate predictor columns
flights_assembled = assembler.transform(flights_indexed)
# Check the resulting column
flights_assembled.select('features', 'delay').show(5, truncate=False)

+-----------------------------------------+-----+
|features                                 |delay|
+-----------------------------------------+-----+
|[10.0,10.0,1.0,2.0,0.0,253.0,8.18,51.0]  |27   |
|[11.0,22.0,1.0,2.0,0.0,1188.0,7.17,127.0]|-19  |
|[2.0,14.0,5.0,4.0,2.0,3618.0,21.17,365.0]|60   |
|[5.0,25.0,3.0,3.0,5.0,621.0,12.92,85.0]  |22   |
|[3.0,28.0,1.0,4.0,3.0,1732.0,13.33,182.0]|70   |
+-----------------------------------------+-----+
only showing top 5 rows



In [None]:

Decision Trees

In [10]:
# Split into training and testing sets in a 80:20 ratio
flights_train, flights_test = flights_assembled.randomSplit([0.8, 0.2], seed=42)

# Check that training set has around 80% of records
training_ratio = flights_train.count() / flights_assembled.count()
print(training_ratio)

0.7998753334443202


In [11]:
# Create a DT classifier object and fit to the training data
tree = DecisionTreeClassifier()
tree_model = tree.fit(flights_train)
# Create predictions on test data
prediction = tree_model.transform(flights_test)
prediction.select('label', 'prediction', 'probability').show(5, False)

+-----+----------+---------------------------------------+
|label|prediction|probability                            |
+-----+----------+---------------------------------------+
|0    |1.0       |[0.3174931129476584,0.6825068870523416]|
|1    |0.0       |[0.6366622864651774,0.3633377135348226]|
|1    |0.0       |[0.6366622864651774,0.3633377135348226]|
|1    |1.0       |[0.3174931129476584,0.6825068870523416]|
|1    |1.0       |[0.3174931129476584,0.6825068870523416]|
+-----+----------+---------------------------------------+
only showing top 5 rows



Evaluate the model A confusion matrix gives a useful breakdown of predictions versus known values. It has four cells which represent the counts of: True Negatives (TN) — prediction is negative & label is negative

True Positives (TP) — prediction is positive & label is positive

False Negatives (FN) — prediction is negative & label is positive

False Positives (FP) — prediction is positive & label is negative

Using these four measure, we can then calculate the accuravy of the model as follows:

Accuracy=(TN+TP)/(TN+TP+FN+FP)

In [12]:
# Create a confusion matrix
prediction.groupBy('label', 'prediction').count().show()

# Calculate the elements of the confusion matrix
TN = prediction.filter('prediction = 0 AND label = prediction').count()
TP = prediction.filter('prediction = 1 AND label = prediction').count()
FN = prediction.filter('prediction = 0 AND label != prediction').count()
FP = prediction.filter('prediction = 1 AND label != prediction').count()

# Accuracy measures the proportion of correct predictions
accuracy = (TN + TP) / (TN + TP + FN + FP)
print(accuracy)

+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|    1|       0.0| 9313|
|    0|       0.0|16071|
|    1|       1.0|16702|
|    0|       1.0| 9604|
+-----+----------+-----+

0.6340297929967111
