In [2]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# PySpark

`PySpark` let's you work in a Spark cluster using python. To begin with you need to connect to a Spark cluster. For the purpose of this walkthrough, we won't connect to any real cluster, and all calculation will be done locally.

The procedures for working with data, doing calculations and so on are analogous to a real cluster.

## Connecting to a cluster

In order to create a connection with a cluster, you have to create an instance of `SparkContext()`. You can pass several oprional parameters and configurations that we're not going to discuss in here.

In [3]:
import pyspark

sc = pyspark.SparkContext.getOrCreate()

# Explore your cluster, work with the SparkSession

Spark's core data structure is the Resilient Distributed Dataset (RDD). This is a low level object that lets Spark work its magic by splitting data across multiple nodes in the cluster. The Spark DataFrame is an obstraction over RDDs that was designed to behave a lot like a SQL table. They easier to understand and also more optimized for complicated operations than RDDs.

## The spark session
To start working with Spark tables (DataFrames), you first have to create a `SparkSession` object from your `SparkContext`. You can think of the SparkContext as your connection to the cluster and the SparkSession as your interface with that connection.

In [4]:
spark = pyspark.sql.SparkSession.builder.getOrCreate()

## List existing tables

The `SparkSession` has an attribute called catalog which lists all the data inside the cluster. This attribute has a few methods for extracting different pieces of information.

The `.listTables()` method returns the names of all the tables in your cluster as a list.

In [5]:
# This import and call to create_temp_tables will be explained below
import helpers

helpers.create_temp_tables(spark)

spark.catalog.listTables()

[Table(name='merchants', database=None, description=None, tableType='TEMPORARY', isTemporary=True),
 Table(name='products', database=None, description=None, tableType='TEMPORARY', isTemporary=True)]

## Querying tables

One of the advantages of the DataFrame interface is that you can run SQL queries on the tables in your Spark cluster.

As you saw in the previous cell, one of the tables in your cluster is the merchants table. This table contains a row for merchant in our system.

Running a query on this table is as easy as using the `.sql()` method on your `SparkSession`. This method takes a string containing the query and returns a `DataFrame` with the results:

In [6]:
merchants = spark.sql('SELECT * FROM merchants')
print('Data type of sql() return object is a DataFrame: {}'.format(type(merchants)))
merchants.show()

Data type of sql() return object is a DataFrame: <class 'pyspark.sql.dataframe.DataFrame'>
+-----------+---------+--------------------+---------+
|merchant_id|     name|               email|month_cpv|
+-----------+---------+--------------------+---------+
|          1|  Jon Doe|    jon@nonsense.com|    239.0|
|          2| Jane Doe|   jane@nonsense.com|    354.6|
|          3|No Profit|noprofit@nonsense...|      0.0|
|          4|  Mr Rich|   rich@nonsense.com|  12345.0|
+-----------+---------+--------------------+---------+



And the same with the products table:

In [7]:
products = spark.sql('SELECT * FROM products')
products.show()

+----------+------------+-----+-----------+
|product_id|product_name|price|merchant_id|
+----------+------------+-----+-----------+
|         1|       paper|   10|          1|
|         2|       glass|  100|          1|
|         3|       watch| 1900|          4|
|         4|       phone| 7000|          4|
|         5|       mouse|  200|          4|
|         6|       shirt|  300|          2|
|         7|      jacket|  900|          2|
|         8|        beer|   39|          3|
|         9|      coffee|   32|          3|
|        10|   ice cream|   30|          3|
+----------+------------+-----+-----------+



You can run any (or most of) SQL queries against spark `DataFrames`:

In [8]:
low_cpv = spark.sql('SELECT * FROM merchants WHERE month_cpv < 300')
low_cpv.show()

+-----------+---------+--------------------+---------+
|merchant_id|     name|               email|month_cpv|
+-----------+---------+--------------------+---------+
|          1|  Jon Doe|    jon@nonsense.com|    239.0|
|          3|No Profit|noprofit@nonsense...|      0.0|
+-----------+---------+--------------------+---------+



An important note is that you can "pandify" your spark DataFrames. This will give you all the power of pandas dataframes:

In [9]:
merchants_pd = merchants.toPandas()
merchants_pd.head()

Unnamed: 0,merchant_id,name,email,month_cpv
0,1,Jon Doe,jon@nonsense.com,239.0
1,2,Jane Doe,jane@nonsense.com,354.6
2,3,No Profit,noprofit@nonsense.com,0.0
3,4,Mr Rich,rich@nonsense.com,12345.0


## Create new data

The `.createDataFrame()` method takes a pandas DataFrame and returns a Spark DataFrame.

The output of this method is stored locally, not in the SparkSession catalog. This means that you can use all the Spark DataFrame methods on it, but you can't access the data in other contexts.

To access the data in this way, you have to save it as a temporary table.

You can do this using the `.createTempView()` Spark DataFrame method, which takes as its only argument the name of the temporary table you'd like to register. This method registers the DataFrame as a table in the catalog, but as this table is temporary, it can only be accessed from the specific SparkSession used to create the Spark DataFrame.

There is also the method `.createOrReplaceTempView()`. This safely creates a new temporary table if nothing was there before, or updates an existing table if one was already defined.

This is actually how we've generated the merchants and products tables. Only that instead of creating the tables from pandas DataFrames, we've done it from CSV files.

Take a look at the code for `create_temp_tables` in `helpers.py`

In [10]:
import inspect

print(inspect.getsource(helpers.create_temp_tables))

def create_temp_tables(spark):
    """ Create some temporal tables for data manipulation

    Params:
        spark - Spark session
    """
    merchants = spark.read.csv('data/merchants.csv', header=True, inferSchema=True)
    products = spark.read.csv('data/products.csv', header=True, inferSchema=True)
    merchants.createOrReplaceTempView('merchants')
    products.createOrReplaceTempView('products')



# More data manipulation

## Create new columns

In Spark you can do this using the `.withColumn()` method, which takes two arguments. First, a string with the name of your new column, and second the new column itself.

The new column must be an object of class Column. You can do this by extracting a column from your DataFrame using `df.colName`.

Updating a Spark DataFrame is different than working in pandas because the Spark DataFrame is immutable. This means that it can't be changed, and so columns can't be updated in place.

Thus, all these methods return a new DataFrame. To overwrite the original DataFrame you must reassign the returned DataFrame using the method like so:

    df = df.withColumn("newCol", df.oldCol + 1)

The above code creates a DataFrame with the same columns as df plus a new column, newCol, where every entry is equal to the corresponding entry from oldCol, plus one.

Let's try and add an extra column to our products table:

In [11]:
new_products = products.withColumn('10_discount', products.price*0.9)
new_products.show()

+----------+------------+-----+-----------+-----------+
|product_id|product_name|price|merchant_id|10_discount|
+----------+------------+-----+-----------+-----------+
|         1|       paper|   10|          1|        9.0|
|         2|       glass|  100|          1|       90.0|
|         3|       watch| 1900|          4|     1710.0|
|         4|       phone| 7000|          4|     6300.0|
|         5|       mouse|  200|          4|      180.0|
|         6|       shirt|  300|          2|      270.0|
|         7|      jacket|  900|          2|      810.0|
|         8|        beer|   39|          3|       35.1|
|         9|      coffee|   32|          3|       28.8|
|        10|   ice cream|   30|          3|       27.0|
+----------+------------+-----+-----------+-----------+



## Selecting, filtering, aggregating, grouping, joining... SQL all the things!

Instead of using the spark session to run SQL-like queries, you can also use the built in methods in the spark DataFrames.

### Selecting

In [12]:
# Syntax 1, with strings
names = merchants.select('name', 'email', 'month_cpv')
names.show()

# Syntax 2, with DataFrame column names
emails = merchants.select(merchants.name, merchants.email)
emails.show()

+---------+--------------------+---------+
|     name|               email|month_cpv|
+---------+--------------------+---------+
|  Jon Doe|    jon@nonsense.com|    239.0|
| Jane Doe|   jane@nonsense.com|    354.6|
|No Profit|noprofit@nonsense...|      0.0|
|  Mr Rich|   rich@nonsense.com|  12345.0|
+---------+--------------------+---------+

+---------+--------------------+
|     name|               email|
+---------+--------------------+
|  Jon Doe|    jon@nonsense.com|
| Jane Doe|   jane@nonsense.com|
|No Profit|noprofit@nonsense...|
|  Mr Rich|   rich@nonsense.com|
+---------+--------------------+



you can also use `selectExpr` to select new columns as operations with other columns..

In [13]:
products.selectExpr('product_name', 'price', 'price*0.9 as discounted').show()

+------------+-----+----------+
|product_name|price|discounted|
+------------+-----+----------+
|       paper|   10|       9.0|
|       glass|  100|      90.0|
|       watch| 1900|    1710.0|
|       phone| 7000|    6300.0|
|       mouse|  200|     180.0|
|       shirt|  300|     270.0|
|      jacket|  900|     810.0|
|        beer|   39|      35.1|
|      coffee|   32|      28.8|
|   ice cream|   30|      27.0|
+------------+-----+----------+



or create an alias column and use it for selection

In [14]:
discount = (products.price*0.9).alias('discounted')
products.select('product_name', 'price', discount).show()

+------------+-----+----------+
|product_name|price|discounted|
+------------+-----+----------+
|       paper|   10|       9.0|
|       glass|  100|      90.0|
|       watch| 1900|    1710.0|
|       phone| 7000|    6300.0|
|       mouse|  200|     180.0|
|       shirt|  300|     270.0|
|      jacket|  900|     810.0|
|        beer|   39|      35.1|
|      coffee|   32|      28.8|
|   ice cream|   30|      27.0|
+------------+-----+----------+



### Filtering 
You can also apply filters to your data, and chain them as well

In [15]:
filter_by_j = merchants.name.contains('J')
merchants.filter(filter_by_j).show()

+-----------+--------+-----------------+---------+
|merchant_id|    name|            email|month_cpv|
+-----------+--------+-----------------+---------+
|          1| Jon Doe| jon@nonsense.com|    239.0|
|          2|Jane Doe|jane@nonsense.com|    354.6|
+-----------+--------+-----------------+---------+



In [16]:
filter_low_cpv = merchants.month_cpv < 300
merchants.filter(filter_by_j).filter(filter_low_cpv).show()

+-----------+-------+----------------+---------+
|merchant_id|   name|           email|month_cpv|
+-----------+-------+----------------+---------+
|          1|Jon Doe|jon@nonsense.com|    239.0|
+-----------+-------+----------------+---------+



### Aggregating

All of the common aggregation methods `.min()`, `.max()` and `.count()` are `GroupedData` methods. These are created by calling the `.groupBy()` DataFrame method. For example, to find the product with the lower price, you can do:

In [17]:
products.groupBy().min('price').show()

+----------+
|min(price)|
+----------+
|        10|
+----------+



In [18]:
# Average price products of merchant 1
products.filter(products.merchant_id == 1).groupBy().avg('price').show()

+----------+
|avg(price)|
+----------+
|      55.0|
+----------+



### Grouping
You can also pass arguments to the `groupBy` method to group your data, and run aggregations on it

In [19]:
# Average product price per merchant
products.groupBy('merchant_id').avg('price').show()

+-----------+------------------+
|merchant_id|        avg(price)|
+-----------+------------------+
|          1|              55.0|
|          3|33.666666666666664|
|          4|3033.3333333333335|
|          2|             600.0|
+-----------+------------------+



In addition to the GroupedData methods, there is also the `.agg()` method. This method lets you pass an aggregate column expression that uses any of the aggregate functions from the `pyspark.sql.functions` submodule.

This submodule contains many useful functions for computing things like standard deviations.

In [20]:
import pyspark.sql.functions as F

# Standard deviation of product prices by merchant
products.groupBy('merchant_id').agg(F.stddev('price')).show()

+-----------+------------------+
|merchant_id|stddev_samp(price)|
+-----------+------------------+
|          1| 63.63961030678928|
|          3| 4.725815626252609|
|          4|3538.8321990924255|
|          2|424.26406871192853|
+-----------+------------------+



### Joining

Finally, you can also join tables in spark, and it's easy!

Joins are performed using the DataFrame method `.join()`. This method takes three arguments. The first is the second DataFrame that you want to join with the first one. The second argument, `on`, is the name of the key column(s) as a string. **The names of the key column(s) must be the same in each table**. The third argument, `how`, specifies the kind of join to perform.

In [21]:
merchants.join(products, on='merchant_id', how='leftouter').show()

+-----------+---------+--------------------+---------+----------+------------+-----+
|merchant_id|     name|               email|month_cpv|product_id|product_name|price|
+-----------+---------+--------------------+---------+----------+------------+-----+
|          1|  Jon Doe|    jon@nonsense.com|    239.0|         2|       glass|  100|
|          1|  Jon Doe|    jon@nonsense.com|    239.0|         1|       paper|   10|
|          2| Jane Doe|   jane@nonsense.com|    354.6|         7|      jacket|  900|
|          2| Jane Doe|   jane@nonsense.com|    354.6|         6|       shirt|  300|
|          3|No Profit|noprofit@nonsense...|      0.0|        10|   ice cream|   30|
|          3|No Profit|noprofit@nonsense...|      0.0|         9|      coffee|   32|
|          3|No Profit|noprofit@nonsense...|      0.0|         8|        beer|   39|
|          4|  Mr Rich|   rich@nonsense.com|  12345.0|         5|       mouse|  200|
|          4|  Mr Rich|   rich@nonsense.com|  12345.0|         4|