## Task I

* run the query belllow
  * in the query we want to aggregate number of questions for each user and then join it with the users dataset
* see the query plan and find out what is not optimal
* try to optimize it

In [None]:
import pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import *

from pyspark.sql import Window
from pyspark.sql.types import IntegerType

import os

In [None]:
spark = (
    SparkSession
    .builder
    .appName('Optimize IV')
    .enableHiveSupport()
    .getOrCreate()
)

In [None]:
questions = spark.table('questionsA')

users = spark.table('usersB')

In [None]:
spark.conf.set('spark.sql.autoBroadcastJoinThreshold', -1)

In [None]:
num_questions = (
    questions
    .groupBy('user_id')
    .agg(
        count('*').alias('n')
    )
)

(
    users.join(num_questions, 'user_id')
    .write
    .mode('overwrite')
    .format('noop')
    .save()  
)

See the query plan

* What is suboptimal?
* What can we do about it?

The users table is bucketed by `user_id`, so we should be able to join with the questions without shuffle in the users branch. That is indeed happening, but we have two shuffles in the other branch, but one shuffle should be sufficient for the aggregation, the join shouldn't require another shuffle.

The problem is that the joining column `user_id` has a different data type in both tables, it is `long` in users, but `int` in questions, which you can see when calling `printSchema()`. We can cast the type to be long before the aggregation, so spark doesn't need to re-shuffle the data again just because of the cast before the join.

In [None]:
questions.printSchema()

In [None]:
num_questions = (
    questions
    .withColumn('user_id', col('user_id').cast('long'))
    .groupBy('user_id')
    .agg(
        count('*').alias('n')
    )
)

In [None]:
(
    users.join(num_questions, 'user_id')
    .write
    .mode('overwrite')
    .format('noop')
    .save()  
)

In [None]:
spark.stop()