# Complex Data Types

In this notebook you will continue to improve the text categorization query implemented in `Text categorization` and `User Defined Functions` notebook. For each question find out which category has the most occurences in the text. Consider only questions for which we have at least one occurence.

In [None]:
"""
Example output:
+-----------+--------+---------+
|question_id|category|frequency|
+-----------+--------+---------+
|   59611343|    java|        2| # this means that java was contained 2x in the question text
|   21038752|  python|        5| # this means that python was contained 5x in the question text
|   44381369|     sql|        5|
+-----------+--------+---------+
"""

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, udf, lit, array, struct, reverse, array_sort, expr
from pyspark.sql.types import IntegerType

import os
import re

In [None]:
spark = (
    SparkSession
    .builder
    .appName('UDFs I')
    .getOrCreate()
)

In [None]:
base_path = os.getcwd()

project_path = ('/').join(base_path.split('/')[0:-3]) 

questions_input_path = os.path.join(project_path, 'output/questions-transformed')

In [None]:
questionsDF = (
    spark
    .read
    .option('path', questions_input_path)
    .load()
)

In [None]:
# This is what we implemented in `User Defined Functions` notebook.

categories = ['java', 'sql', 'python', 'spark']

@udf(IntegerType())
def count_occurences(message, category):
    return len(re.findall(r"{}".format(category) , message, re.IGNORECASE))

def get_c(df):
    for category in categories:
        df = df.withColumn(category, count_occurences(col('body'), lit(category)))
    return df

result = get_c(questionsDF.select('question_id', 'body'))

Now we will move it a step futher:

In [None]:
result.show()

### Find the most relevant category

* The result now contains number of occurences for each catagory.
* For each question find out which category has the most occurences

Hint
* For each question create an array of structs where the struct should have to subfields
 * category_name
 * frequency (number of occurences)
* Use a for-loop over the `cols` list to create the array
* Sort the array in descending order (have the `frequency` subfield on the first position in the struct)
* Access the subfields of the first element
* Functions you will need:
    * [array_sort](http://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.functions.array_sort.html#pyspark.sql.functions.array_sort)
    * [reverse](http://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.functions.reverse.html#pyspark.sql.functions.reverse)
    * [struct](http://spark.apache.org/docs/latest/api/python/reference/api/pyspark.sql.functions.struct.html#pyspark.sql.functions.struct)

In [None]:
# Using a for-loop create the expression that we will pass as an argument to the array function
# Use struct function: struct(a.alias(...), b.alias(...)), where a and b are the frequency col and the column name
# Using the alias in the struct is important here, because we have to make sure that the fields have the same name 
# for each element inside the array

s = []
for c in categories:
    s.append(struct(col(c).alias('frequency'), lit(c).alias('category_name')))

In [None]:
# Create the array using the array function
# Sort the array and take first element

(
    result
    .withColumn('categories', array(s))
    .withColumn('categories', reverse(array_sort('categories')))
    .select(
        'question_id',
        col('categories.category_name')[0].alias('category'),
        col('categories.frequency')[0].alias('frequency')
    )
    .filter(col('frequency') > 0)
).show(n=10)

In [None]:
# or equivalently you can first access the first element of the array and then select the category and frequency:

(
    result
    .withColumn('categories', array(s))
    .withColumn('categories', reverse(array_sort('categories')))
    .select(
        'question_id',
        col('categories')[0].alias('category') # get the first struct
    )
    .select(
        'question_id',
        col('category.category_name').alias('category'), # access the fields of the struct
        'category.frequency'
    )
    .filter(col('frequency') > 0)
).show(n=10)

#### Note
* When you sort array with structs, the position of the subfields is important.
* With the array_sort function when the frequency was the same we had no control over the sort
* Also we were relying on internal feature that the structs will be sorted according to the first subfield

#### Custimize the sort 

If the frequency is the same for two categories, prefer the `sql` category.

Hint
* use [array_sort](https://spark.apache.org/docs/latest/api/sql/index.html#array_sort) as a SQL expression inside expr - it will allow to use a custom comparator function
* for a specific example see my [article](https://towardsdatascience.com/did-you-know-this-in-spark-sql-a7398bfcc41e)

In [None]:
(
    result
    .withColumn('categories', array(s))
    .withColumn('categories', expr(
        """array_sort(categories, (left, right) -> case when left.frequency < right.frequency then 1
            when left.frequency > right.frequency then -1 
            when left.frequency == right.frequency and left.category_name == 'sql' then -1
            else 0 end)"""
    ))
    .select(
        'question_id',
        col('categories.category_name')[0].alias('category'),
        col('categories.frequency')[0].alias('frequency')
    )
    .filter(col('frequency') > 0)
).show(n=10)

In [None]:
spark.stop()