# 2. Manipulating Data

This series is based on the Datacamp course [Introduction to PySpark](https://app.datacamp.com/learn/courses/introduction-to-pyspark). The course has the following chapters:

1. Basics: Getting to know PySpark
2. **Manipulating data**: The current notebook.
3. Getting started with machine learning pipelines

This notebook deals with methods for Spark SQL dataframe manipulation. Many of these methods have an equivalent SQL operator. check my [SQL guide](https://github.com/mxagar/sql_guide) if you need a refresher.

**Table of Contents**:

- [Setup: Create Session + Upload Data](#Setup:-Create-Session-+-Upload-Data)
- [2.1 Creating, Renaming, and Casting Columns](#2.1-Creating,-Renaming,-and-Casting-Columns)
- [2.2 SQL in a Nutshell](#2.2-SQL-in-a-Nutshell)
- [2.3 Filtering Data: WHERE - filter()](#2.3-Filtering-Data:-WHERE---filter())
- [2.4 Selecting Columns: SELECT - select(), selectExpr(), alias()](#2.4-Selecting-Columns:-SELECT---select(),-selectExpr(),-alias())
- [2.5 Grouping and Aggregating: GROUP BY, MIN, MAX, COUNT, SUM, AVG, AGG](#2.5-Grouping-and-Aggregating:-GROUP-BY,-MIN,-MAX,-COUNT,-SUM,-AVG,-AGG)
    - [Aggregation Functions](#Aggregation-Functions)
- [2.6 Joining: JOIN - join()](#2.6-Joining:-JOIN---join())

## Setup: Create Session + Upload Data

In [1]:
import findspark
findspark.init()

In [2]:
# Import SparkSession from pyspark.sql
from pyspark.sql import SparkSession

# Create or get a (new) SparkSession: session
session = SparkSession.builder.getOrCreate()

# Print session: our SparkSession
print(session)

23/04/21 13:31:08 WARN Utils: Your hostname, kasiopeia.local resolves to a loopback address: 127.0.0.1; using 192.168.1.34 instead (on interface en0)
23/04/21 13:31:08 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/04/21 13:31:09 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


<pyspark.sql.session.SparkSession object at 0x7fde48026e50>




In [3]:
# Load and register flights dataframe
flights = session.read.csv("../data/flights_small.csv", header=True, inferSchema=True)
flights.createOrReplaceTempView("flights")

# Load and register airports dataframe
airports = session.read.csv("../data/airports.csv", header=True, inferSchema=True)
airports.createOrReplaceTempView("airports")

# Load and register planes dataframe
planes = session.read.csv("../data/planes.csv", header=True, inferSchema=True)
planes.createOrReplaceTempView("planes")

print(session.catalog.listTables())

                                                                                

[Table(name='airports', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True), Table(name='flights', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True), Table(name='planes', catalog=None, namespace=[], description=None, tableType='TEMPORARY', isTemporary=True)]


## 2.1 Creating, Renaming, and Casting Columns

Once we get the Spark SQL dataframe, we can add a new column to it with

```python
df = df.withColumn("newCol", df.oldCol + 1)
```

However, Spark SQL dataframes are inmutable, i.e., we are creating a new dataframe. Notes:

- `df.colName` is a `Column` object, which comes up often. We can also convert a column name string into a `Column` with `pyspark.sql.functions.col`.
- `withColumn()` returns the **entire table/dataframe** with a new column. If we want to change a column content, we need to write `"oldCol"` instead of `"newCol"` in the first argument. We can use it to rename columns, too. The second argment **must** be a `Column` object, created as `df.colName` or `col("colName")`.

It might happen that we need to cast the type of a column; to check the types we use `printSchema()` and to cast 

```python
from pyspark.sql.functions import col

# Print schema with types
df.printSchema()

# Cast from string to double: new table is created
df = df.withColumn("air_time", col("air_time").cast("double"))
df = df.withColumn("air_time", df.arr_delay.cast("double"))

# Rename column "air_time" -> "flight_duration"
# BUT, the old column is still there if it has another name;
# we can drop it using .select(), as shown below
df = df.withColumn("flight_duration", flights.air_time)

# Another way to rename column names:
# this function allows to use two column name strings
# AND replaces the previous column
df = df.withColumnRenamed("flight_duration", "air_time")
```

In [4]:
# Create/get the DataFrame flights
flights = session.table("flights")

In [5]:
# Show the head
flights.show(5)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|
|2014|    3|  9|     754|       -1|    1015|        1|     AS| N612AS|   522|   SEA| BUR|     127|     937|   7|    54|
+----+-----+---+--------+---------+-----

In [6]:
# Add a new column: duration_hrs
# General syntax: df = df.withColumn("newCol", df.oldCol + 1)
# A new dataframe is returned! That's because dataframes and their columns are inmutable
# To modify a colum: df = df.withColumn("col", df.col + 1)
# BUT: in reality, we create a new dataframe with the modified column
flights = flights.withColumn("duration_hrs", flights.air_time/60)

In [7]:
# Show the head
flights.show(5)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|      duration_hrs|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------------+
|2014|   12|  8|     658|       -7|     935|       -5|     VX| N846VA|  1780|   SEA| LAX|     132|     954|   6|    58|               2.2|
|2014|    1| 22|    1040|        5|    1505|        5|     AS| N559AS|   851|   SEA| HNL|     360|    2677|  10|    40|               6.0|
|2014|    3|  9|    1443|       -2|    1652|        2|     VX| N847VA|   755|   SEA| SFO|     111|     679|  14|    43|              1.85|
|2014|    4|  9|    1705|       45|    1839|       34|     WN| N360SW|   344|   PDX| SJC|      83|     569|  17|     5|1.3833333333333333|
|2014|    3|  9|     754|  

In [8]:
flights.printSchema()

root
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- dep_time: string (nullable = true)
 |-- dep_delay: string (nullable = true)
 |-- arr_time: string (nullable = true)
 |-- arr_delay: string (nullable = true)
 |-- carrier: string (nullable = true)
 |-- tailnum: string (nullable = true)
 |-- flight: integer (nullable = true)
 |-- origin: string (nullable = true)
 |-- dest: string (nullable = true)
 |-- air_time: string (nullable = true)
 |-- distance: integer (nullable = true)
 |-- hour: string (nullable = true)
 |-- minute: string (nullable = true)
 |-- duration_hrs: double (nullable = true)



In [9]:
from pyspark.sql.functions import col

# Convert air_time and dep_delay (strings) to double to use math operations on them
flights = flights.withColumn("air_time", col("air_time").cast("double"))
flights = flights.withColumn("dep_delay", col("dep_delay").cast("double"))

In [10]:
# Rename column and keep old
flights = flights.withColumn("flight_duration", flights.air_time)

In [26]:
# Another way to rename column names
# This option replaces the old column
# flights = flights.withColumnRenamed("flight_duration", "air_time")

In [12]:
flights.printSchema()

root
 |-- year: integer (nullable = true)
 |-- month: integer (nullable = true)
 |-- day: integer (nullable = true)
 |-- dep_time: string (nullable = true)
 |-- dep_delay: double (nullable = true)
 |-- arr_time: string (nullable = true)
 |-- arr_delay: string (nullable = true)
 |-- carrier: string (nullable = true)
 |-- tailnum: string (nullable = true)
 |-- flight: integer (nullable = true)
 |-- origin: string (nullable = true)
 |-- dest: string (nullable = true)
 |-- air_time: double (nullable = true)
 |-- distance: integer (nullable = true)
 |-- hour: string (nullable = true)
 |-- minute: string (nullable = true)
 |-- duration_hrs: double (nullable = true)
 |-- flight_duration: double (nullable = true)



## 2.2 SQL in a Nutshell

Many Spark SQL Dataframe methods have an equivalent SQL operation.

Most common SQL operators: `SELECT`, `FROM`, `WHERE`, `AS`, `GROUP BY`, `COUNT()`, `AVG()`, etc.

```sql
-- Get all the contents from the table my_table: we get a table
SELECT * FROM my_table;

-- Get specified columns and compute a new column value: we get a table
SELECT origin, dest, air_time / 60 FROM flights;

-- Filter according to value in column: we get a table
SELECT * FROM students
WHERE grade = 'A';

-- Get the table which contains the destination and tail number of flights that last +10h
SELECT dest, tail_num FROM flights WHERE air_time > 600;

-- Group by: group by category values and apply an aggregation function for each group
-- In this case: number of flights for each unique origin
SELECT COUNT(*) FROM flights
GROUP BY origin;

-- Group by all unique combinations of origin and dest columns
SELECT origin, dest, COUNT(*) FROM flights
GROUP BY origin, dest;

-- Group by unique origin-carrier combinations and for each
-- compute average air time in hrs
SELECT AVG(air_time) / 60 FROM flights
GROUP BY origin, carrier;

-- Flight duration in hrs, new column name
SELECT air_time / 60 AS duration_hrs
FROM flights
```

Also, recall we can combine tables in SQL using the `JOIN` operator:

```sql
-- INNER JOIN: note it is symmetrical, we can interchange TableA and B
SELECT * FROM TableA
INNER JOIN TableB
ON TableA.col_match = TableB.col_match;

-- FULL OUTER JOIN
SELECT * FROM TableA
FULL OUTER JOIN TableB
ON TableA.col_match = Table_B.col_match

-- LEFT OUTER JOIN
-- Left table: TableA; Right table: TableB
SELECT * FROM TableA
LEFT OUTER JOIN TableB
ON TableA.col_match = Table_B.col_match

-- RIGHT OUTER JOIN
-- Left table: TableA; Right table: TableB
SELECT * FROM TableA
RIGHT OUTER JOIN TableB
ON TableA.col_match = Table_B.col_match
```

## 2.3 Filtering Data: WHERE - filter()

The `filter()` function is equivalent to `WHERE`. We can pass either a string we would write after `WHERE` or we can use the typical Python syntax:

```sql
SELECT * FROM flights WHERE air_time > 120
```

```python
flights.filter("air_time > 120").show()
flights.filter(flights.air_time > 120).show()
```

In [13]:
# Filter flights by passing a string
long_flights1 = flights.filter("distance > 1000")

# Filter flights by passing a column of boolean values
long_flights2 = flights.filter(flights.distance > 1000)

# Print the data to check they're equal
long_flights1.show(2)
long_flights2.show(2)

+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------+---------------+
|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|dest|air_time|distance|hour|minute|duration_hrs|flight_duration|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------+---------------+
|2014|    1| 22|    1040|      5.0|    1505|        5|     AS| N559AS|   851|   SEA| HNL|   360.0|    2677|  10|    40|         6.0|          360.0|
|2014|    4| 19|    1236|     -4.0|    1508|       -7|     AS| N309AS|   490|   SEA| SAN|   135.0|    1050|  12|    36|        2.25|          135.0|
+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+----+--------+--------+----+------+------------+---------------+
only showing top 2 rows

+----+-----+---+--------+---------+--------+---------+-------+-------+------+----

## 2.4 Selecting Columns: SELECT - select(), selectExpr(), alias()

The `select()` method is equivalent to the `SELECT` SQL operator: it can take several comma-separated column names or `Column` objects (`df.col`) and returns a table with them. In contrast, the `withColumn()` method returns the entire table. Therefore, if we want to drop unnecessary columns, we can use `select()`.

If we want to perform a more sophisticated selection, as in SQL, we can use `selectExpr()` and pass comma-separated SQL strings; if we want to change the name of the selected/transformed column, we can use `alias()`, equivalent to `AS` in SQL.

In [14]:
# Select the first set of columns
selected1 = flights.select("tailnum", "origin", "dest")

# Select the second set of columns
temp = flights.select(flights.origin, flights.dest, flights.carrier)

# Define first filter
filterA = flights.origin == "SEA"

# Define second filter
filterB = flights.dest == "PDX"

# Filter the data, first by filterA then by filterB
selected2 = temp.filter(filterA).filter(filterB)

In [15]:
# Define avg_speed
# We define a new column object
avg_speed = (flights.distance/(flights.air_time/60)).alias("avg_speed")

# Select the correct columns
# We can pass comma separated strings or column objects to select();
# each column is a comma-separated element.
speed1 = flights.select("origin", "dest", "tailnum", avg_speed)

# Create the same table using a SQL expression
# We can pass comma separated SQL strings to selectExpr(); each colum operation is
# a comma-separated element.
speed2 = flights.selectExpr("origin", "dest", "tailnum", "distance/(air_time/60) as avg_speed")

## 2.5 Grouping and Aggregating: GROUP BY, MIN, MAX, COUNT, SUM, AVG, AGG

Similarly as it is done in SQL, we can create a `pyspark.sql.GroupedData` object with `groupBy()` and then apply aggregation functions like `min()`, `max()`, `count()`, `sum()`, `avg()`, `agg()`, etc. Note that we can use `groupBy()` in two ways:

- If we pass one or more column names to `groupBy()`, i.e., `groupBy("col")`, it will group the table in the classes/unique values of the passed column(s); then, we apply an aggregation function on those groups. This is equivalent to SQL.
- If we don't pass a column name to `groupBy()`, each row is a group. This seems not to be useful, but it is in practice, because thanks to it we can apply aggregation functions on the rows that would not be possible otherwise; e.g., `df.min("col")` is not possible, but we need to do `df.groupBy().min("col")`.

In [27]:
# Group all flights by destination and for them
# pick the minimum distance
flights.groupBy("dest").min("distance").show(5)

+----+-------------+
|dest|min(distance)|
+----+-------------+
| MSY|         2086|
| GEG|          224|
| BUR|          817|
| SNA|          859|
| EUG|          106|
+----+-------------+
only showing top 5 rows



In [17]:
# Find the shortest flight from PDX in terms of distance
# Note that in this case we don't pass a column to groupBy, but 
# concatenate an aggregation function with the column, which applies to all rows
flights.filter(flights.origin == "PDX").groupBy().min("distance").show()

+-------------+
|min(distance)|
+-------------+
|          106|
+-------------+



In [18]:
# Find the longest flight from SEA in terms of air time
# Note that in this case we don't pass a column to groupBy, but 
# concatenate an aggregation function with the column, which applies to all rows
flights.filter(flights.origin == "SEA").groupBy().max("air_time").show()

+-------------+
|max(air_time)|
+-------------+
|        409.0|
+-------------+



In [19]:
# Average duration of Delta flights
# Note that in this case we don't pass a column to groupBy, but 
# concatenate an aggregation function with the column, which applies to all rows
flights.filter(flights.carrier == "DL").filter(flights.origin == "SEA").groupBy().avg("air_time").show()

# Total hours in the air
# Note that in this case we don't pass a column to groupBy, but 
# concatenate an aggregation function with the column, which applies to all rows
flights.withColumn("duration_hrs", flights.air_time/60).groupBy().sum("duration_hrs").show()

+------------------+
|     avg(air_time)|
+------------------+
|188.20689655172413|
+------------------+

+------------------+
| sum(duration_hrs)|
+------------------+
|25289.600000000126|
+------------------+



In [20]:
# Group by tailnum
by_plane = flights.groupBy("tailnum")

# Number of flights each plane made
by_plane.count().show(5)

# Group by origin
by_origin = flights.groupBy("origin")

# Average duration of flights from PDX and SEA
by_origin.avg("air_time").show()

+-------+-----+
|tailnum|count|
+-------+-----+
| N442AS|   38|
| N102UW|    2|
| N36472|    4|
| N38451|    4|
| N73283|    4|
+-------+-----+
only showing top 5 rows

+------+------------------+
|origin|     avg(air_time)|
+------+------------------+
|   SEA| 160.4361496051259|
|   PDX|137.11543248288737|
+------+------------------+



### Aggregation Functions

The module `pyspark.sql.functions` contains many aggregation functions which we can use with the generic `agg()` method which is applied after any `groupBy()`:

- `abs()`: Computes the absolute value of a column.
- `avg()`: Computes the average of a column.
- `stddev()`: Computes the standard deviation of a column.
- `col()`: Returns a column based on the given column name.
- `concat()`: Concatenates multiple columns together.
- `count()`: Counts the number of non-null values in a column.
- `date_format()`: Formats a date or timestamp column based on a specified format string.
- `dayofmonth()`: Extracts the day of the month from a date or timestamp column.
- `explode()`: Transforms an array column into multiple rows.
- `first()`: Returns the first value of a column in a group.
- `lit()`: Creates a column with a literal value.
- `max()`: Computes the maximum value of a column.
- `min()`: Computes the minimum value of a column.
- `month()`: Extracts the month from a date or timestamp column.
- `round()`: Rounds a column to a specified number of decimal places.
- `split()`: Splits a string column based on a delimiter.
- `sum()`: Computes the sum of a column.
- `to_date()`: Converts a string column to a date column.
- `to_timestamp()`: Converts a string column to a timestamp column.
- `udf()`: Defines a user-defined function that can be used in PySpark.

... and many more.

In [21]:
# Import pyspark.sql.functions as F
import pyspark.sql.functions as F

# Group by month and dest
by_month_dest = flights.groupBy("month", "dest")

# Average departure delay by month and destination
by_month_dest.avg("dep_delay").show(5)

# Standard deviation of departure delay
by_month_dest.agg(F.stddev("dep_delay")).show(5)

+-----+----+------------------+
|month|dest|    avg(dep_delay)|
+-----+----+------------------+
|    4| PHX|1.6833333333333333|
|    1| RDM|            -1.625|
|    5| ONT|3.5555555555555554|
|    7| OMA|              -6.5|
|    8| MDW|              7.45|
+-----+----+------------------+
only showing top 5 rows

+-----+----+----------------------+
|month|dest|stddev_samp(dep_delay)|
+-----+----+----------------------+
|    4| PHX|    15.003380033491737|
|    1| RDM|     8.830749846821778|
|    5| ONT|    18.895178691342874|
|    7| OMA|    2.1213203435596424|
|    8| MDW|    14.467659032985843|
+-----+----+----------------------+
only showing top 5 rows



## 2.6 Joining: JOIN - join()

Joining Spark SQL dataframes is very similar to joining in SQL: we combine tables given the index values on key columns.

```python
dj_joined = df_left.join(other=df_right,
                         on="key_col",
                         how="left_outer")

# Possible values for how: 
#    "inner" (default), "outer", "left_outer", "right_outer", "leftsemi", and "cross"
# Other arguments for join():
# - suffixes: tuple of strings to append to the column names that overlap between the two DataFrames.
#    By default: "_x" and "_y"
# - broadcast: boolean value indicating whether to broadcast the smaller DataFrame to all nodes in the cluster
#    to speed up the join operation. By default: False

```

In [22]:
# Examine the data
print(airports.show(5))

+---+--------------------+----------+-----------+----+---+---+
|faa|                name|       lat|        lon| alt| tz|dst|
+---+--------------------+----------+-----------+----+---+---+
|04G|   Lansdowne Airport|41.1304722|-80.6195833|1044| -5|  A|
|06A|Moton Field Munic...|32.4605722|-85.6800278| 264| -5|  A|
|06C| Schaumburg Regional|41.9893408|-88.1012428| 801| -6|  A|
|06N|     Randall Airport| 41.431912|-74.3915611| 523| -5|  A|
|09J|Jekyll Island Air...|31.0744722|-81.4277778|  11| -4|  A|
+---+--------------------+----------+-----------+----+---+---+
only showing top 5 rows

None


In [23]:
# Rename the faa column to be dest
airports = airports.withColumnRenamed("faa", "dest")

In [24]:
# Examine the data
print(airports.show(5))

+----+--------------------+----------+-----------+----+---+---+
|dest|                name|       lat|        lon| alt| tz|dst|
+----+--------------------+----------+-----------+----+---+---+
| 04G|   Lansdowne Airport|41.1304722|-80.6195833|1044| -5|  A|
| 06A|Moton Field Munic...|32.4605722|-85.6800278| 264| -5|  A|
| 06C| Schaumburg Regional|41.9893408|-88.1012428| 801| -6|  A|
| 06N|     Randall Airport| 41.431912|-74.3915611| 523| -5|  A|
| 09J|Jekyll Island Air...|31.0744722|-81.4277778|  11| -4|  A|
+----+--------------------+----------+-----------+----+---+---+
only showing top 5 rows

None


In [25]:
# Join the DataFrames
flights_with_airports = flights.join(airports,
                                     on="dest",
                                     how="left_outer")

# Examine the new DataFrame
print(flights_with_airports.show(5))

+----+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+--------+--------+----+------+------------------+---------------+--------------------+---------+-----------+---+---+---+
|dest|year|month|day|dep_time|dep_delay|arr_time|arr_delay|carrier|tailnum|flight|origin|air_time|distance|hour|minute|      duration_hrs|flight_duration|                name|      lat|        lon|alt| tz|dst|
+----+----+-----+---+--------+---------+--------+---------+-------+-------+------+------+--------+--------+----+------+------------------+---------------+--------------------+---------+-----------+---+---+---+
| LAX|2014|   12|  8|     658|     -7.0|     935|       -5|     VX| N846VA|  1780|   SEA|   132.0|     954|   6|    58|               2.2|          132.0|    Los Angeles Intl|33.942536|-118.408075|126| -8|  A|
| HNL|2014|    1| 22|    1040|      5.0|    1505|        5|     AS| N559AS|   851|   SEA|   360.0|    2677|  10|    40|               6.0|          360.0|      