# Optimize the query plan I

Suppose we want to compose query in which we get for each question also the number of answers to this question for each month. See the query bellow in which does that in suboptimal way and try to rewrite it to achieve more optimal plan. 

In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import (
    col, count, month
)

import os

In [None]:
spark = (
    SparkSession
    .builder
    .appName('Optimize I')
    .getOrCreate()
)

In [None]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-3]) 

answers_input_path = os.path.join(project_path, 'data/answers')

questions_input_path = os.path.join(project_path, 'output/questions-transformed')

In [None]:
answersDF = (
    spark
    .read
    .option('path', answers_input_path)
    .load()
)

questionsDF = (
    spark
    .read
    .option('path', questions_input_path)
    .load()
)

#### Answers aggregation:

Here we :
* get number of answers per question per month

In [None]:
answers_month = (
    answersDF
    .withColumn('month', month('creation_date'))
    .groupBy('question_id', 'month')
    .agg(
        count('*').alias('cnt')
    )
)

Here we join the original questions with the aggregation:

In [None]:
resultDF = (
    questionsDF
    .join(answers_month, 'question_id')
    .select(
        'question_id',
        'creation_date',
        'title',
        'month',
        'cnt'
    )
)

In [None]:
resultDF.orderBy('question_id', 'month').show()

# Task

see the query plan of the previous result and rewrite the query to optimize it

hint
* use repartition before groupBy operation to reduce number of shuffles
* if you repartition the data by question_id, groupBy will not induce shuffle and the corresponding join branch will neither

In [None]:
answers_month = (
    answersDF
    .repartition('question_id')  # This repartition eliminates one shuffle
    .withColumn('month', month('creation_date'))
    .groupBy('question_id', 'month')
    .agg(
        count('*').alias('cnt')
    )
)

In [None]:
resultDF = (
    questionsDF
    .join(answers_month, 'question_id')
    .select(
        'question_id',
        'creation_date',
        'title',
        'month',
        'cnt'
    )
)

In [None]:
resultDF.orderBy('question_id', 'month').show()

In [None]:
spark.stop()