In [1]:
from pyspark.sql.types import StructType,StructField,TimestampType,StringType,LongType

# Define the common schema
answer_schema = StructType([\
                    StructField("question_id", LongType(),True),\
                    StructField("element", StringType(),True),\
                    StructField("creation_date", TimestampType(),True),\
                    StructField("title", StringType(),True),\
                    StructField("accepted_answer_id", LongType(),True),\
                    StructField("comments", StringType(),True),\
                    StructField("user_id", LongType(),True),\
                    StructField("views", LongType(),True)\
                          ])

In [2]:
'''
Optimize the query plan

Suppose we want to compose query in which we get for each question also the number of answers to this question for each month. See the query below which does that in a suboptimal way and try to rewrite it to achieve a more optimal plan.
'''

import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, count, month

import os
import time

spark = SparkSession.builder.appName('Optimize I').getOrCreate()

base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-3]) 

answers_input_path = os.path.join(project_path, 'data/answers')

questions_input_path = os.path.join(project_path, 'data/questions')

answersDF = spark.read.option("header", "true").option('path', answers_input_path).load()

questionsDF = spark.read.option('path', questions_input_path).schema(answer_schema).load()


In [3]:
'''
Answers aggregation

Here we : get number of answers per question per month
'''

start_time = time.time()

answers_month = answersDF.withColumn('month', month('creation_date'))

answers_month = answers_month.groupBy('question_id', 'month').agg(count('*').alias('cnt'))

resultDF = questionsDF.join(answers_month, 'question_id').select('question_id', 'creation_date', 'title', 'month', 'cnt')

resultDF.orderBy('question_id', 'month').show()

'''
Task:

see the query plan of the previous result and rewrite the query to optimize it
'''
print("--- %s seconds ---" % (time.time() - start_time))

+-----------+--------------------+--------------------+-----+---+
|question_id|       creation_date|               title|month|cnt|
+-----------+--------------------+--------------------+-----+---+
|     155989|2014-12-31 17:59:...|Frost bubble form...|    2|  1|
|     155989|2014-12-31 17:59:...|Frost bubble form...|   12|  1|
|     155990|2014-12-31 18:51:...|The abstract spac...|    1|  1|
|     155990|2014-12-31 18:51:...|The abstract spac...|   12|  1|
|     155992|2014-12-31 19:44:...|centrifugal force...|   12|  1|
|     155993|2014-12-31 19:56:...|How can I estimat...|    1|  1|
|     155995|2014-12-31 21:16:...|Why should a solu...|    1|  3|
|     155996|2014-12-31 22:06:...|Why do we assume ...|    1|  2|
|     155996|2014-12-31 22:06:...|Why do we assume ...|    2|  1|
|     155996|2014-12-31 22:06:...|Why do we assume ...|   11|  1|
|     155997|2014-12-31 22:26:...|Why do square sha...|    1|  3|
|     155999|2014-12-31 23:01:...|Diagonalizability...|    1|  1|
|     1560

In [4]:
resultDF.explain()

== Physical Plan ==
*(3) Project [question_id#12L, creation_date#14, title#15, month#28, cnt#44L]
+- *(3) BroadcastHashJoin [question_id#12L], [question_id#0L], Inner, BuildRight
   :- *(3) Project [question_id#12L, creation_date#14, title#15]
   :  +- *(3) Filter isnotnull(question_id#12L)
   :     +- *(3) ColumnarToRow
   :        +- FileScan parquet [question_id#12L,creation_date#14,title#15] Batched: true, DataFilters: [isnotnull(question_id#12L)], Format: Parquet, Location: InMemoryFileIndex[file:/C:/SpringBoard-DataEngineer/Spark-OptimizationProject/data/questions], PartitionFilters: [], PushedFilters: [IsNotNull(question_id)], ReadSchema: struct<question_id:bigint,creation_date:timestamp,title:string>
   +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, true])), [id=#135]
      +- *(2) HashAggregate(keys=[question_id#0L, month#28], functions=[count(1)])
         +- Exchange hashpartitioning(question_id#0L, month#28, 200), true, [id=#131]
            +- *(1) 

In [5]:
start_time = time.time()

spark.conf.set("spark.sql.adaptive.enabled","true")

answers_month = answersDF.withColumn('month', month('creation_date'))

# Repartition by Month to avoid shuffling

answers_month = answers_month.repartition(col("month"))

answers_month = answers_month.groupBy('question_id', 'month').agg(count('*').alias('cnt'))

resultDF = questionsDF.join(answers_month, 'question_id').select('question_id', 'creation_date', 'title', 'month', 'cnt')

resultDF.orderBy('question_id', 'month').show()

print("--- %s seconds ---" % (time.time() - start_time))

+-----------+--------------------+--------------------+-----+---+
|question_id|       creation_date|               title|month|cnt|
+-----------+--------------------+--------------------+-----+---+
|     155989|2014-12-31 17:59:...|Frost bubble form...|    2|  1|
|     155989|2014-12-31 17:59:...|Frost bubble form...|   12|  1|
|     155990|2014-12-31 18:51:...|The abstract spac...|    1|  1|
|     155990|2014-12-31 18:51:...|The abstract spac...|   12|  1|
|     155992|2014-12-31 19:44:...|centrifugal force...|   12|  1|
|     155993|2014-12-31 19:56:...|How can I estimat...|    1|  1|
|     155995|2014-12-31 21:16:...|Why should a solu...|    1|  3|
|     155996|2014-12-31 22:06:...|Why do we assume ...|    1|  2|
|     155996|2014-12-31 22:06:...|Why do we assume ...|    2|  1|
|     155996|2014-12-31 22:06:...|Why do we assume ...|   11|  1|
|     155997|2014-12-31 22:26:...|Why do square sha...|    1|  3|
|     155999|2014-12-31 23:01:...|Diagonalizability...|    1|  1|
|     1560

In [6]:
resultDF.explain()

== Physical Plan ==
AdaptiveSparkPlan isFinalPlan=false
+- Project [question_id#12L, creation_date#14, title#15, month#86, cnt#102L]
   +- BroadcastHashJoin [question_id#12L], [question_id#0L], Inner, BuildRight
      :- Project [question_id#12L, creation_date#14, title#15]
      :  +- Filter isnotnull(question_id#12L)
      :     +- FileScan parquet [question_id#12L,creation_date#14,title#15] Batched: true, DataFilters: [isnotnull(question_id#12L)], Format: Parquet, Location: InMemoryFileIndex[file:/C:/SpringBoard-DataEngineer/Spark-OptimizationProject/data/questions], PartitionFilters: [], PushedFilters: [IsNotNull(question_id)], ReadSchema: struct<question_id:bigint,creation_date:timestamp,title:string>
      +- BroadcastExchange HashedRelationBroadcastMode(List(input[0, bigint, true])), [id=#313]
         +- HashAggregate(keys=[question_id#0L, month#86], functions=[count(1)])
            +- HashAggregate(keys=[question_id#0L, month#86], functions=[partial_count(1)])
               