# Introduction to PySpark

Outline:
* [Overview](#Overview)
  + SparkContext, RDD and SparkSession
  + Spark SQL and DataFrames
      + Spark DataFrames
      + SQL Queries
  + Working with pandas
* [Querying and Manipulating Data](#Querying-and-Manipulating-Data)
  + Selecting Data
  + Filtering Data
  + Grouping and Aggregation
  + Joining Data
  + Creating Columns

In [2]:
# !pip install pyspark

## Overview

Spark is a fast and general cluster computing system for Big Data. It provides APIs in Python as well as R, Scala, and Java. 

Spark comes with following built-in libraries:

* SQL and DataFrames
* Spark Streaming
* MLlib (machine learning)
* GraphX (graph)

### SparkContext and SparkSession

#### SparkContext

We can connect to a Spark cluster from PySpark by creating an instance of the `SparkContext` class. `SparkContext` is the main entry point for Spark functionality. Thus, the first thing a Spark program must do is to create a SparkContext object, which tells Spark how to access a cluster.

In [3]:
from pyspark import SparkContext

#Create an instance of the SparkContext class
sc = SparkContext(appName="myApp")

# sc = SparkContext.getOrCreate()

# Verify SparkContext
print(sc)

# Print Spark version
print(sc.version)

# Print master
print (sc.master)

# sc.stop()

<SparkContext master=local[*] appName=myApp>
2.4.3
local[*]


The class constructor takes a few optional arguments that allow us to specify the attributes of the cluster we're connecting to. An object holding all these attributes can also be created with the `SparkConf()` constructor. `SparkConf()` contains information about our Spark application. 

``` Python
conf = SparkConf().setAppName(appName).setMaster(master)
sc = SparkContext(conf=conf)
```

#### RDD

Spark's core data structure is the Resilient Distributed Dataset (RDD).

- Resilient: It is fault tolerant by using RDD lineage graph (DAG). Hence, it makes it possible to do recomputation in case of node failure.
- Distributed:  As datasets for Spark RDD resides in multiple nodes.
- Dataset: Records of data that you will work with.

Spark RDD is the technique of representing datasets distributed across multiple nodes, which can operate in parallel. In other words, Spark RDD is the main fault tolerant abstraction of Apache Spark and also its fundamental data structure. The RDD in Spark is an immutable distributed collection of objects.

RDDs are hard to work with directly, so we'll be using the Spark `DataFrame` abstraction built on top of RDDs.

#### SparkSession

To start working with Spark DataFrames, we first have to create a `SparkSession` object from the `SparkContext`. We can think of the SparkContext as our _connection_ to the cluster and the SparkSession as our _interface_ with that connection.

Creating multiple `SparkSession`s and `SparkContext`s can cause issues, so it's best practice to use the `SparkSession.builder.getOrCreate()` method. This returns an existing `SparkSession` if there's already one in the environment, or creates a new one if necessary.

In [4]:
# Import SparkSession from pyspark.sql
from pyspark.sql import SparkSession

# Create my_spark (Make a new SparkSession called my_spark)
spark = SparkSession.builder.getOrCreate()

# Print my_spark to verify it's a SparkSession
print(spark)

<pyspark.sql.session.SparkSession object at 0x109d24ac8>


### Spark SQL and DataFrames

**Spark SQL** is a Spark module for structured data processing. Unlike the basic Spark RDD API, the interfaces provided by Spark SQL provide Spark with more information about the structure of both the data and the computation being performed. There are several ways to interact with Spark SQL including SQL and the Dataset API.

When running SQL from within Python, the results will be returned as a Dataset/DataFrame.

A **Dataset** is an interface that provides the benefits of RDDs (strong typing, ability to use powerful lambda functions) with the benefits of Spark SQL’s optimized execution engine. The Dataset API is available in Scala and Java. Python does not have the support for the Dataset API. But due to Python’s dynamic nature, many of the benefits of the Dataset API are already available (i.e. we can access the field of a row by name naturally row.columnName).  

A **DataFrame** is a Dataset organized into named columns. It is conceptually equivalent to a table in a relational database or a data frame in R/Python, but with richer optimizations under the hood. The DataFrame API is available in Scala, Java, Python, and R. [Source](https://spark.apache.org/docs/latest/sql-programming-guide.html)

#### Spark DataFrames

`SparkSession` has a `.read` attribute which has several methods for reading different data sources into Spark DataFrames. Using these we can read in a `.csv` file and create a DataFrame just like we do it with pandas DataFrames.

In [5]:
# Path to the file
file_path = "/Users/stb/Documents/GitHub/PySpark/data/flights.csv"

# Read in the data
flights = spark.read.csv(file_path, header=True, inferSchema = True)

# Print the type of 'flights'
print(type(flights))

# Show the data (first 5 rows)
flights.show(n = 5)

<class 'pyspark.sql.dataframe.DataFrame'>
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54

In [6]:
# Print the schema of DataFrame
flights.printSchema()

root
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- dep_time: string (nullable = true)
 |-- dep_delay: string (nullable = true)
 |-- arr_time: string (nullable = true)
 |-- arr_delay: string (nullable = true)
 |-- carrier: string (nullable = true)
 |-- tailnum: string (nullable = true)
 |-- flight: integer (nullable = true)
 |-- origin: string (nullable = true)
 |-- dest: string (nullable = true)
 |-- air_time: string (nullable = true)
 |-- distance: integer (nullable = true)
 |-- hour: string (nullable = true)
 |-- minute: string (nullable = true)



In [7]:
# Print the tables in the catalog
spark.catalog.listTables()

[]

In [8]:
# Create a temporary table
flights.createTempView("flights")

In [9]:
# Print the tables in the catalog
spark.catalog.listTables()

[Table(name='flights', database=None, description=None, tableType='TEMPORARY', isTemporary=True)]

#### SQL Queries

One of the advantages of the DataFrame interface is that we can run SQL queries with `.sql()` method on the tables in our Spark cluster.

In [10]:
# Create a query
query = "FROM flights SELECT * LIMIT 5"

# Get the first 5 rows of flights
flights5 = spark.sql(query)

# Show the results
flights5.show()

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
+----+-----+---+--------+---------+-----

### Working with `pandas`

We may want to work with the table locally using `pandas`. Spark DataFrames make that easy with the `.toPandas()` method. Calling this method on a Spark DataFrame returns the corresponding pandas DataFrame. 

In [11]:
# Create a query
query = "SELECT origin, dest, COUNT(*) as N FROM flights GROUP BY origin, dest"

# Run the query
flight_counts = spark.sql(query)

# Convert the results to a pandas DataFrame
pd_counts = flight_counts.toPandas()

# Print the head of pd_counts
pd_counts.head()

Unnamed: 0,origin,dest,N
0,SEA,RNO,8
1,SEA,DTW,98
2,SEA,CLE,2
3,SEA,LAX,450
4,PDX,SEA,144


We may also want to put a `pandas` DataFrame into a Spark cluster. The `.createDataFrame()` method takes a pandas DataFrame and returns a Spark DataFrame. 

In [12]:
import pandas as pd
import numpy as np

# Create pd_temp
pd_temp = pd.DataFrame(np.random.random(10))

# Create spark_temp from pd_temp
spark_temp = spark.createDataFrame(pd_temp)

# Examine the tables in the catalog (and verify that the new DataFrame is not present)
spark.catalog.listTables()


[Table(name='flights', database=None, description=None, tableType='TEMPORARY', isTemporary=True)]

Note that the output of this method is stored locally, not in the SparkSession catalog. This means that we can use all the Spark DataFrame methods on it, but we can't access the data in other contexts. For example, a SQL query using the `.sql()` method that references the  DataFrame will throw an error. To access the data in this way, we have to save it as a temporary table. We can do this using the `.createTempView()` Spark DataFrame method. This method registers the DataFrame as a table in the catalog, but as this table is temporary, it can only be accessed from the specific SparkSession used to create the Spark DataFrame.

There is also the method `.createOrReplaceTempView()`. This safely creates a new temporary table if nothing was there before, or updates an existing table if one was already defined. 

In [13]:
# Add spark_temp to the catalog
spark_temp.createOrReplaceTempView("temp")

# Examine the tables in the catalog again
spark.catalog.listTables()

[Table(name='flights', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='temp', database=None, description=None, tableType='TEMPORARY', isTemporary=True)]

## Querying and Manipulating Data

`pyspark.sql` module, which provides optimized data queries to the Spark session, is used for querying and manipulating data. 

### Selecting Data

The Spark variant of SQL's `SELECT` clause is the `.select()` method. This method takes multiple arguments - one for each column. These arguments can either be the column name as a string or a column object (using the `df.colName` syntax).

In [20]:
# Select the first set of columns
selected1 = flights.select("tailnum", "origin", "dest")

# Show the first set of columns
selected1.show(3)

# Select the second set of columns
temp = flights.select(flights.tailnum, flights.origin, flights.dest)

# Define first filter
filterA = flights.origin == "SEA"

# Define second filter
filterB = flights.dest == "PDX"

# Filter the data, first by filterA then by filterB
selected2 = temp.filter(filterA).filter(filterB)

# Show the filtered records
selected2.show(3)

+-------+------+----+
|tailnum|origin|dest|
+-------+------+----+
| N846VA|   SEA| LAX|
| N559AS|   SEA| HNL|
| N847VA|   SEA| SFO|
+-------+------+----+
only showing top 3 rows

+-------+------+----+
|tailnum|origin|dest|
+-------+------+----+
| N810SK|   SEA| PDX|
| N822SK|   SEA| PDX|
| N586SW|   SEA| PDX|
+-------+------+----+
only showing top 3 rows



We can also use the `.select()` method to perform column-wise operations, as we do in SQL. There are again two options here. W can eeither use `df.colName`notation or we can use Spark DataFrame method `.selectExpr()` which takes SQL expression as a string. Note that `.alias()` method is equivalent to the SQL's `AS` keyword.

In [22]:
# Define avg_speed
avg_speed = (flights.distance/(flights.air_time/60)).alias("avg_speed")

# Select the correct columns
speed1 = flights.select("origin", "dest", "tailnum", avg_speed)

# Show the result
speed1.show(3)

# Create the same table using a SQL expression
speed2 = flights.selectExpr("origin", "dest", "tailnum", "distance/(air_time/60) as avg_speed")

# Show the result
speed2.show(3)

+------+----+-------+------------------+
|origin|dest|tailnum|         avg_speed|
+------+----+-------+------------------+
|   SEA| LAX| N846VA| 433.6363636363636|
|   SEA| HNL| N559AS| 446.1666666666667|
|   SEA| SFO| N847VA|367.02702702702703|
+------+----+-------+------------------+
only showing top 3 rows

+------+----+-------+------------------+
|origin|dest|tailnum|         avg_speed|
+------+----+-------+------------------+
|   SEA| LAX| N846VA| 433.6363636363636|
|   SEA| HNL| N559AS| 446.1666666666667|
|   SEA| SFO| N847VA|367.02702702702703|
+------+----+-------+------------------+
only showing top 3 rows



### Filtering Data

The Spark counterpart of SQL's `WHERE` clause is the `.filter()` method, which takes either a Spark Column of boolean (True/False) values or the `WHERE ` clause of a SQL expression as a _string_.

We'll now use the `.filter()` method to find all the flights that flew over 1000 miles two ways.

In [14]:
# Filter flights with a SQL string
long_flights1 = flights.filter("distance > 1000")

# Filter flights with a boolean column
long_flights2 = flights.filter(flights.distance > 1000)

# Examine the data to check they're equal
print(long_flights1.show(3))
print(long_flights2.show(3))

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    4| 19|    1236|       -4|    1508|       -7|     AS| N309AS|   490|   SEA| SAN|     135|    1050|  12|    36|
|2014|   11| 19|    1812|       -3|    2352|       -4|     AS| N564AS|    26|   SEA| ORD|     198|    1721|  18|    12|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
only showing top 3 rows

None
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|mont

### Grouping and Aggregation

#### Aggregation: `.groupBy()` method with no arguments

In [27]:
# Find the shortest flight from PDX in terms of distance
flights.filter(flights.origin == "PDX").groupBy().min("distance").show()

# Find the longest flight from PDX in terms of distance
flights.filter(flights.origin == "PDX").groupBy().max("distance").show()

+-------------+
|min(distance)|
+-------------+
|          106|
+-------------+

+-------------+
|max(distance)|
+-------------+
|         2631|
+-------------+



In [None]:
# Average duration of Delta flights
flights.filter(flights.carrier == "DL").filter(flights.origin == "SEA").groupBy().avg("air_time").show()

# Total hours in the air
flights.withColumn("duration_hrs", flights.air_time/60).groupBy().sum("duration_hrs").show()

#### Grouping and Aggregation: `.groupBy()` method with arguments

In [34]:
# Group by tailnum
by_plane = flights.groupBy("tailnum")

# Number of flights each plane made
by_plane.count().show(5)

# Group by origin
by_origin = flights.groupBy("origin")

# Average distance of flights from PDX and SEA
by_origin.avg("distance").show()

+-------+-----+
|tailnum|count|
+-------+-----+
| N442AS|   38|
| N102UW|    2|
| N36472|    4|
| N38451|    4|
| N73283|    4|
+-------+-----+
only showing top 5 rows

+------+------------------+
|origin|     avg(distance)|
+------+------------------+
|   SEA|1276.5170269469943|
|   PDX|1065.9026494146642|
+------+------------------+



In addition to the `GroupedData` methods, there is also the `.agg()` method, which lets us pass an aggregate column expression that uses any of the aggregate functions from the `pyspark.sql.functions`  submodule. 

This submodule contains many useful functions for computing things like _standard deviations_. All the aggregation functions in this submodule take the name of a column in a `GroupedData` table.

In [38]:
# Import pyspark.sql.functions as F
import pyspark.sql.functions as F

# Group by month and dest
by_month_dest = flights.groupBy("month", "dest")

# Average distance by month and destination
by_month_dest.avg("distance").show(5)

# Standard deviation
by_month_dest.agg(F.stddev("distance")).show(5)

+-----+----+------------------+
|month|dest|     avg(distance)|
+-----+----+------------------+
|    4| PHX|1074.3333333333333|
|    1| RDM|             116.0|
|    5| ONT| 903.5555555555555|
|    7| OMA|            1368.0|
|    8| MDW|            1738.4|
+-----+----+------------------+
only showing top 5 rows

+-----+----+---------------------+
|month|dest|stddev_samp(distance)|
+-----+----+---------------------+
|    4| PHX|    46.58750347706978|
|    1| RDM|                  0.0|
|    5| ONT|    62.19146064997812|
|    7| OMA|                  0.0|
|    8| MDW|    8.462922227669292|
+-----+----+---------------------+
only showing top 5 rows



### Joining Data

A join will combine two different tables along a column that they share. This column is called the key.  In PySpark, joins are performed using the DataFrame method `.join()`. This method takes three arguments. The first is the second DataFrame that you want to join with the first one. The second argument, `on`, is the name of the key column(s) as a string. The names of the key column(s) must be the same in each table. The third argument, `how`, specifies the kind of join to perform.

In [None]:
# Examine the data
print(airports.show())

# Rename the faa column
airports = airports.withColumnRenamed("faa", "dest")

# Join the DataFrames
flights_with_airports = flights.join(airport s, "dest", "leftouter")

# Examine the data again
print(flights_with_airports.show())

### Creating Columns

We can perform column-wise operations using the `.withColumn()` method, which takes two arguments. First, the new column as a string, and second the new column itself.

Since a Spark DataFrame is _immutable_, it can't be changed.  Thus the columns can't be updated in place, and we need to create a new DataFrame by overwriting the original one.

The difference between `.select()` and `.withColumn()` methods is that `.select()` returns only the columns we specify, while `.withColumn()` returns all the columns of the DataFrame in addition to the one we defined.

In [29]:
# Show the head
print(flights.show(2))

# Add duration_hrs
flights = flights.withColumn("duration_hrs", flights.air_time/60)

# Show the head of the new DataFrame
print(flights.show(2))

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
only showing top 2 rows

None
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minut

----

References: 
- https://spark.apache.org/
- https://datacamp.com/