# Getting to know PySpark

In this chapter, you'll learn how Spark manages data and how can you read and write tables from Python.

## Preparing the environment

### Importing libraries

In [1]:
import pandas as pd
import numpy as np

from pprint import pprint

from pyspark.sql.types import (_parse_datatype_string, StructType, StructField,
                               DoubleType, IntegerType, StringType)
from pyspark.sql import SparkSession

### Connect to Spark

In [2]:
spark = SparkSession.builder.getOrCreate()

# eval DataFrame in notebooks
spark.conf.set('spark.sql.repl.eagerEval.enabled', True)

### Reading the data

In [3]:
schema_str = "year int, month int, day int, dep_time int, dep_delay int, arr_time int, " + \
             "arr_delay int, carrier string, tailnum string, flight int, origin string, " + \
             "dest string, air_time int, distance int, hour int, minute int"
customSchema = _parse_datatype_string(schema_str)
flights = spark.read.csv('data-sources/flights_small.csv', header=True, schema=schema_str)
flights.createOrReplaceTempView("flights")
flights.printSchema()
flights.limit(2)

root
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- dep_time: integer (nullable = true)
 |-- dep_delay: integer (nullable = true)
 |-- arr_time: integer (nullable = true)
 |-- arr_delay: integer (nullable = true)
 |-- carrier: string (nullable = true)
 |-- tailnum: string (nullable = true)
 |-- flight: integer (nullable = true)
 |-- origin: string (nullable = true)
 |-- dest: string (nullable = true)
 |-- air_time: integer (nullable = true)
 |-- distance: integer (nullable = true)
 |-- hour: integer (nullable = true)
 |-- minute: integer (nullable = true)



year,month,day,dep_time,dep_delay,arr_time,arr_delay,carrier,tailnum,flight,origin,dest,air_time,distance,hour,minute
2014,12,8,658,-7,935,-5,VX,N846VA,1780,SEA,LAX,132,954,6,58
2014,1,22,1040,5,1505,5,AS,N559AS,851,SEA,HNL,360,2677,10,40


In [4]:
schema_str = "faa string, name string, lat double, lon double, alt int, tz int, dst string"
customSchema = _parse_datatype_string(schema_str)
airports = spark.read.schema(customSchema).csv('data-sources/airports.csv', header=True)
airports.createOrReplaceTempView("airports")
airports.printSchema()
airports.limit(2)

root
 |-- faa: string (nullable = true)
 |-- name: string (nullable = true)
 |-- lat: double (nullable = true)
 |-- lon: double (nullable = true)
 |-- alt: integer (nullable = true)
 |-- tz: integer (nullable = true)
 |-- dst: string (nullable = true)



faa,name,lat,lon,alt,tz,dst
04G,Lansdowne Airport,41.1304722,-80.6195833,1044,-5,A
06A,Moton Field Munic...,32.4605722,-85.6800278,264,-5,A


In [5]:
customSchema = StructType([
    StructField("tailnum", StringType()),
    StructField("year", IntegerType()),
    StructField("type", StringType()),
    StructField("manufacturer", StringType()),
    StructField("model", StringType()),
    StructField("engines", IntegerType()),
    StructField("seats", IntegerType()),
    StructField("speed", DoubleType()),
    StructField("engine", StringType())
])
planes = (spark.read.schema(customSchema)
                    .format("csv")
                    .option("header", "true")
                    .load('data-sources/planes.csv'))
planes.createOrReplaceTempView("planes")
planes.printSchema()
planes.limit(2)

root
 |-- tailnum: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- type: string (nullable = true)
 |-- manufacturer: string (nullable = true)
 |-- model: string (nullable = true)
 |-- engines: integer (nullable = true)
 |-- seats: integer (nullable = true)
 |-- speed: double (nullable = true)
 |-- engine: string (nullable = true)



tailnum,year,type,manufacturer,model,engines,seats,speed,engine
N102UW,1998,Fixed wing multi ...,AIRBUS INDUSTRIE,A320-214,2,182,,Turbo-fan
N103US,1999,Fixed wing multi ...,AIRBUS INDUSTRIE,A320-214,2,182,,Turbo-fan


In [6]:
spark.catalog.listTables()

[Table(name='airports', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='planes', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]

## Ex. 1 - Examining The SparkContext

In this exercise you'll get familiar with the SparkContext.

You'll probably notice that code takes longer to run than you might expect. This is because Spark is some serious software. It takes more time to start up than you might be used to. You may also find that running simpler computations might take longer than expected. That's because all the optimizations that Spark has under its hood are designed for complicated operations with big data sets. That means that for simple or small problems Spark may actually perform worse than some other solutions!

**Instructions:**

1. Call `print()` on `sc` to verify there's a `SparkContext` in your environment.
2. `print()` `sc.version` to see what version of `Spark` is running on your cluster.

In [7]:
sc = spark.sparkContext

print(sc)
print(sc.version)

<SparkContext master=local[*] appName=pyspark-shell>
3.5.1


## Ex. 2 - Creating a SparkSession

We've already created a SparkSession for you called spark, but what if you're not sure there already is one? Creating multiple SparkSessions and SparkContexts can cause issues, so it's best practice to use the SparkSession.builder.getOrCreate() method. This returns an existing SparkSession if there's already one in the environment, or creates a new one if necessary!

**Instructions:**

1. Import SparkSession from pyspark.sql (already done in the importing libraries section).
2. Make a new SparkSession called my_spark using SparkSession.builder.getOrCreate().
3. Print my_spark to the console to verify it's a SparkSession.

In [8]:
# Create my_spark
my_spark = SparkSession.builder.getOrCreate()

# Print my_spark
print(my_spark)

<pyspark.sql.session.SparkSession object at 0x000001ED5AB37810>


In [9]:
my_spark == spark

True

## Ex. 3 - Viewing tables

Once you've created a SparkSession, you can start poking around to see what data is in your cluster!

Your SparkSession has an attribute called catalog which lists all the data inside the cluster. This attribute has a few methods for extracting different pieces of information.

One of the most useful is the .listTables() method, which returns the names of all the tables in your cluster as a list.

**Instructions**
1. See what tables are in your cluster by calling spark.catalog.listTables() and printing the result!

In [10]:
# Print the tables in the catalog
pprint(spark.catalog.listTables())

[Table(name='airports', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='planes', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]


## Ex. 4 - Are you query-ious?

One of the advantages of the DataFrame interface is that you can run SQL queries on the tables in your Spark cluster. 

**Instructions:**
1. Use the `.sql()` method to get the first 10 rows of the `flights` table and save the result to `flights10`. The variable query contains the appropriate SQL query.
2. Use the DataFrame method `.show()` to print `flights10`.

In [11]:
# Don't change this query
query = "SELECT * FROM flights LIMIT 10"

# Get the first 10 rows of flights
flights10 = spark.sql(query)

# Show the results
flights10.show()

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
|2014|    1| 15|    1037|        7|    1

In [12]:
flights10

year,month,day,dep_time,dep_delay,arr_time,arr_delay,carrier,tailnum,flight,origin,dest,air_time,distance,hour,minute
2014,12,8,658,-7,935,-5,VX,N846VA,1780,SEA,LAX,132,954,6,58
2014,1,22,1040,5,1505,5,AS,N559AS,851,SEA,HNL,360,2677,10,40
2014,3,9,1443,-2,1652,2,VX,N847VA,755,SEA,SFO,111,679,14,43
2014,4,9,1705,45,1839,34,WN,N360SW,344,PDX,SJC,83,569,17,5
2014,3,9,754,-1,1015,1,AS,N612AS,522,SEA,BUR,127,937,7,54
2014,1,15,1037,7,1352,2,WN,N646SW,48,PDX,DEN,121,991,10,37
2014,7,2,847,42,1041,51,WN,N422WN,1520,PDX,OAK,90,543,8,47
2014,5,12,1655,-5,1842,-18,VX,N361VA,755,SEA,SFO,98,679,16,55
2014,4,19,1236,-4,1508,-7,AS,N309AS,490,SEA,SAN,135,1050,12,36
2014,11,19,1812,-3,2352,-4,AS,N564AS,26,SEA,ORD,198,1721,18,12


## Ex. 5 - Pandafy a Spark DataFrame

Suppose you've run a query on your huge dataset and aggregated it down to something a little more manageable. Sometimes it makes sense to then take that table and work with it locally using a tool like `pandas`. `Spark DataFrames` make that easy with the `.toPandas()` method. Calling this method on a `Spark DataFrame` returns the corresponding `pandas DataFrame`. 

This time the query counts the number of flights to each airport from SEA and PDX.

**Instructions:**
1. Run the query using the .sql() method. Save the result in flight_counts.
2. Use the .toPandas() method on flight_counts to create a pandas DataFrame called pd_counts.
3. Print the .head() of pd_counts to the console.

In [13]:
# Don't change this query
query = "SELECT origin, dest, COUNT(*) as N FROM flights GROUP BY origin, dest"

# Run the query
flight_counts = spark.sql(query)

# Convert the results to a pandas DataFrame
pd_counts = flight_counts.toPandas()

# Print the head of pd_counts
pd_counts.head()

Unnamed: 0,origin,dest,N
0,SEA,RNO,8
1,SEA,DTW,98
2,SEA,CLE,2
3,SEA,LAX,450
4,PDX,SEA,144


## Ex. 6 - Put some Spark in your data

In this exercise, we are going to put a pandas DataFrame into a Spark cluster! The `SparkSession` class has a method for this. The `.createDataFrame()` method takes a `pandas DataFrame` and returns a `Spark DataFrame`.

The output of this method is stored locally, not in the `SparkSession catalog`. This means that you can use all the `Spark DataFrame methods` on it, but you can't access the data in other contexts.

For example, a SQL query (using the `.sql()` method) that references your DataFrame will throw an error. To access the data in this way, you have to save it as a temporary table.

You can do this using the `.createTempView()` Spark DataFrame method, which takes as its only argument the name of the temporary table you'd like to register. This method registers the DataFrame as a table in the catalog, but as this table is temporary, it can only be accessed from the specific SparkSession used to create the Spark DataFrame.

There is also the method `.createOrReplaceTempView()`. This safely creates a new temporary table if nothing was there before, or updates an existing table if one was already defined. You'll use this method to avoid running into problems with duplicate tables.

Check out the diagram to see all the different ways your Spark data structures interact with each other.

![spark_figure](images/spark_figure.png)

**Instructions:**

1. The code to create a pandas DataFrame of random numbers has already been provided and saved under `pd_temp`.
2. Create a `Spark DataFrame` called `spark_temp` by calling the Spark method `.createDataFrame()` with `pd_temp` as the argument.
3. Examine the list of tables in your Spark cluster and verify that the new DataFrame is not present. Remember you can use `spark.catalog.listTables()` to do so.
4. Register the `spark_temp` DataFrame you just created as a temporary table using the `.createOrReplaceTempView()` method. THe temporary table should be named `"temp"`. Remember that the table name is set including it as the only argument to your method!
5. Examine the list of tables again.

In [14]:
# Create pd_temp
pd_temp = pd.DataFrame(np.random.random(10))

# Create spark_temp from pd_temp
spark_temp = spark.createDataFrame(pd_temp)

# Examine the tables in the catalog
print('Tables catalog:')
pprint(spark.catalog.listTables())

# Add spark_temp to the catalog
print('\nTables catalog after adding temp:')
spark_temp.createOrReplaceTempView('temp')

# Examine the tables in the catalog again
pprint(spark.catalog.listTables())

Tables catalog:
[Table(name='airports', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='planes', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]

Tables catalog after adding temp:
[Table(name='airports', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='planes', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='temp', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]


lling .show().

## Ex. 7 - Dropping the middle man

Now you know how to put data into Spark via pandas, but you're probably wondering why deal with pandas at all? Wouldn't it be easier to just read a text file straight into Spark? Of course it would!

Luckily, your SparkSession has a `.read` attribute which has several methods for reading different data sources into Spark DataFrames. Using these you can create a DataFrame from a `.csv` file just like with regular pandas DataFrames!

The variable file_path is a string with the path to the file `airports.csv`. This file contains information about different airports all over the world.

**Instructions:**
1. Use the .read.csv() method to create a Spark DataFrame called airports
2. The first argument is file_path
3. Pass the argument header=True so that Spark knows to take the column names from the first line of the file.
4. Print out this DataFrame by calling .show().

In [15]:
# Don't change this file path
file_path = "data-sources/airports.csv"

# Read in the airports data
airports = spark.read.csv(file_path, header=True)

# Show the data
# airports.show()
print((airports.count(), len(airports.columns)))
airports

(1397, 7)


faa,name,lat,lon,alt,tz,dst
04G,Lansdowne Airport,41.1304722,-80.6195833,1044,-5,A
06A,Moton Field Munic...,32.4605722,-85.6800278,264,-5,A
06C,Schaumburg Regional,41.9893408,-88.1012428,801,-6,A
06N,Randall Airport,41.431912,-74.3915611,523,-5,A
09J,Jekyll Island Air...,31.0744722,-81.4277778,11,-4,A
0A9,Elizabethton Muni...,36.3712222,-82.1734167,1593,-4,A
0G6,Williams County A...,41.4673056,-84.5067778,730,-5,A
0G7,Finger Lakes Regi...,42.8835647,-76.7812318,492,-5,A
0P2,Shoestring Aviati...,39.7948244,-76.6471914,1000,-5,U
0S9,Jefferson County ...,48.0538086,-122.8106436,108,-8,A


## Close

In [16]:
spark.stop()