### Instructions:
Optimize the query plan

Suppose we want to compose query in which we get for each question also the number of answers to this question for each month. 

See the query below which does that in a suboptimal way and try to rewrite it to achieve a more optimal plan.


In [2]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, month
import os

# build Spark object
spark = SparkSession \
    .builder \
    .master('local[*]') \
    .appName('Optimize Spark Query') \
    .getOrCreate()

In [3]:
spark

### Create answersDF DataFrame

In [4]:
# set path to answers on HDFS
answers_input_path = 'data/answers'
answersDF = spark.read.parquet(answers_input_path)
answersDF.show()
answersDF.count() # count number of rows in DF

+-----------+---------+--------------------+--------+-------+-----+
|question_id|answer_id|       creation_date|comments|user_id|score|
+-----------+---------+--------------------+--------+-------+-----+
|     226592|   226595|2015-12-30 07:46:...|       3|  82798|    2|
|     388057|   388062|2018-02-23 02:52:...|       8|    520|   21|
|     293286|   293305|2016-11-18 05:35:...|       0|  47472|    2|
|     442499|   442503|2018-11-22 14:34:...|       0| 137289|    0|
|     293009|   293031|2016-11-16 21:36:...|       0|  83721|    0|
|     395532|   395537|2018-03-25 13:51:...|       0|   1325|    0|
|     329826|   329843|2017-04-29 23:42:...|       4|    520|    1|
|     294710|   295061|2016-11-27 09:29:...|       2| 114696|    2|
|     291910|   291917|2016-11-10 18:56:...|       0| 114696|    2|
|     372382|   372394|2017-12-04 10:17:...|       0| 172328|    0|
|     178387|   178394|2015-04-26 01:31:...|       6|  62726|    0|
|     393947|   393948|2018-03-18 06:22:...|    

### Show number of partitions for answersDF

In [5]:
answersDF.rdd.getNumPartitions()

4

### Create questionsDF DataFrame

In [6]:
# set path to questions on HDFS
questions_input_path = 'data/questions'
questionsDF = spark.read.parquet(questions_input_path)
questionsDF.show()
questionsDF.count()  # count number of rows in DF

+-----------+--------------------+--------------------+--------------------+------------------+--------+-------+-----+
|question_id|                tags|       creation_date|               title|accepted_answer_id|comments|user_id|views|
+-----------+--------------------+--------------------+--------------------+------------------+--------+-------+-----+
|     382738|[optics, waves, f...|2018-01-28 15:22:...|What is the pseud...|            382772|       0|  76347|   32|
|     370717|[field-theory, de...|2017-11-25 17:09:...|What is the defin...|              null|       1|  75085|   82|
|     339944|[general-relativi...|2017-06-18 04:32:...|Could gravitation...|              null|      13| 116137|  333|
|     233852|[homework-and-exe...|2016-02-05 05:19:...|When does travell...|              null|       9|  95831|  185|
|     294165|[quantum-mechanic...|2016-11-22 19:39:...|Time-dependent qu...|              null|       1| 118807|   56|
|     173819|[homework-and-exe...|2015-04-02 23:

### Show number of partitions for questionsDF

In [7]:
questionsDF.rdd.getNumPartitions()

4

### The original query to be optimized:

In [8]:
'''
Answers aggregation

Here we get number of answers per question per month
'''

answers_month = answersDF.withColumn('month', month('creation_date')).groupBy('question_id', 'month').agg(count('*').alias('cnt'))


resultDF = questionsDF.join(answers_month, 'question_id').select('question_id', 'creation_date', 'title', 'month', 'cnt')

resultDF.orderBy('question_id', 'month').show()
resultDF.count()

resultDF.explain()

+-----------+--------------------+--------------------+-----+---+
|question_id|       creation_date|               title|month|cnt|
+-----------+--------------------+--------------------+-----+---+
|     155989|2015-01-01 09:59:...|Frost bubble form...|    1|  1|
|     155989|2015-01-01 09:59:...|Frost bubble form...|    2|  1|
|     155990|2015-01-01 10:51:...|The abstract spac...|    1|  2|
|     155992|2015-01-01 11:44:...|centrifugal force...|    1|  1|
|     155993|2015-01-01 11:56:...|How can I estimat...|    1|  1|
|     155995|2015-01-01 13:16:...|Why should a solu...|    1|  3|
|     155996|2015-01-01 14:06:...|Why do we assume ...|    1|  2|
|     155996|2015-01-01 14:06:...|Why do we assume ...|    2|  1|
|     155996|2015-01-01 14:06:...|Why do we assume ...|   11|  1|
|     155997|2015-01-01 14:26:...|Why do square sha...|    1|  3|
|     155999|2015-01-01 15:01:...|Diagonalizability...|    1|  1|
|     156008|2015-01-01 16:48:...|Capturing a light...|    1|  2|
|     1560

### Spark defaults to 200 partitions when performing a join operation on two DataFrames. 
This is made apparent when calling .getNumPartitions() on the newly created DataFrame 'answers_month'

In [9]:
answers_month.rdd.getNumPartitions()

200

### In order to optimize the query, we can reduce the number of shuffles by using the .coalesce() function and setting the number of partitions to 4 in order to match the number of partitions in the original DataFrames. 

In [10]:
'''
Task: see the query plan of the previous result and rewrite the query to optimize it
'''

answers_month = answersDF.withColumn('month', month('creation_date')).groupBy('question_id', 'month').agg(count('*').alias('cnt'))

# Use .coalesce() method to reduce number of partitions to 4
answers_month = answers_month.coalesce(4)

resultDF = questionsDF.join(answers_month, 'question_id').select('question_id', 'creation_date', 'title', 'month', 'cnt')

resultDF.orderBy('question_id', 'month').show()
resultDF.count()

resultDF.explain()


+-----------+--------------------+--------------------+-----+---+
|question_id|       creation_date|               title|month|cnt|
+-----------+--------------------+--------------------+-----+---+
|     155989|2015-01-01 09:59:...|Frost bubble form...|    1|  1|
|     155989|2015-01-01 09:59:...|Frost bubble form...|    2|  1|
|     155990|2015-01-01 10:51:...|The abstract spac...|    1|  2|
|     155992|2015-01-01 11:44:...|centrifugal force...|    1|  1|
|     155993|2015-01-01 11:56:...|How can I estimat...|    1|  1|
|     155995|2015-01-01 13:16:...|Why should a solu...|    1|  3|
|     155996|2015-01-01 14:06:...|Why do we assume ...|    1|  2|
|     155996|2015-01-01 14:06:...|Why do we assume ...|    2|  1|
|     155996|2015-01-01 14:06:...|Why do we assume ...|   11|  1|
|     155997|2015-01-01 14:26:...|Why do square sha...|    1|  3|
|     155999|2015-01-01 15:01:...|Diagonalizability...|    1|  1|
|     156008|2015-01-01 16:48:...|Capturing a light...|    1|  2|
|     1560

### The number of partitions has been reduced to 4

In [11]:
answers_month.rdd.getNumPartitions()

4