# Spark Dataframes

- look like pandas dataframes
- share some of the same methods and syntax
- but they are 2 seperate types of objects

**Create Spark Session**

`pyspark.sql.SparkSession.builder.getOrCreate()`

In [1]:
import pyspark

spark = pyspark.sql.SparkSession.builder.getOrCreate()

#method to create spark session

## Create Dataframes

### From a Pandas Dataframe

First, we create the pandas dataframe. As a reminder, there are multiple ways to create a pandas dataframe. 

1. From a **dictionary-like object**, where we provide the values by columns. 

In [2]:
import pandas as pd
import numpy as np

# Create pandas dataframe by columns using dictionary-like object

pd_df = pd.DataFrame({'col1': ['r1c1', 'r2c1', 'r3c1'],
                      'col2': ['r1c2', 'r2c2', 'r3c2'],
                      'col3': ['r1c3', 'r3c3', 'r3c3']
                        }, 
                     index = [1, 2, 3])
pd_df

Unnamed: 0,col1,col2,col3
1,r1c1,r1c2,r1c3
2,r2c1,r2c2,r3c3
3,r3c1,r3c2,r3c3


2. From an **array-like object**, where we provide values by rows. 

In [3]:
# create pandas dataframe by rows
pd_df = pd.DataFrame([['r1c1', 'r1c2', 'r1c3'], 
                      ['r2c1', 'r2c2', 'r2c3'], 
                      ['r3c1', 'r3c2', 'r3c3']
                      ], 
                     index = [1, 2, 3], 
                     columns = ['col1', 'col2', 'col3'])

pd_df

Unnamed: 0,col1,col2,col3
1,r1c1,r1c2,r1c3
2,r2c1,r2c2,r2c3
3,r3c1,r3c2,r3c3


Next, **create the Spark dataframe** from Pandas dataframe using `spark.createDataFrame()`. `spark` is referring to the session we created in the beginning. 

In [4]:
sp_df = spark.createDataFrame(pd_df)
sp_df

DataFrame[col1: string, col2: string, col3: string]

**View the spark dataframe**

- spark is lazy
- .show (defaults to 20)

In [5]:
sp_df.show()

+----+----+----+
|col1|col2|col3|
+----+----+----+
|r1c1|r1c2|r1c3|
|r2c1|r2c2|r2c3|
|r3c1|r3c2|r3c3|
+----+----+----+



### Read from Files

Comparing pandas and spark

In [6]:
pd_v_spark = pd.DataFrame([['pd.read_csv("myfile.csv")', 
                            'spark.read.load("myfile.csv", format = "csv", sep = ",")'], 
                           ['pd.read_json("myfile.json")', 
                            'spark.read.load("myfile.json", format = "json") OR spark.read.json("myfile.json")']], 
                          index = ['csv', 'json'], 
                          columns = ['pandas', 'spark'])

# to display and see all text in dataframe
pd.set_option('display.max_colwidth', 10000)


pd_v_spark

Unnamed: 0,pandas,spark
csv,"pd.read_csv(""myfile.csv"")","spark.read.load(""myfile.csv"", format = ""csv"", sep = "","")"
json,"pd.read_json(""myfile.json"")","spark.read.load(""myfile.json"", format = ""json"") OR spark.read.json(""myfile.json"")"


### Summarize dataframes

Comparing pandas and spark

In [7]:
pd_v_spark = pd_v_spark.append(pd.DataFrame([['pd_df.head()', 'sp_df.show(), .head(), .take()'],
                                             ['pd_df.head(1)', 'sp_df.first()'],
                                             ['pd_df.describe()', 'sp_df.describe()'],
                                             ['pd_df.columns', 'sp_df.columns'],
                                             ['len(pd_df)', 'sp_df.count()'],
                                             ['len(pd_df.drop_duplicates())', 'sp_df.distinct().count()'],
                                             ['pd_df.info()', 'sp_df.printSchema()']
                                            ],
                                            index = ['1st n rows', '1st row','summary statistics', 
                                                     'column names', '# rows', '# distinct rows', 
                                                     'df schema info'], 
                                            columns = ['pandas', 'spark']))

In [8]:
pd_v_spark

Unnamed: 0,pandas,spark
csv,"pd.read_csv(""myfile.csv"")","spark.read.load(""myfile.csv"", format = ""csv"", sep = "","")"
json,"pd.read_json(""myfile.json"")","spark.read.load(""myfile.json"", format = ""json"") OR spark.read.json(""myfile.json"")"
1st n rows,pd_df.head(),"sp_df.show(), .head(), .take()"
1st row,pd_df.head(1),sp_df.first()
summary statistics,pd_df.describe(),sp_df.describe()
column names,pd_df.columns,sp_df.columns
# rows,len(pd_df),sp_df.count()
# distinct rows,len(pd_df.drop_duplicates()),sp_df.distinct().count()
df schema info,pd_df.info(),sp_df.printSchema()


Let's use a dataset with more realistic looking data to explore...

In [9]:
from pydataset import data

mpg_pd = data("mpg")
mpg_pd.head(5)

Unnamed: 0,manufacturer,model,displ,year,cyl,trans,drv,cty,hwy,fl,class
1,audi,a4,1.8,1999,4,auto(l5),f,18,29,p,compact
2,audi,a4,1.8,1999,4,manual(m5),f,21,29,p,compact
3,audi,a4,2.0,2008,4,manual(m6),f,20,31,p,compact
4,audi,a4,2.0,2008,4,auto(av),f,21,30,p,compact
5,audi,a4,2.8,1999,6,auto(l5),f,16,26,p,compact


pandas df to spark df...

In [12]:
mpg = spark.createDataFrame(data("mpg"))
mpg.show(5)

+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|  class|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
|        audi|   a4|  1.8|1999|  4|  auto(l5)|  f| 18| 29|  p|compact|
|        audi|   a4|  1.8|1999|  4|manual(m5)|  f| 21| 29|  p|compact|
|        audi|   a4|  2.0|2008|  4|manual(m6)|  f| 20| 31|  p|compact|
|        audi|   a4|  2.0|2008|  4|  auto(av)|  f| 21| 30|  p|compact|
|        audi|   a4|  2.8|1999|  6|  auto(l5)|  f| 16| 26|  p|compact|
+------------+-----+-----+----+---+----------+---+---+---+---+-------+
only showing top 5 rows



## Columns

Pandas series vs. Spark column objects

A column object represents a vertical slice of a dataframe, but does not contain the data itself. 

You will use it to perform functions on and reference that column. 

In [13]:
# pandas series
mpg_pd.model

1          a4
2          a4
3          a4
4          a4
5          a4
        ...  
230    passat
231    passat
232    passat
233    passat
234    passat
Name: model, Length: 234, dtype: object

In [14]:
# spark column object
mpg.model

Column<b'model'>

**Select columns**

Comparing pandas and spark

In [15]:
pd_v_spark = pd_v_spark.append(pd.DataFrame([['pd_df[["col1", "col2"]]', 
                                              'sp_df.select(sp_df.col1, sp_df.col2)']
                                            ],
                                            index = ['select columns'], 
                                            columns = ['pandas', 'spark']))
pd_v_spark

Unnamed: 0,pandas,spark
csv,"pd.read_csv(""myfile.csv"")","spark.read.load(""myfile.csv"", format = ""csv"", sep = "","")"
json,"pd.read_json(""myfile.json"")","spark.read.load(""myfile.json"", format = ""json"") OR spark.read.json(""myfile.json"")"
1st n rows,pd_df.head(),"sp_df.show(), .head(), .take()"
1st row,pd_df.head(1),sp_df.first()
summary statistics,pd_df.describe(),sp_df.describe()
column names,pd_df.columns,sp_df.columns
# rows,len(pd_df),sp_df.count()
# distinct rows,len(pd_df.drop_duplicates()),sp_df.distinct().count()
df schema info,pd_df.info(),sp_df.printSchema()
select columns,"pd_df[[""col1"", ""col2""]]","sp_df.select(sp_df.col1, sp_df.col2)"


Select columns hwy, cty, and model

In [16]:
mpg.select(mpg.hwy, mpg.cty, mpg.model)

DataFrame[hwy: bigint, cty: bigint, model: string]

How can I show the data?

In [17]:
mpg.select(mpg.hwy, mpg.cty, mpg.model).show(5)

+---+---+-----+
|hwy|cty|model|
+---+---+-----+
| 29| 18|   a4|
| 29| 21|   a4|
| 31| 20|   a4|
| 30| 21|   a4|
| 26| 16|   a4|
+---+---+-----+
only showing top 5 rows



Column objects support operations such as arithmetic operators

In [18]:
mpg.hwy + 1

Column<b'(hwy + 1)'>

In [19]:
mpg.select(mpg.hwy, mpg.hwy+1).show(3)

+---+---------+
|hwy|(hwy + 1)|
+---+---------+
| 29|       30|
| 29|       30|
| 31|       32|
+---+---------+
only showing top 3 rows



Once we have a column object, we can use the .alias method to rename it. 

In [20]:
mpg.select(mpg.hwy.alias('highway_mileage'), 
           (mpg.hwy+1).alias('highway_mileage_plus1')).show(3)

+---------------+---------------------+
|highway_mileage|highway_mileage_plus1|
+---------------+---------------------+
|             29|                   30|
|             29|                   30|
|             31|                   32|
+---------------+---------------------+
only showing top 3 rows



We can also store column objects in variables and reference them

In [21]:
col1 = mpg.hwy.alias('highway_mileage')
col2 = (mpg.hwy/2).alias('highway_mileage_halved')

mpg.select(col1, col2).show(3)

+---------------+----------------------+
|highway_mileage|highway_mileage_halved|
+---------------+----------------------+
|             29|                  14.5|
|             29|                  14.5|
|             31|                  15.5|
+---------------+----------------------+
only showing top 3 rows



In addition to the syntax we've seen above, we can create columns with the `col` and `expr` functions from `pyspark.sql.functions` module.

**col**

In [24]:
from pyspark.sql.functions import col, expr

col("hwy")

Column<b'hwy'>

In [25]:
mpg.hwy

Column<b'hwy'>

The column object produced by the col function is the same as the the previous column object we saw.

We can create `avg_mileage` using the col function to produce pyspark Column objects and using the arithmetic operators to combine them.

In [26]:
avg_col = (col("hwy") + col("cty")) / 2
avg_col

Column<b'((hwy + cty) / 2)'>

In [27]:
mpg.select(
    col("hwy").alias("hwy_mileage"), 
    mpg.cty.alias("cty_mileage"),
    avg_col.alias("avg_mileage")
).show(3)

+-----------+-----------+-----------+
|hwy_mileage|cty_mileage|avg_mileage|
+-----------+-----------+-----------+
|         29|         18|       23.5|
|         29|         21|       25.0|
|         31|         20|       25.5|
+-----------+-----------+-----------+
only showing top 3 rows



**expr**

- Does everything col does and more
- Returns the same type of column object
- But also allows us to express manipulations to the column within the string that defines the column.
- Which syntax to use is merely a style choice.

In [28]:
mpg.select(
    expr("hwy"), # the same as `col`
    expr("hwy + 1"), # arithmetic expression col("hwy") + 1
    expr("hwy AS hwy_mileage"), # using alias col("hwy").alias("hwy_mileage")
    expr("hwy + 1 AS hwy_incremented"), # a combo of the 2 above. 
).show(3)

+---+---------+-----------+---------------+
|hwy|(hwy + 1)|hwy_mileage|hwy_incremented|
+---+---------+-----------+---------------+
| 29|       30|         29|             30|
| 29|       30|         29|             30|
| 31|       32|         31|             32|
+---+---------+-----------+---------------+
only showing top 3 rows



In [29]:
mpg.select(
    mpg.hwy.alias("highway"),
    col("hwy").alias("highway"),
    expr("hwy").alias("highway"),
    expr("hwy AS highway")
).show(3)

+-------+-------+-------+-------+
|highway|highway|highway|highway|
+-------+-------+-------+-------+
|     29|     29|     29|     29|
|     29|     29|     29|     29|
|     31|     31|     31|     31|
+-------+-------+-------+-------+
only showing top 3 rows



## Spark SQL

- Spark SQL allows us to write SQL queries against our spark dataframes.  
- We'll first "register" the table with spark with `sp_df.createOrReplaceTempView('sp_df')`.  

In [30]:
mpg.createOrReplaceTempView('mpg_view')

- Now we can write a sql query against the mpg table.  

In [31]:
spark.sql(
    """
    SELECT hwy, cty, (hwy + cty)/2 AS avg
    FROM mpg_view
    """
).show(3)

+---+---+----+
|hwy|cty| avg|
+---+---+----+
| 29| 18|23.5|
| 29| 21|25.0|
| 31| 20|25.5|
+---+---+----+
only showing top 3 rows



- The resulting value is another dataframe. 
- To see the values, we have to ...

**Note:** All of these methods for creating / manipulating dataframes are the same in terms of performance. The resulting dataframes get turned into the same spark code that gets executed on the JVM, so it really is just a style choice as to which to use.

## Type Casting

**View column datatypes** using `dtypes` or `printSchema()`

In [32]:
mpg.dtypes

[('manufacturer', 'string'),
 ('model', 'string'),
 ('displ', 'double'),
 ('year', 'bigint'),
 ('cyl', 'bigint'),
 ('trans', 'string'),
 ('drv', 'string'),
 ('cty', 'bigint'),
 ('hwy', 'bigint'),
 ('fl', 'string'),
 ('class', 'string')]

In [33]:
mpg.printSchema()

root
 |-- manufacturer: string (nullable = true)
 |-- model: string (nullable = true)
 |-- displ: double (nullable = true)
 |-- year: long (nullable = true)
 |-- cyl: long (nullable = true)
 |-- trans: string (nullable = true)
 |-- drv: string (nullable = true)
 |-- cty: long (nullable = true)
 |-- hwy: long (nullable = true)
 |-- fl: string (nullable = true)
 |-- class: string (nullable = true)



To **convert** from one type to another use the `.cast` method on a column.

In [34]:
mpg.select(mpg.hwy.cast("string")).printSchema()

root
 |-- hwy: string (nullable = true)



If a value is not able to be converted, it will be replaced with null. 

In [35]:
mpg.select(mpg.model, mpg.model.cast("int")).show(3)

+-----+-----+
|model|model|
+-----+-----+
|   a4| null|
|   a4| null|
|   a4| null|
+-----+-----+
only showing top 3 rows



## Basic Built-in Functions

There are many other functions beyond col and expr within the pyspark.sql.functions module for operating on pyspark dataframe columns.

- `concat`: to concatenate strings  
- `sum`: to sum a group  
- `avg`: to take the average of a group  
- `min`: to find the minimum  
- `max`: to find the maximum  

**Note**: importing the sum, min and max functions directly will override the built-in sum, min and max functions. This means you will get an error if you try to sum a list of numbers, because sum will reference the relative pyspark function, which works with pyspark dataframe columns, while the relative built-in function works with lists of numbers.

In [38]:
# Note: The pyspark avg and mean functions are aliases of eachother

from pyspark.sql.functions import round, concat, sum, min, max, count, avg, mean

It very common to see something like:  

`import pyspark.sql.functions as F`

which will import all of the functions from the `pyspark.sql.functions` module.

Try out some functions: 

In [39]:
mpg.select(
    (sum(mpg.hwy) / count(mpg.hwy)).alias("avg1"),
    round(avg(mpg.hwy), 2).alias("avg2"),
    mean(mpg.hwy).alias("avg3"),
    min(mpg.hwy),
    max(mpg.hwy)
).show(3)

+-----------------+-----+-----------------+--------+--------+
|             avg1| avg2|             avg3|min(hwy)|max(hwy)|
+-----------------+-----+-----------------+--------+--------+
|23.44017094017094|23.44|23.44017094017094|      12|      44|
+-----------------+-----+-----------------+--------+--------+



In [40]:
mpg.select(concat(mpg.manufacturer, mpg.model)).show(3)

+---------------------------+
|concat(manufacturer, model)|
+---------------------------+
|                     audia4|
|                     audia4|
|                     audia4|
+---------------------------+
only showing top 3 rows



In order to use a string literal as part of our select, we'll need to use the `lit` function, otherwise spark will try to resolve our string as a column.

In [41]:
from pyspark.sql.functions import lit
mpg.select(concat(mpg.cyl, lit(" cylinders")).alias("cylinders")).show(3)

+-----------+
|  cylinders|
+-----------+
|4 cylinders|
|4 cylinders|
|4 cylinders|
+-----------+
only showing top 3 rows



## String Manipulation PySpark Functions

In order to demonstrate these functions we'll create a dataframe with some text data.

In [42]:
from pyspark.sql.functions import regexp_extract, regexp_replace

In [43]:
textdf = spark.createDataFrame(
    pd.DataFrame(
        {
            "address": [
                "600 Navarro St ste 600, San Antonio, TX 78205",
                "3130 Broadway St, San Antonio, TX 78209",
                "303 Pearl Pkwy, San Antonio, TX 78215",
                "1255 SW Loop 410, San Antonio, TX 78227",
            ]
        }
    )
)

textdf.show(truncate=False)

+---------------------------------------------+
|address                                      |
+---------------------------------------------+
|600 Navarro St ste 600, San Antonio, TX 78205|
|3130 Broadway St, San Antonio, TX 78209      |
|303 Pearl Pkwy, San Antonio, TX 78215        |
|1255 SW Loop 410, San Antonio, TX 78227      |
+---------------------------------------------+



`regexp_extract`: specify a regular expression with at least one capture group, and create a new column based on the contents of a capture group.


- first argument: the name of the string column to extract from.  
- second argument: the regular expression itself.  
- last argument: specifies which capture group we want to use. If, for example, our regular expression had 2 capture groups in it and we wanted the contents of the 2nd group, we would specify a 2 here.


In [46]:
textdf.select("address", 
              regexp_extract("address", 
                             r"^(\d+)", 1).alias("street_no"),
              regexp_extract("address", 
                             r"^\d+\s([\w\s]+?),", 1).alias("street")
             ).show(truncate=False)

+---------------------------------------------+---------+------------------+
|address                                      |street_no|street            |
+---------------------------------------------+---------+------------------+
|600 Navarro St ste 600, San Antonio, TX 78205|600      |Navarro St ste 600|
|3130 Broadway St, San Antonio, TX 78209      |3130     |Broadway St       |
|303 Pearl Pkwy, San Antonio, TX 78215        |303      |Pearl Pkwy        |
|1255 SW Loop 410, San Antonio, TX 78227      |1255     |SW Loop 410       |
+---------------------------------------------+---------+------------------+



`regexp_replace` lets us make substitutions based on a regular expression.

Below, we obtain just the city, state, and zip code of the address by replacing everything up to the first comma with an empty string.

In [47]:
textdf.select("address",
             regexp_replace("address", 
                           r"^.*?,\s*", "").alias("city_state_zip")
             ).show(truncate=False)

+---------------------------------------------+---------------------+
|address                                      |city_state_zip       |
+---------------------------------------------+---------------------+
|600 Navarro St ste 600, San Antonio, TX 78205|San Antonio, TX 78205|
|3130 Broadway St, San Antonio, TX 78209      |San Antonio, TX 78209|
|303 Pearl Pkwy, San Antonio, TX 78215        |San Antonio, TX 78215|
|1255 SW Loop 410, San Antonio, TX 78227      |San Antonio, TX 78227|
+---------------------------------------------+---------------------+



## Conditional Subsetting and Filtering of Dataframes

`.filter` and `.where` both allow us to select a subset of the rows of our dataframe.

In [48]:
pd_v_spark = pd_v_spark.append(pd.DataFrame([['pd_df[pd_df.c1 > 0]', 'sp_df.filter(df.c1 > 0), sp_df.where(df.c1 > 0)'],
                                            ],
                                            index = ['conditional filtering'], 
                                            columns = ['pandas', 'spark']))
pd_v_spark

Unnamed: 0,pandas,spark
csv,"pd.read_csv(""myfile.csv"")","spark.read.load(""myfile.csv"", format = ""csv"", sep = "","")"
json,"pd.read_json(""myfile.json"")","spark.read.load(""myfile.json"", format = ""json"") OR spark.read.json(""myfile.json"")"
1st n rows,pd_df.head(),"sp_df.show(), .head(), .take()"
1st row,pd_df.head(1),sp_df.first()
summary statistics,pd_df.describe(),sp_df.describe()
column names,pd_df.columns,sp_df.columns
# rows,len(pd_df),sp_df.count()
# distinct rows,len(pd_df.drop_duplicates()),sp_df.distinct().count()
df schema info,pd_df.info(),sp_df.printSchema()
select columns,"pd_df[[""col1"", ""col2""]]","sp_df.select(sp_df.col1, sp_df.col2)"


Use `filter()` and `where()`

In [49]:
mpg.filter(mpg.cyl==4).where(mpg["class"]=="subcompact").show(3)

+------------+-----+-----+----+---+----------+---+---+---+---+----------+
|manufacturer|model|displ|year|cyl|     trans|drv|cty|hwy| fl|     class|
+------------+-----+-----+----+---+----------+---+---+---+---+----------+
|       honda|civic|  1.6|1999|  4|manual(m5)|  f| 28| 33|  r|subcompact|
|       honda|civic|  1.6|1999|  4|  auto(l4)|  f| 24| 32|  r|subcompact|
|       honda|civic|  1.6|1999|  4|manual(m5)|  f| 25| 32|  r|subcompact|
+------------+-----+-----+----+---+----------+---+---+---+---+----------+
only showing top 3 rows



## Conditional Assigning of Values

Spark => when :  Excel => IF : SQL => CASE...WHEN : Python => numpy.where

- Specify a condition, and a value to produce if that condition is true

In [50]:
pd_v_spark = pd_v_spark.append(pd.DataFrame([['np.where(pd_df.c1.array > 0, "positive")', 
                                              'sp_df.select(df.c1, when(df.c1 > 0, "positive").alias("number_sign"))'],
                                            ],
                                            index = ['conditional assigning'], 
                                            columns = ['pandas', 'spark']))
pd_v_spark

Unnamed: 0,pandas,spark
csv,"pd.read_csv(""myfile.csv"")","spark.read.load(""myfile.csv"", format = ""csv"", sep = "","")"
json,"pd.read_json(""myfile.json"")","spark.read.load(""myfile.json"", format = ""json"") OR spark.read.json(""myfile.json"")"
1st n rows,pd_df.head(),"sp_df.show(), .head(), .take()"
1st row,pd_df.head(1),sp_df.first()
summary statistics,pd_df.describe(),sp_df.describe()
column names,pd_df.columns,sp_df.columns
# rows,len(pd_df),sp_df.count()
# distinct rows,len(pd_df.drop_duplicates()),sp_df.distinct().count()
df schema info,pd_df.info(),sp_df.printSchema()
select columns,"pd_df[[""col1"", ""col2""]]","sp_df.select(sp_df.col1, sp_df.col2)"


In [51]:
from pyspark.sql.functions import when

- If the condition we specified is false, null will be produced.   
- Use the `.otherwise` method to specify a value to use if our condition is false  

In [52]:
pd_v_spark = pd_v_spark.append(pd.DataFrame([['np.where(pd_df.c1.array > 0, "pos", "neg")', 
                                              'sp_df.select(df.c1, when(df.c1 > 0, "pos").otherwise("neg").alias("number_sign"))'],
                                            ],
                                            index = ['conditional assigning with else'], 
                                            columns = ['pandas', 'spark']))
pd_v_spark

Unnamed: 0,pandas,spark
csv,"pd.read_csv(""myfile.csv"")","spark.read.load(""myfile.csv"", format = ""csv"", sep = "","")"
json,"pd.read_json(""myfile.json"")","spark.read.load(""myfile.json"", format = ""json"") OR spark.read.json(""myfile.json"")"
1st n rows,pd_df.head(),"sp_df.show(), .head(), .take()"
1st row,pd_df.head(1),sp_df.first()
summary statistics,pd_df.describe(),sp_df.describe()
column names,pd_df.columns,sp_df.columns
# rows,len(pd_df),sp_df.count()
# distinct rows,len(pd_df.drop_duplicates()),sp_df.distinct().count()
df schema info,pd_df.info(),sp_df.printSchema()
select columns,"pd_df[[""col1"", ""col2""]]","sp_df.select(sp_df.col1, sp_df.col2)"


Use `.when()` with `.otherwise()`

In [54]:
mpg.select(mpg.hwy,
          when(mpg.hwy > 25, "good_mileage")
          .otherwise("bad_mileage")
          .alias("mpg_class")).show(10)

+---+------------+
|hwy|   mpg_class|
+---+------------+
| 29|good_mileage|
| 29|good_mileage|
| 31|good_mileage|
| 30|good_mileage|
| 26|good_mileage|
| 26|good_mileage|
| 27|good_mileage|
| 26|good_mileage|
| 25| bad_mileage|
| 28|good_mileage|
+---+------------+
only showing top 10 rows



In [56]:
mpg.select(
    mpg.displ,
    (when(mpg.displ < 2, "small")
     .when(mpg.displ < 3, "medium")
     .otherwise("large")
     .alias("engine_size")
    )).show(10)

+-----+-----------+
|displ|engine_size|
+-----+-----------+
|  1.8|      small|
|  1.8|      small|
|  2.0|     medium|
|  2.0|     medium|
|  2.8|     medium|
|  2.8|     medium|
|  3.1|      large|
|  1.8|      small|
|  1.8|      small|
|  2.0|     medium|
+-----+-----------+
only showing top 10 rows



- To specify multiple conditions, we can chain `.when` calls.   
- The first condition that is met will be the value that is used.  
- If none of the conditions are met the value specified in the .otherwise will be used (or null if you don't provide a .otherwise).  

Notice here that a car with a displ of 1.8 matches both conditions we specified, but small is produced because it is associated with the first matching condition. For any value between 2 and 3, medium will be produced, and anything larger than 3 will produce large.

## Sorting and Ordering

- Sort the rows by one or more columns with two methods: `.sort` and `.orderBy`. 
- `.sort` and `.orderBy` are aliases of each other and do the exact same thing. 
- Takes in a Column object or a string that is the name of a column.
- By default, values are sorted in ascending order.    

In [57]:
mpg.sort(mpg.hwy).show(8)

+------------+-------------------+-----+----+---+----------+---+---+---+---+------+
|manufacturer|              model|displ|year|cyl|     trans|drv|cty|hwy| fl| class|
+------------+-------------------+-----+----+---+----------+---+---+---+---+------+
|       dodge|        durango 4wd|  4.7|2008|  8|  auto(l5)|  4|  9| 12|  e|   suv|
|       dodge|ram 1500 pickup 4wd|  4.7|2008|  8|  auto(l5)|  4|  9| 12|  e|pickup|
|       dodge|ram 1500 pickup 4wd|  4.7|2008|  8|manual(m6)|  4|  9| 12|  e|pickup|
|        jeep| grand cherokee 4wd|  4.7|2008|  8|  auto(l5)|  4|  9| 12|  e|   suv|
|       dodge|  dakota pickup 4wd|  4.7|2008|  8|  auto(l5)|  4|  9| 12|  e|pickup|
|   chevrolet|    k1500 tahoe 4wd|  5.3|2008|  8|  auto(l4)|  4| 11| 14|  e|   suv|
|        jeep| grand cherokee 4wd|  6.1|2008|  8|  auto(l5)|  4| 11| 14|  p|   suv|
|  land rover|        range rover|  4.0|1999|  8|  auto(l4)|  4| 11| 15|  p|   suv|
+------------+-------------------+-----+----+---+----------+---+---+---+---+

- To sort in descending order, we can use the `.desc` method on any Column object, or the `desc` function from `pyspark.sql.functions`

In [58]:
from pyspark.sql.functions import asc, desc

mpg.sort(mpg.hwy.desc())
# is the same as...
mpg.sort(col("hwy").desc())
# is the same as ..
mpg.sort(desc("hwy")).show(5)

+------------+----------+-----+----+---+----------+---+---+---+---+----------+
|manufacturer|     model|displ|year|cyl|     trans|drv|cty|hwy| fl|     class|
+------------+----------+-----+----+---+----------+---+---+---+---+----------+
|  volkswagen|new beetle|  1.9|1999|  4|manual(m5)|  f| 35| 44|  d|subcompact|
|  volkswagen|     jetta|  1.9|1999|  4|manual(m5)|  f| 33| 44|  d|   compact|
|  volkswagen|new beetle|  1.9|1999|  4|  auto(l4)|  f| 29| 41|  d|subcompact|
|      toyota|   corolla|  1.8|2008|  4|manual(m5)|  f| 28| 37|  r|   compact|
|       honda|     civic|  1.8|2008|  4|  auto(l5)|  f| 24| 36|  c|subcompact|
+------------+----------+-----+----+---+----------+---+---+---+---+----------+
only showing top 5 rows



- To specify sorting by multiple columns, we provide each column as a separate argument to `.sort`.  

In the example below: 

1. Reverse alphabetically by the vehicle's class   
2. By the number of cylinders from lowest to highest  
3. By the vehicle's highway mileage, from greatest to smallest  

In [64]:
mpg.select(mpg['class'], 
           mpg.cyl, 
           mpg.hwy).sort(desc("class"), mpg.cyl.asc(), col("hwy").desc()).show()

+-----+---+---+
|class|cyl|hwy|
+-----+---+---+
|  suv|  4| 27|
|  suv|  4| 26|
|  suv|  4| 25|
|  suv|  4| 25|
|  suv|  4| 24|
|  suv|  4| 23|
|  suv|  4| 20|
|  suv|  4| 20|
|  suv|  6| 22|
|  suv|  6| 20|
|  suv|  6| 20|
|  suv|  6| 20|
|  suv|  6| 19|
|  suv|  6| 19|
|  suv|  6| 19|
|  suv|  6| 19|
|  suv|  6| 19|
|  suv|  6| 17|
|  suv|  6| 17|
|  suv|  6| 17|
+-----+---+---+
only showing top 20 rows



In [65]:
pd_v_spark = pd_v_spark.append(pd.DataFrame([['pd_df.sort_values(by=["c1"])', 
                                              'sp_df.sort(sp_df.c1)'],
                                             ['pd_df.sort_values(by=["c1","c2"])',
                                              'sp_df.sort(sp_df.c1, sp_df.c2)'],
                                             ['pd_df.sort_values(by=["c1","c2"], ascending=[False, True])',
                                              'sp_df.sort(sp_df.c1.desc(), sp_df.c2)'],
                                             ['pd_df.sort_values(by=["c1","c2"], ascending=False)', 
                                              'sp_df.sort(desc("c1"), desc("c2")) OR sp_df.sort(col("c1").desc(), col("c2").desc())']
                                            ],
                                            index = ['sort 1 col asc', 'sort 2+ cols asc', 'sort 2+ cols desc/asc', 'sort 2+ cols desc'], 
                                            columns = ['pandas', 'spark']))
pd_v_spark

Unnamed: 0,pandas,spark
csv,"pd.read_csv(""myfile.csv"")","spark.read.load(""myfile.csv"", format = ""csv"", sep = "","")"
json,"pd.read_json(""myfile.json"")","spark.read.load(""myfile.json"", format = ""json"") OR spark.read.json(""myfile.json"")"
1st n rows,pd_df.head(),"sp_df.show(), .head(), .take()"
1st row,pd_df.head(1),sp_df.first()
summary statistics,pd_df.describe(),sp_df.describe()
column names,pd_df.columns,sp_df.columns
# rows,len(pd_df),sp_df.count()
# distinct rows,len(pd_df.drop_duplicates()),sp_df.distinct().count()
df schema info,pd_df.info(),sp_df.printSchema()
select columns,"pd_df[[""col1"", ""col2""]]","sp_df.select(sp_df.col1, sp_df.col2)"


## Grouping and Aggregating

- To aggregate our data by group, use the `.groupBy` method.  
- Like with .select and .sort, we can pass either Column objects or strings that are column names to .groupBy.  
- All of the expressions below are equivalent.

In [66]:
mpg.groupBy(mpg.cyl)
mpg.groupBy(col('cyl'))
mpg.groupBy('cyl')

<pyspark.sql.group.GroupedData at 0x7fcbe3a126d8>

- Once the data is grouped, specify an aggregation.    
- We can use one of the aggregate functions we imported earlier, along with a column  

In [68]:
mpg.groupBy(mpg.cyl).agg(avg(mpg.cty), avg(mpg.hwy)).show(5)

+---+------------------+-----------------+
|cyl|          avg(cty)|         avg(hwy)|
+---+------------------+-----------------+
|  6| 16.21518987341772|22.82278481012658|
|  5|              20.5|            28.75|
|  8|12.571428571428571|17.62857142857143|
|  4|21.012345679012345|28.80246913580247|
+---+------------------+-----------------+



- To group by multiple columns, pass each of the columns as a separate argument to .groupBy.   
- This is different from pandas, where we would need to pass a list.  

In [69]:
mpg.groupBy("cyl", "class").agg(avg(mpg.cty), avg(mpg.hwy)).show()

+---+----------+------------------+------------------+
|cyl|     class|          avg(cty)|          avg(hwy)|
+---+----------+------------------+------------------+
|  5|   compact|              21.0|              29.0|
|  5|subcompact|              20.0|              28.5|
|  6|subcompact|              17.0|24.714285714285715|
|  6|    pickup|              14.5|              17.9|
|  4|subcompact|22.857142857142858| 30.80952380952381|
|  8|       suv|12.131578947368421|16.789473684210527|
|  8|    pickup|              11.8|              15.8|
|  8|   midsize|              16.0|              24.0|
|  4|   midsize|              20.5|           29.1875|
|  8|   2seater|              15.4|              24.8|
|  6|   compact|16.923076923076923|25.307692307692307|
|  6|   minivan|              15.6|              22.2|
|  4|   compact|            21.375|          29.46875|
|  8|subcompact|              14.8|              21.6|
|  6|   midsize|17.782608695652176| 26.26086956521739|
|  4|   mi

- In addition to `.groupBy`, we can use `.rollup`, which will do the same aggregations, but will also include the overall total.  
- Below the null value in cyl indicates the total count.  

In [70]:
mpg.rollup("cyl").count().sort("cyl").show()

+----+-----+
| cyl|count|
+----+-----+
|null|  234|
|   4|   81|
|   5|    4|
|   6|   79|
|   8|   70|
+----+-----+



- Use `.rollup` to compute average by group with an overall average
- The null row represents the overall average highway mileage.

In [71]:
mpg.rollup("cyl").agg(expr("avg(hwy)")).sort("cyl").show()

# these are the same...

mpg.rollup("cyl").agg(avg(mpg.hwy)).sort("cyl").show()

+----+-----------------+
| cyl|         avg(hwy)|
+----+-----------------+
|null|23.44017094017094|
|   4|28.80246913580247|
|   5|            28.75|
|   6|22.82278481012658|
|   8|17.62857142857143|
+----+-----------------+



- You can rollup to multiple columns.  
- Where cyl = null you see the overall average.  
- Where cyl = n and class = null, you have the average across all classes for each cylinder value.  

In [72]:
mpg.rollup("cyl", "class").mean("hwy").sort(col("cyl"), col("class")
                                           ).show()

+----+----------+------------------+
| cyl|     class|          avg(hwy)|
+----+----------+------------------+
|null|      null| 23.44017094017094|
|   4|      null| 28.80246913580247|
|   4|   compact|          29.46875|
|   4|   midsize|           29.1875|
|   4|   minivan|              24.0|
|   4|    pickup|20.666666666666668|
|   4|subcompact| 30.80952380952381|
|   4|       suv|             23.75|
|   5|      null|             28.75|
|   5|   compact|              29.0|
|   5|subcompact|              28.5|
|   6|      null| 22.82278481012658|
|   6|   compact|25.307692307692307|
|   6|   midsize| 26.26086956521739|
|   6|   minivan|              22.2|
|   6|    pickup|              17.9|
|   6|subcompact|24.714285714285715|
|   6|       suv|              18.5|
|   8|      null| 17.62857142857143|
|   8|   2seater|              24.8|
+----+----------+------------------+
only showing top 20 rows



## Crosstabs and Pivot Tables  

- Another way to aggregate is by `.crosstab`.    
- Similar to pandas `.crosstab` function, in that it calculates the number of occurrences of each unique value from the two passed columns.    
- `.crosstab` does counts.  
- For a different aggregation, use `.pivot`.  

In [79]:
mpg.crosstab("class", "cyl").show()

+----------+---+---+---+---+
| class_cyl|  4|  5|  6|  8|
+----------+---+---+---+---+
|   midsize| 16|  0| 23|  2|
|subcompact| 21|  2|  7|  5|
|   2seater|  0|  0|  0|  5|
|    pickup|  3|  0| 10| 20|
|   minivan|  1|  0| 10|  0|
|       suv|  8|  0| 16| 38|
|   compact| 32|  2| 13|  0|
+----------+---+---+---+---+



To find the average highway mileage for each combination of car class and number of cylinders, we could use `.pivot`.  

In [80]:
mpg.groupBy("class").pivot("cyl").count().show()

+----------+----+----+----+----+
|     class|   4|   5|   6|   8|
+----------+----+----+----+----+
|subcompact|  21|   2|   7|   5|
|   compact|  32|   2|  13|null|
|   minivan|   1|null|  10|null|
|       suv|   8|null|  16|  38|
|   midsize|  16|null|  23|   2|
|    pickup|   3|null|  10|  20|
|   2seater|null|null|null|   5|
+----------+----+----+----+----+



In [74]:
mpg.groupBy("class").pivot("cyl").mean("hwy").sort(col("class")).show()

+----------+------------------+----+------------------+------------------+
|     class|                 4|   5|                 6|                 8|
+----------+------------------+----+------------------+------------------+
|   2seater|              null|null|              null|              24.8|
|   compact|          29.46875|29.0|25.307692307692307|              null|
|   midsize|           29.1875|null| 26.26086956521739|              24.0|
|   minivan|              24.0|null|              22.2|              null|
|    pickup|20.666666666666668|null|              17.9|              15.8|
|subcompact| 30.80952380952381|28.5|24.714285714285715|              21.6|
|       suv|             23.75|null|              18.5|16.789473684210527|
+----------+------------------+----+------------------+------------------+



You can see how this is a reshape of the following: 

In [84]:
mpg.groupBy("class", "cyl").agg(round(mean("hwy"), 2)).sort(col("class")).show()

+----------+---+------------------+
|     class|cyl|round(avg(hwy), 2)|
+----------+---+------------------+
|   2seater|  8|              24.8|
|   compact|  4|             29.47|
|   compact|  5|              29.0|
|   compact|  6|             25.31|
|   midsize|  8|              24.0|
|   midsize|  4|             29.19|
|   midsize|  6|             26.26|
|   minivan|  4|              24.0|
|   minivan|  6|              22.2|
|    pickup|  8|              15.8|
|    pickup|  4|             20.67|
|    pickup|  6|              17.9|
|subcompact|  5|              28.5|
|subcompact|  8|              21.6|
|subcompact|  6|             24.71|
|subcompact|  4|             30.81|
|       suv|  4|             23.75|
|       suv|  8|             16.79|
|       suv|  6|              18.5|
+----------+---+------------------+



You can see from above:   
- The unique values from the column we group by will be the rows in the resulting dataframe.  
- The unique values from the column we pivot on will become the columns.  
- The values in each cell will be equal to the aggregation we specified over the group of values defined by the intersection of the rows and the columns.  

## Handling Missing Data  

Let's take a look at how spark handles missing data. First we'll create a dataframe that has a few missing values:  

In [85]:
df = spark.createDataFrame(
    pd.DataFrame(
        {"x": [1, 2, np.nan, 4, 5, np.nan], "y": [np.nan, 0, 0, 3, 1, np.nan]}
    )
)
df.show()

+---+---+
|  x|  y|
+---+---+
|1.0|NaN|
|2.0|0.0|
|NaN|0.0|
|4.0|3.0|
|5.0|1.0|
|NaN|NaN|
+---+---+



Spark provides two main ways to deal with missing values:

- `.fill`: to replace missing values with a specified value  
- `.drop`: to drop rows containing missing values  
- Both methods are accessed through the `.na` property. We'll look at some examples below:  

In [86]:
df.na.drop().show()

+---+---+
|  x|  y|
+---+---+
|2.0|0.0|
|4.0|3.0|
|5.0|1.0|
+---+---+



In [87]:
df.na.fill(0).show()

+---+---+
|  x|  y|
+---+---+
|1.0|0.0|
|2.0|0.0|
|0.0|0.0|
|4.0|3.0|
|5.0|1.0|
|0.0|0.0|
+---+---+



For both methods, we can specify that we only want to fill or drop values in a specific column with a second argument:

In [88]:
df.na.fill(0, subset="x").na.fill(-1, subset="y").show()

+---+----+
|  x|   y|
+---+----+
|1.0|-1.0|
|2.0| 0.0|
|0.0| 0.0|
|4.0| 3.0|
|5.0| 1.0|
|0.0|-1.0|
+---+----+



Notice that above the na values in the x column were filled with 0, but the na values in y were left alone.

In the example below, the rows with an na value for the y column will be dropped, but the rows with na values for only the x column will remain.  

In [89]:
df.na.drop(subset="y").show()

# df.na.drop(subset="price").na.fill(0, subset="items").show()

+---+---+
|  x|  y|
+---+---+
|2.0|0.0|
|NaN|0.0|
|4.0|3.0|
|5.0|1.0|
+---+---+



## DataFrame Transformations

The .explain method will show us how spark is thinking about our dataframe.

In [90]:
mpg.explain()

== Physical Plan ==
*(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




For our basic example, we see that there is only a single step.

In [91]:
mpg.select(mpg.cyl, mpg.hwy).explain()

== Physical Plan ==
*(1) Project [cyl#112L, hwy#116L]
+- *(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




Here we are doing a more advanced select calculation, but this is still just a single step to spark.

In [92]:
mpg.filter(mpg.cyl == 6).explain()

== Physical Plan ==
*(1) Filter (isnotnull(cyl#112L) AND (cyl#112L = 6))
+- *(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




Notice that our filter is also a single step.

Without reading ahead, do you think the execution plan for the two dataframes below will be the same or not?

In [93]:
mpg.select("cyl", "hwy").filter(expr("cyl == 6")).explain()

== Physical Plan ==
*(1) Project [cyl#112L, hwy#116L]
+- *(1) Filter (isnotnull(cyl#112L) AND (cyl#112L = 6))
   +- *(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




In [94]:
mpg.filter(expr("cyl == 6")).select("cyl", "hwy").explain()

== Physical Plan ==
*(1) Project [cyl#112L, hwy#116L]
+- *(1) Filter (isnotnull(cyl#112L) AND (cyl#112L = 6))
   +- *(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




Notice that even though we specified the transformations (.select and .filter) in a different order, we end up with the same output when we call .explain. This is because spark will look at our dataframe and transform it into the most efficient representation possible.

In [95]:
mpg.selectExpr("cyl + 3*16/4 + 19 AS unused", "hwy").select("hwy").explain()

== Physical Plan ==
*(1) Project [hwy#116L]
+- *(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




Notice here that we have 2 seperate select statements, but spark will condense this down to a single Project, as it is smart enough to realize that it doesn't actually need to do all the arithmetic we specified in the first select, since we arent using that value later on.

In [96]:
mpg.select(min(mpg.cyl)).explain()

== Physical Plan ==
*(2) HashAggregate(keys=[], functions=[min(cyl#112L)])
+- Exchange SinglePartition, true, [id=#1246]
   +- *(1) HashAggregate(keys=[], functions=[partial_min(cyl#112L)])
      +- *(1) Project [cyl#112L]
         +- *(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




Notice now that the execution plan gets much more complicated. This is because in steps prior, we were applying transformations that applied to each row individually. To calculate a minimum, we have to look at all the rows in the dataset to find the smallest.

In [97]:
mpg.groupBy(mpg.cyl).agg(min(mpg.hwy), max(mpg.hwy)).explain()

== Physical Plan ==
*(2) HashAggregate(keys=[cyl#112L], functions=[min(hwy#116L), max(hwy#116L)])
+- Exchange hashpartitioning(cyl#112L, 200), true, [id=#1267]
   +- *(1) HashAggregate(keys=[cyl#112L], functions=[partial_min(hwy#116L), partial_max(hwy#116L)])
      +- *(1) Project [cyl#112L, hwy#116L]
         +- *(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




In [98]:
(
    mpg.select(col("cyl"), expr("(cty+hwy)/2 AS avg_mpg"))
    .filter(expr('class == "compact"'))
    .groupBy("cyl")
    .agg(min("avg_mpg"), avg("avg_mpg"), max("avg_mpg"))
    .explain()
)

== Physical Plan ==
*(2) HashAggregate(keys=[cyl#112L], functions=[min(avg_mpg#2968), avg(avg_mpg#2968), max(avg_mpg#2968)])
+- Exchange hashpartitioning(cyl#112L, 200), true, [id=#1292]
   +- *(1) HashAggregate(keys=[cyl#112L], functions=[partial_min(avg_mpg#2968), partial_avg(avg_mpg#2968), partial_max(avg_mpg#2968)])
      +- *(1) Project [cyl#112L, (cast((cty#115L + hwy#116L) as double) / 2.0) AS avg_mpg#2968]
         +- *(1) Filter (isnotnull(class#118) AND (class#118 = compact))
            +- *(1) Scan ExistingRDD[manufacturer#108,model#109,displ#110,year#111L,cyl#112L,trans#113,drv#114,cty#115L,hwy#116L,fl#117,class#118]




## More Dataframe Manipulation Examples

Let's take a look at some more examples of working with spark dataframes. For these examples, we'll be working with a dataset of observations of the weather in seattle.

In [99]:
from vega_datasets import data

weather = data.seattle_weather().assign(date=lambda df: df.date.astype(str))
weather = spark.createDataFrame(weather)
weather.show(6)

+----------+-------------+--------+--------+----+-------+
|      date|precipitation|temp_max|temp_min|wind|weather|
+----------+-------------+--------+--------+----+-------+
|2012-01-01|          0.0|    12.8|     5.0| 4.7|drizzle|
|2012-01-02|         10.9|    10.6|     2.8| 4.5|   rain|
|2012-01-03|          0.8|    11.7|     7.2| 2.3|   rain|
|2012-01-04|         20.3|    12.2|     5.6| 4.7|   rain|
|2012-01-05|          1.3|     8.9|     2.8| 6.1|   rain|
|2012-01-06|          2.5|     4.4|     2.2| 2.2|   rain|
+----------+-------------+--------+--------+----+-------+
only showing top 6 rows



Let's first find the dates where the data starts and stops:

In [100]:
min_date, max_date = weather.select(min("date"), max("date")).first()
min_date, max_date


('2012-01-01', '2015-12-31')

- `.select` to select the minimum date and the maximum date. 
- `.first` returns us the first row of our results, which consists of two values, and so can be unpacked into the min_date and max_date variables.  
- Combine the temp max and min columns into a single column, temp_avg.  

In [101]:
weather = weather.withColumn(
    "temp_avg", expr("ROUND(temp_min + temp_max) / 2")
).drop("temp_max", "temp_min")
weather.show(6)

+----------+-------------+----+-------+--------+
|      date|precipitation|wind|weather|temp_avg|
+----------+-------------+----+-------+--------+
|2012-01-01|          0.0| 4.7|drizzle|     9.0|
|2012-01-02|         10.9| 4.5|   rain|     6.5|
|2012-01-03|          0.8| 2.3|   rain|     9.5|
|2012-01-04|         20.3| 4.7|   rain|     9.0|
|2012-01-05|          1.3| 6.1|   rain|     6.0|
|2012-01-06|          2.5| 2.2|   rain|     3.5|
+----------+-------------+----+-------+--------+
only showing top 6 rows



Now we will calculate the total amount of rainfall for each month. We'll do this by first creating a month column, then grouping by the month, and finally, aggregating by taking the sum of the precipitation. To do this we will need to use the month function.


In [102]:
from pyspark.sql.functions import month, year, quarter

(
    weather.withColumn("month", month("date"))
    .groupBy("month")
    .agg(sum("precipitation").alias("total_rainfall"))
    .sort("month")
    .show()
)

+-----+------------------+
|month|    total_rainfall|
+-----+------------------+
|    1|465.99999999999994|
|    2|             422.0|
|    3|             606.2|
|    4|             375.4|
|    5|             207.5|
|    6|             132.9|
|    7|              48.2|
|    8|             163.7|
|    9|235.49999999999997|
|   10|             503.4|
|   11|             642.5|
|   12| 622.7000000000002|
+-----+------------------+



Let's now take a look at the average tempurature for each type of weather in December 2013:

In [103]:
(
    weather.filter(month("date") == 12)
    .filter(year("date") == 2013)
    .groupBy("weather")
    .agg(mean("temp_avg"))
    .show()
)

+-------+-----------------+
|weather|    avg(temp_avg)|
+-------+-----------------+
|    fog|7.555555555555555|
|    sun|2.977272727272727|
+-------+-----------------+



Here we first have a couple of .filter calls in order to restrict our data to December of 2013. We then group by the weather column, and lastly, aggregate by taking the average of our temp_avg column. The combination of group by and agg will calculate the average tempurature for each unique value of the weather column.

Let's now find out how many days had freezing tempuratures in each month of 2013.


In [107]:
(
    weather.filter(year("date") == 2013)
    .withColumn("freezing_temps", (weather.temp_avg <= 0).cast("int"))
    .withColumn("month", month("date"))
    .groupBy("month")
    .agg(sum("freezing_temps").alias("no_days_with_freezing_temps"))
    .sort("month")
).show()

+-----+---------------------------+
|month|no_days_with_freezing_temps|
+-----+---------------------------+
|    1|                          3|
|    2|                          0|
|    3|                          0|
|    4|                          0|
|    5|                          0|
|    6|                          0|
|    7|                          0|
|    8|                          0|
|    9|                          0|
|   10|                          0|
|   11|                          0|
|   12|                          5|
+-----+---------------------------+



In [None]:

(
    weather.filter(year("date") == 2013)
    .withColumn("freezing_temps", (weather.temp_avg <= 0).cast("int"))
    .withColumn("month", month("date"))
    .groupBy("month")
    .agg(sum("freezing_temps").alias("no_of_days_with_freezing_temps"))
    .sort("month")
    .show()
)

Joins
Like pandas and sql, spark has functionality that lets us combine two tabular datasets, known as a join.

We'll start by creating some data that we can join together:


In [108]:
users = spark.createDataFrame(
    pd.DataFrame(
        {
            "id": [1, 2, 3, 4, 5, 6],
            "name": ["bob", "joe", "sally", "adam", "jane", "mike"],
            "role_id": [1, 2, 3, 3, np.nan, np.nan],
        }
    )
)
roles = spark.createDataFrame(
    pd.DataFrame(
        {
            "id": [1, 2, 3, 4],
            "name": ["admin", "author", "reviewer", "commenter"],
        }
    )
)
print("--- users ---")
users.show()
print("--- roles ---")
roles.show()

--- users ---
+---+-----+-------+
| id| name|role_id|
+---+-----+-------+
|  1|  bob|    1.0|
|  2|  joe|    2.0|
|  3|sally|    3.0|
|  4| adam|    3.0|
|  5| jane|    NaN|
|  6| mike|    NaN|
+---+-----+-------+

--- roles ---
+---+---------+
| id|     name|
+---+---------+
|  1|    admin|
|  2|   author|
|  3| reviewer|
|  4|commenter|
+---+---------+



- To join two dataframes together, we'll need to call the `.join` method on one of them and supply the other as an argument.  
- In addition, we'll need to supply the condition on which we are joining.   
- In this case, we are joining where the role_id column on the users table is equal to the id column on the roles table.  
- By default, spark will perform an inner join

In [109]:
users.join(roles, on=users.role_id == roles.id).show()

+---+-----+-------+---+--------+
| id| name|role_id| id|    name|
+---+-----+-------+---+--------+
|  1|  bob|    1.0|  1|   admin|
|  3|sally|    3.0|  3|reviewer|
|  4| adam|    3.0|  3|reviewer|
|  2|  joe|    2.0|  2|  author|
+---+-----+-------+---+--------+



In [110]:
users.join(roles, on=users.role_id == roles.id, how="left").show()

+---+-----+-------+----+--------+
| id| name|role_id|  id|    name|
+---+-----+-------+----+--------+
|  5| jane|    NaN|null|    null|
|  6| mike|    NaN|null|    null|
|  1|  bob|    1.0|   1|   admin|
|  3|sally|    3.0|   3|reviewer|
|  4| adam|    3.0|   3|reviewer|
|  2|  joe|    2.0|   2|  author|
+---+-----+-------+----+--------+



Notice a duplicate id column. There are several ways we could go about dealing with this:

- alias each dataframe + explicitly select columns after joining (this could also be implemented with spark SQL).  
- rename duplicated columns before merging.  
- drop duplicated columns after the merge (.drop(right.id))  

## Visualization (or Lack Therof)

Spark does not provide a way to do visualization with their dataframes. To visualize data from spark, you should use the `.toPandas` method on a spark dataframe to convert it to a pandas dataframe, then visualize as you normally would.

In [111]:
users.toPandas()

Unnamed: 0,id,name,role_id
0,1,bob,1.0
1,2,joe,2.0
2,3,sally,3.0
3,4,adam,3.0
4,5,jane,
5,6,mike,
