# Optimize the query plan I

Suppose we want to compose query in which we get for each question also the number of answers to this question for each month. See the query bellow in which does that in suboptimal way and try to rewrite it to achieve more optimal plan. 

In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, count, month, hour
)

import os

In [None]:
spark = (
    SparkSession
    .builder
    .appName('Optimize I')
    .getOrCreate()
)

In [None]:
spark.version

In [None]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-3]) 

answers_input_path = os.path.join(project_path, 'data/answers')

questions_input_path = os.path.join(project_path, 'data/questions')

In [None]:
# We will turn broadcast join off because we want to work with sort merge join (SMJ) because we want to assume that
# in practice both datasets are large so SMJ would manifest anyway

spark.conf.set('spark.sql.autoBroadcastJoinThreshold', -1)

In [None]:
answersDF = (
    spark
    .read
    .option('path', answers_input_path)
    .load()
)

questionsDF = (
    spark
    .read
    .option('path', questions_input_path)
    .load()
)

#### Answers aggregation:

Here we :
* get number of answers per question per month

In [None]:
answers_month = (
    answersDF
    .withColumn('month', month('creation_date'))
    .groupBy('question_id', 'month')
    .agg(
        count('*').alias('cnt')
    )
)

Here we join the original questions with the aggregation:

In [None]:
resultDF = (
    questionsDF
    .join(answers_month, 'question_id')
)

#### Execute the query

Here we:

* run the query with the `noop` format which will execute the query so we can see the query plan, but it will not write the output anywhere

In [None]:
(
    resultDF
    .write
    .mode('overwrite')
    .format('noop')
    .save()
)

# Task

see the query plan of the previous result and rewrite the query to optimize it

hint
* use [repartition](https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.repartition.html#pyspark.sql.DataFrame.repartition) before groupBy operation to reduce number of shuffles
* if you repartition the data by question_id, groupBy will not induce shuffle and the corresponding join branch will neither

In [None]:
answers_month = (
    answersDF
    .repartition('question_id')  # This repartition eliminates one shuffle
    .withColumn('month', month('creation_date'))
    .groupBy('question_id', 'month')
    .agg(
        count('*').alias('cnt')
    )
)

In [None]:
resultDF = (
    questionsDF
    .join(answers_month, 'question_id')
)

In [None]:
(
    resultDF
    .write
    .mode('overwrite')
    .format('noop')
    .save()
)

In [None]:
spark.stop()