# Creating Hash based Partitions on a S3 dataset.

Partition Pruning can significantly speed up queries on an S3 dataset.
For datasets where a distinct partition field cannot be identified, and especially for ranged get use cases e.g. query by user id or query by order id etc. partitioning the dataset by hashing the key field into a fixed number of partitions can signficantly speed up queries.

In this notebook we are using the Spark crc32 function to hash a field and then modulo to divide the hash into n buckets.

Let's look a sample orders dataset.

In [1]:
df=spark.read.parquet("s3://<databucket>/xyzzypq/")

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
17,application_1608318781561_0019,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [2]:
df.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------------+-----------------------+-------+-----------+-------------------+--------+----------+--------+---------+-------+------------------+------------------+----------+
|          DISCOUNT|LAST_MODIFIED_TIMESTAMP|LINE_ID|LINE_NUMBER|         ORDER_DATE|ORDER_ID|PRODUCT_ID|QUANTITY|SHIP_MODE|SITE_ID|       SUPPLY_COST|               TAX|UNIT_PRICE|
+------------------+-----------------------+-------+-----------+-------------------+--------+----------+--------+---------+-------+------------------+------------------+----------+
|505.15211456522854|    2013-09-20T22:40:04|      1|          1|2013-09-20T00:00:00|  115855|       410|      35|  ONE-DAY|    386|389.05923561572445| 625.2594318622095|     618.0|
|189.40738665133208|    2013-09-20T22:40:04|      2|          2|2013-09-20T00:00:00|  115855|        28|      76|  ONE-DAY|    386| 289.2103998597069|  298.229409855354|     498.0|
|160.04322240073517|    2013-09-20T22:40:04|      3|          3|2013-09-20T00:00:00|  115855|  

Let's add a PARTITION_KEY using the crc32 hash function and divide into 20 buckets.

In [3]:
from pyspark.sql.functions import *

df=df.withColumn("PARTITION_KEY",expr("mod(crc32(concat(ORDER_ID)),20)"))
df.createOrReplaceTempView("order_lines_v")
df.show(3)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------------+-----------------------+-------+-----------+-------------------+--------+----------+--------+---------+-------+------------------+------------------+----------+-------------+
|          DISCOUNT|LAST_MODIFIED_TIMESTAMP|LINE_ID|LINE_NUMBER|         ORDER_DATE|ORDER_ID|PRODUCT_ID|QUANTITY|SHIP_MODE|SITE_ID|       SUPPLY_COST|               TAX|UNIT_PRICE|PARTITION_KEY|
+------------------+-----------------------+-------+-----------+-------------------+--------+----------+--------+---------+-------+------------------+------------------+----------+-------------+
|505.15211456522854|    2013-09-20T22:40:04|      1|          1|2013-09-20T00:00:00|  115855|       410|      35|  ONE-DAY|    386|389.05923561572445| 625.2594318622095|     618.0|           19|
|189.40738665133208|    2013-09-20T22:40:04|      2|          2|2013-09-20T00:00:00|  115855|        28|      76|  ONE-DAY|    386| 289.2103998597069|  298.229409855354|     498.0|           19|
|160.04322240073517|    2

In [4]:
spark.sql("Select PARTITION_KEY, count(1) as RECORD_COUNT FROM  order_lines_v GROUP BY PARTITION_KEY ORDER BY 1").show(20)

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-------------+------------+
|PARTITION_KEY|RECORD_COUNT|
+-------------+------------+
|            0|      171761|
|            1|      172459|
|            2|      173135|
|            3|      171889|
|            4|      171900|
|            5|      171185|
|            6|      172586|
|            7|      172765|
|            8|      172008|
|            9|      172573|
|           10|      172168|
|           11|      170935|
|           12|      173715|
|           13|      172674|
|           14|      173131|
|           15|      173176|
|           16|      174665|
|           17|      173062|
|           18|      172022|
|           19|      172939|
+-------------+------------+

Now that the dataset is divided uniformly into 20 buckets, let's write it the data parititioned by the partition column.
We are also sorting the data within each partition by ORDER_ID

In [5]:
partitionColumns=["PARTITION_KEY"]
df.repartition(*partitionColumns).sortWithinPartitions("ORDER_ID").write.mode("OVERWRITE").partitionBy(partitionColumns).parquet("s3://<databucket>/xyzzypq-p1/")

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

## Querying the dataset using Amazon Athena:

Amazon Athena also supports the crc32 function so we can derive the partition from the ORDER_ID field.

In [6]:
%%local

import boto3,time
import pandas as pd

region='us-east-1'
defaultdb="tempdb"
default_output="s3://<athena_query_results_bucket>/"

## execute Athena SQL
def executeQuery(query, database=defaultdb, s3_output=default_output, poll=0.5):
    athena = boto3.client('athena',region_name=region)
    response = athena.start_query_execution(
        QueryString=query,
        QueryExecutionContext={
            'Database': database
            },
        ResultConfiguration={
            'OutputLocation': s3_output,
            }
        )

    print('Execution ID: ' + response['QueryExecutionId'])
    queryExecutionId=response['QueryExecutionId']
    state='QUEUED'
    while( state=='RUNNING' or state=='QUEUED'):
       response = athena.get_query_execution(QueryExecutionId=queryExecutionId)
       state=response['QueryExecution']['Status']['State']
       print (state)
       if  state=='RUNNING' or state=='QUEUED':
            time.sleep(poll)
       elif (state=='FAILED'):
             print (response['QueryExecution']['Status']['StateChangeReason'])
        
        
    #print (response)    
    return response

## Read from Athena to a Pandas Dataframe
def read_from_athena(sql):
    response=executeQuery(sql)
    return pd.read_csv(response['QueryExecution']['ResultConfiguration']['OutputLocation'])


In [7]:
%%time

sql="""SELECT * FROM "tempdb"."xyzzypq" where order_id = 100"""

response=executeQuery(sql)

Execution ID: 8a1419d3-b644-497f-8525-5b1f24325163
QUEUED
RUNNING
RUNNING
RUNNING
RUNNING
RUNNING
SUCCEEDED
CPU times: user 178 ms, sys: 8.01 ms, total: 186 ms
Wall time: 5.33 s


In [8]:
%%time

sql="""SELECT * FROM 
(Select * from "tempdb"."xyzzypq_p1" 
where partition_key=cast(mod(crc32(to_utf8(cast(order_id as varchar))),20) as varchar))
where order_id = 100"""

response=executeQuery(sql)

Execution ID: 82b76930-09b4-47dc-be15-0b79bba9e125
QUEUED
RUNNING
RUNNING
SUCCEEDED
CPU times: user 39.7 ms, sys: 1.07 ms, total: 40.8 ms
Wall time: 2.27 s


### Conclusion

**We see an improvement of around 50% in query response times for this query.** The subquery can be easily saved to an Amazon Athena view so that users do not have to specify the partition_key in their queries but it is automatically applied in the query.