<a href="https://colab.research.google.com/github/martin-fabbri/colab-notebooks/blob/master/spark/pyspark_ml.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# PySpark ML

## Setup PySpark instance

To run spark in Colab, we need to first install all the dependencies in Colab environment i.e. Apache Spark 2.3.2 with hadoop 2.7, Java 8 and Findspark to locate the spark in the system.

In [1]:
#@title ### Setup PySpark instance
#@markdown To run spark in Colab, we need to first install all the dependencies in Colab environment i.e. Apache Spark 2.3.2 with hadoop 2.7, Java 8 and Findspark to locate the spark in the system.

#@markdown **Uppon successful completion of this cell a ``SparkSession`` context named ``spark`` will be available to interact with the service.**

#@markdown Creating multiple ``SparkSession`` or ``SparkContext`` object could 
#@markdown cause issues. If you need to get a reference to the context it is 
#@markdown recommended to use ``SparkSession.builder.getOrCreate()``.


!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q https://downloads.apache.org/spark/spark-2.4.5/spark-2.4.5-bin-hadoop2.7.tgz
!tar xf spark-2.4.5-bin-hadoop2.7.tgz
!pip install -q findspark

import os
import findspark
# environment variables
os.environ['JAVA_HOME'] = '/usr/lib/jvm/java-8-openjdk-amd64'
os.environ['SPARK_HOME'] = 'spark-2.4.5-bin-hadoop2.7'
# check installation
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()
spark

## Download dataset

In [2]:
!wget -q https://raw.githubusercontent.com/martin-fabbri/colab-notebooks/master/data/flights_small.csv
!wget -q https://raw.githubusercontent.com/martin-fabbri/colab-notebooks/master/data/airports.csv
!wget -q https://raw.githubusercontent.com/martin-fabbri/colab-notebooks/master/data/planes.csv
!ls *.csv

airports.csv  flights_small.csv  planes.csv


## Load dataset into a the ``flights`` table

In [0]:
dataset = spark.read.csv('flights_small.csv', inferSchema=True, header =True)
dataset.write.saveAsTable('flights')

In [0]:
dataset = spark.read.csv('planes.csv', inferSchema=True, header =True)
dataset.write.saveAsTable('planes')

``SparckSession`` has an attribute called ``catalog`` which list all the tables inside te cluster.

In [5]:
spark.catalog.listTables()

[Table(name='flights', database='default', description=None, tableType='MANAGED', isTemporary=False),
 Table(name='planes', database='default', description=None, tableType='MANAGED', isTemporary=False)]

## Perform Spark queries

In [6]:
query = """
  FROM planes
  SELECT *
  LIMIT 10
"""

planes = spark.sql(query)
planes.show()

+-------+----+--------------------+----------------+--------+-------+-----+-----+---------+
|tailnum|year|                type|    manufacturer|   model|engines|seats|speed|   engine|
+-------+----+--------------------+----------------+--------+-------+-----+-----+---------+
| N102UW|1998|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N103US|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N104UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N105UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N107US|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N108UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N109UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N110UW|1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA

In [7]:
query = """
  FROM flights
  SELECT *
  LIMIT 10
"""

flights = spark.sql(query)
flights.show()

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
|2014|    1| 15|    1037|        7|    1

## Join the DataFrames

Rename the `year` column on planes to avoid conflict while joining.

In [8]:
planes = planes.withColumnRenamed('year', 'plane_year')
planes.show()

+-------+----------+--------------------+----------------+--------+-------+-----+-----+---------+
|tailnum|plane_year|                type|    manufacturer|   model|engines|seats|speed|   engine|
+-------+----------+--------------------+----------------+--------+-------+-----+-----+---------+
| N102UW|      1998|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N103US|      1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N104UW|      1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N105UW|      1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N107US|      1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N108UW|      1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N109UW|      1999|Fixed wing multi ...|AIRBUS INDUSTRIE|A320-214|      2|  182|   NA|Turbo-fan|
| N110UW|      1999|

Join the DataFrames

In [9]:
model_data = flights.join(planes, on=['tailnum'], how='leftouter')
model_data.show()

+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+
|tailnum|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|flight|origin|dest|air_time|distance|hour|minute|plane_year|type|manufacturer|model|engines|seats|speed|engine|
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+
| N846VA|2014|   12|  8|     658|       -7|     935|       -5|     VX|  1780|   SEA| LAX|     132|     954|   6|    58|      null|null|        null| null|   null| null| null|  null|
| N559AS|2014|    1| 22|    1040|        5|    1505|        5|     AS|   851|   SEA| HNL|     360|    2677|  10|    40|      null|null|        null| null|   null| null| null|  null|
| N847VA|2014|    3|  9|    1443|       -2|    1652|        2|     VX|   755|   SEA| SFO| 

## Transform required columns to numeric type

In [10]:
# dataframe = dataframe.withColumn("col", dataframe.col.cast("new_type"))
model_data = model_data.withColumn('arr_delay', model_data.arr_delay.cast('integer'))
model_data = model_data.withColumn('air_time', model_data.air_time.cast('integer'))
model_data = model_data.withColumn('month', model_data.month.cast('integer'))
model_data = model_data.withColumn('plane_year', model_data.plane_year.cast('integer'))
model_data.show()

+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+
|tailnum|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|flight|origin|dest|air_time|distance|hour|minute|plane_year|type|manufacturer|model|engines|seats|speed|engine|
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+
| N846VA|2014|   12|  8|     658|       -7|     935|       -5|     VX|  1780|   SEA| LAX|     132|     954|   6|    58|      null|null|        null| null|   null| null| null|  null|
| N559AS|2014|    1| 22|    1040|        5|    1505|        5|     AS|   851|   SEA| HNL|     360|    2677|  10|    40|      null|null|        null| null|   null| null| null|  null|
| N847VA|2014|    3|  9|    1443|       -2|    1652|        2|     VX|   755|   SEA| SFO| 

## Create a new column

In [11]:
model_data = model_data.withColumn('plane_age', model_data.plane_year - model_data.year)
model_data.show()

+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+---------+
|tailnum|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|flight|origin|dest|air_time|distance|hour|minute|plane_year|type|manufacturer|model|engines|seats|speed|engine|plane_age|
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+---------+
| N846VA|2014|   12|  8|     658|       -7|     935|       -5|     VX|  1780|   SEA| LAX|     132|     954|   6|    58|      null|null|        null| null|   null| null| null|  null|     null|
| N559AS|2014|    1| 22|    1040|        5|    1505|        5|     AS|   851|   SEA| HNL|     360|    2677|  10|    40|      null|null|        null| null|   null| null| null|  null|     null|
| N847VA|2014|    3|  9|    1443|       

In [12]:
query = """
  FROM planes
  SELECT *
  WHERE tailnum = 'N646SW'
  LIMIT 10
"""

ahq = spark.sql(query)
ahq.show()

+-------+----+--------------------+------------+-------+-------+-----+-----+---------+
|tailnum|year|                type|manufacturer|  model|engines|seats|speed|   engine|
+-------+----+--------------------+------------+-------+-------+-----+-----+---------+
| N646SW|1997|Fixed wing multi ...|      BOEING|737-3H4|      2|  149|   NA|Turbo-fan|
+-------+----+--------------------+------------+-------+-------+-----+-----+---------+



In [13]:
query = """
  SELECT flights.*, planes.*
  FROM flights
  INNER JOIN planes
  ON planes.tailnum = flights.tailnum
"""

ahq = spark.sql(query)
ahq.show()

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+-------+----+--------------------+--------------+-----------+-------+-----+-----+---------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|tailnum|year|                type|  manufacturer|      model|engines|seats|speed|   engine|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+-------+----+--------------------+--------------+-----------+-------+-----+-----+---------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58| N846VA|2011|Fixed wing multi ...|        AIRBUS|   A320-214|      2|  182|   NA|Turbo-fan|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40| N559AS|2006|Fixed wing multi ...|   

In [14]:
model_data = model_data.withColumn('is_late', model_data.arr_delay > 0)
model_data = model_data.withColumn('label', model_data.is_late.cast('integer'))
model_data.show()

+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+---------+-------+-----+
|tailnum|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|flight|origin|dest|air_time|distance|hour|minute|plane_year|type|manufacturer|model|engines|seats|speed|engine|plane_age|is_late|label|
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+---------+-------+-----+
| N846VA|2014|   12|  8|     658|       -7|     935|       -5|     VX|  1780|   SEA| LAX|     132|     954|   6|    58|      null|null|        null| null|   null| null| null|  null|     null|  false|    0|
| N559AS|2014|    1| 22|    1040|        5|    1505|        5|     AS|   851|   SEA| HNL|     360|    2677|  10|    40|      null|null|        null| null|   null| null| null|  

In [15]:
model_data = model_data.filter('arr_delay is not NULL and dep_delay is not NULL and air_time is not NULL and plane_year is not NULL')
model_data.show()

+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+---------+-------+-----+
|tailnum|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|flight|origin|dest|air_time|distance|hour|minute|plane_year|type|manufacturer|model|engines|seats|speed|engine|plane_age|is_late|label|
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+---------+-------+-----+
+-------+----+-----+---+--------+---------+--------+---------+-------+------+------+----+--------+--------+----+------+----------+----+------------+-----+-------+-----+-----+------+---------+-------+-----+



## One hot encode categorical data

In [0]:
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder

carr_indexer = StringIndexer(inputCol='carrier', outputCol='carrier_index')
carr_encoder = OneHotEncoder(inputCol='carrier_index', outputCol='carrier_fact')

In [0]:
dest_indexer = StringIndexer(inputCol='dest', outputCol='dest_index')
dest_encoder = OneHotEncoder(inputCol='dest_index', outputCol='dest_fact')

## Assemble a vector

In [0]:
from pyspark.ml.feature import VectorAssembler
vec_assembler = VectorAssembler(inputCols=['month', 'air_time', 'carrier_fact', 'dest_fact', 'plane_age'], outputCol='features')

In [0]:
from pyspark.ml import Pipeline

flights_pipe = Pipeline(stages=[dest_indexer, dest_encoder, carr_indexer, 
                                carr_encoder, vec_assembler])

## Test & Train Split

In Spark it's important to make sure you split the data **after** all the transformations. This is because operations like `StringIndexer` don't always
produce the same index even when the same list of strings.  

In [0]:
# piped_data = flights_pipe.fit(model_data).transform(model_data)

In [0]:
# training, test = piped_data.randomSplit([.6, .4])