# Partition

https://kontext.tech/column/spark/296/data-partitioning-in-spark-pyspark-in-depth-walkthrough  
https://kontext.tech/column/spark/299/data-partitioning-functions-in-spark-pyspark-explained  
https://mungingdata.com/apache-spark/partitionby/

In [1]:
from pyspark.sql import SparkSession

In [2]:
MAX_NUM_CORES = 10

In [3]:
spark = SparkSession.builder \
    .master("spark://IMCHLT276:7077") \
    .config("spark.sql.autoBroadcastJoinThreshold", -1) \
    .config("spark.executor.memory", "2g") \
    .config("spark.executor.cores", "2") \
    .config("spark.cores.max", f"{MAX_NUM_CORES}") \
    .config("spark.local.dir", "/opt/tmp/spark-temp/") \
    .appName("DataSkewness") \
    .getOrCreate()

sc = spark.sparkContext

In [4]:
spark

**Test how partition size affects the output file numbers**

**Test 1** : Number of partition is equal to the cores

In [21]:
df = spark.range(100000)
df.show()

+---+
| id|
+---+
|  0|
|  1|
|  2|
|  3|
|  4|
|  5|
|  6|
|  7|
|  8|
|  9|
| 10|
| 11|
| 12|
| 13|
| 14|
| 15|
| 16|
| 17|
| 18|
| 19|
+---+
only showing top 20 rows



In [11]:
df.rdd.getNumPartitions() == MAX_NUM_CORES

True

In [9]:
! rm -rf /tmp/df_tes/
df.write.parquet("/tmp/df_tes/")
!ls /tmp/df_tes/

**Test 2** : Repartition will affect the number ouput files

In [23]:
df = spark.range(100000)
df = df.repartition(20)
df.rdd.getNumPartitions()

20

In [14]:
! rm -rf /tmp/df_tes/
df.write.parquet("/tmp/df_tes/")
!ls /tmp/df_tes/

_SUCCESS
part-00000-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00001-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00002-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00003-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00004-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00005-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00006-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00007-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00008-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00009-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00010-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00011-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00012-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00013-1876d46c-f683-4ebd-81a7-e99f246916a8-c000.snappy.parquet
part-00014-1876d46c-f683-4ebd-81a7-e99f

**Test 3** : Repartition to 1 and see waht happens?

In [50]:
df = spark.range(10000000)
df = df.repartition(1)
df.rdd.getNumPartitions()

1

In [52]:
! rm -rf /tmp/df_tes/
df.write.parquet("/tmp/df_tes/")
!ls /tmp/df_tes/ -alh

total 39M
drwxrwxr-x  2 mageswarand mageswarand 4.0K Feb 18 13:26 .
drwxrwxrwt 36 root        root        4.0K Feb 18 13:26 ..
-rw-r--r--  1 mageswarand mageswarand    8 Feb 18 13:26 ._SUCCESS.crc
-rw-r--r--  1 mageswarand mageswarand 306K Feb 18 13:26 .part-00000-40d5a79d-70bf-4762-bddb-22a29d64db6e-c000.snappy.parquet.crc
-rw-r--r--  1 mageswarand mageswarand    0 Feb 18 13:26 _SUCCESS
-rw-r--r--  1 mageswarand mageswarand  39M Feb 18 13:26 part-00000-40d5a79d-70bf-4762-bddb-22a29d64db6e-c000.snappy.parquet


**Test 4** : coalesce

In [27]:
df = spark.range(100000)
df = df.coalesce(1)
df.rdd.getNumPartitions()

1

In [28]:
! rm -rf /tmp/df_tes/
df.write.parquet("/tmp/df_tes/")
!ls /tmp/df_tes/ -alh

_SUCCESS  part-00000-2b4e54a3-e11e-4a5a-b53d-96f0e42d9c51-c000.snappy.parquet


**Test 5** : Add a text column and repartition to 1 and see waht happens? Size on local disk doesn't matter. On HDFS this may change

In [6]:
import string, random
import pyspark.sql.functions as F
from pyspark.sql.types import *

letters = string.ascii_lowercase
letters_upper = string.ascii_uppercase

for _i in range(0, 10):
    letters += letters

for _i in range(0, 10):
    letters += letters_upper

print("Number of chars to choose from", len(letters))
sample_string = random.sample(letters, 500)
# print("sample_string", ''.join(sample_string))

def random_string(stringLength=200):
    """Generate a random string of fixed length """
    return ''.join(random.sample(letters, stringLength))

random_string_udf = F.udf(random_string,StringType())

Number of chars to choose from 26884


In [7]:
df = spark.range(1000000)
df = df.withColumn("data", random_string_udf())

In [8]:
df = df.repartition(1, F.col("data"))
df = df.select("data")

In [9]:
df.rdd.getNumPartitions()

1

In [10]:
%time
! rm -rf /tmp/df_tes/
df.write.parquet("/tmp/df_tes/")
!ls /tmp/df_tes/ -alh

CPU times: user 6 µs, sys: 4 µs, total: 10 µs
Wall time: 18.1 µs
total 197M
drwxrwxr-x  2 mageswarand mageswarand 4.0K Feb 18 14:32 .
drwxrwxrwt 36 root        root        4.0K Feb 18 14:32 ..
-rw-r--r--  1 mageswarand mageswarand    8 Feb 18 14:32 ._SUCCESS.crc
-rw-r--r--  1 mageswarand mageswarand 1.6M Feb 18 14:32 .part-00000-90cd6cf3-8eb1-445d-a85e-c9a464f4a094-c000.snappy.parquet.crc
-rw-r--r--  1 mageswarand mageswarand    0 Feb 18 14:32 _SUCCESS
-rw-r--r--  1 mageswarand mageswarand 195M Feb 18 14:32 part-00000-90cd6cf3-8eb1-445d-a85e-c9a464f4a094-c000.snappy.parquet


In [17]:
from pyspark.sql.functions import spark_partition_id

df.groupBy(spark_partition_id()).count().show()

+--------------------+-------+
|SPARK_PARTITION_ID()|  count|
+--------------------+-------+
|                   0|1000000|
+--------------------+-------+



**Test 6** : Read back the stored DF with 1 partition and see how many partitions are there? Equals to number of cores

In [20]:
df = spark.read.parquet("/tmp/df_tes/")
df.rdd.getNumPartitions()

10

**Test 7** Store as many paritions and read it back

In [21]:
df = spark.range(1000000)
df = df.withColumn("data", random_string_udf())
df = df.repartition(32, F.col("data"))
df = df.select("data")

In [22]:
%time
! rm -rf /tmp/df_tes/
df.write.parquet("/tmp/df_tes/")
!ls /tmp/df_tes/ -alh

CPU times: user 3 µs, sys: 2 µs, total: 5 µs
Wall time: 8.82 µs
total 197M
drwxrwxr-x  2 mageswarand mageswarand  12K Feb 18 15:38 .
drwxrwxrwt 36 root        root        4.0K Feb 18 15:38 ..
-rw-r--r--  1 mageswarand mageswarand    8 Feb 18 15:38 ._SUCCESS.crc
-rw-r--r--  1 mageswarand mageswarand  49K Feb 18 15:38 .part-00000-87fc8316-4176-4e93-aa56-2b347163f81a-c000.snappy.parquet.crc
-rw-r--r--  1 mageswarand mageswarand  49K Feb 18 15:38 .part-00001-87fc8316-4176-4e93-aa56-2b347163f81a-c000.snappy.parquet.crc
-rw-r--r--  1 mageswarand mageswarand  49K Feb 18 15:38 .part-00002-87fc8316-4176-4e93-aa56-2b347163f81a-c000.snappy.parquet.crc
-rw-r--r--  1 mageswarand mageswarand  49K Feb 18 15:38 .part-00003-87fc8316-4176-4e93-aa56-2b347163f81a-c000.snappy.parquet.crc
-rw-r--r--  1 mageswarand mageswarand  49K Feb 18 15:38 .part-00004-87fc8316-4176-4e93-aa56-2b347163f81a-c000.snappy.parquet.crc
-rw-r--r--  1 mageswarand mageswarand  49K Feb 18 15:38 .part-00005-87fc8316-4176-4e93-aa56-2

In [23]:
df = spark.read.parquet("/tmp/df_tes/")
df.rdd.getNumPartitions()

11

In [26]:
df.groupBy(spark_partition_id()).count().show()

+--------------------+-----+
|SPARK_PARTITION_ID()|count|
+--------------------+-----+
|                   1|94124|
|                   6|93649|
|                   3|93967|
|                   5|93696|
|                   9|93054|
|                   4|93792|
|                   8|93361|
|                   7|93577|
|                  10|61894|
|                   2|94036|
|                   0|94850|
+--------------------+-----+



**Test 8** : Less number of records and more partitions? 

Spark will try to evenly distribute the data to each partitions. If the total partition number is greater than the actual record count (or RDD size), some partitions will be empty.

In [57]:
df = spark.range(10)
df = df.withColumn("data", random_string_udf())
df = df.repartition(100, F.col("data"))
df = df.select("data")

In [58]:
%time
! rm -rf /tmp/df_tes/
df.write.parquet("/tmp/df_tes/")
!ls /tmp/df_tes/

CPU times: user 6 µs, sys: 3 µs, total: 9 µs
Wall time: 16.2 µs
_SUCCESS
part-00000-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00022-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00024-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00026-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00032-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00036-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00037-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00050-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00058-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00068-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet
part-00070-f52592db-cf7e-40b5-b5e4-623bcb46f17e-c000.snappy.parquet


In [59]:
res = df.groupBy(spark_partition_id()).agg(F.count("data").alias("id")).orderBy("id")

In [60]:
res.show(1000)

+--------------------+---+
|SPARK_PARTITION_ID()| id|
+--------------------+---+
|                  92|  1|
|                  68|  1|
|                  71|  1|
|                  20|  1|
|                  23|  1|
|                  69|  1|
|                  61|  1|
|                  35|  1|
|                   1|  1|
|                  96|  1|
+--------------------+---+



In [61]:
res.count()

10

**Test 9** Default column repartition? Equals to 200

In [70]:
df = spark.range(10000)
df = df.withColumn("data", random_string_udf())
df = df.repartition(F.col("data"))

In [71]:
%time
! rm -rf /tmp/df_tes/
df.write.parquet("/tmp/df_tes/")
!ls /tmp/df_tes/ | wc -l

CPU times: user 6 µs, sys: 3 µs, total: 9 µs
Wall time: 16 µs
201


In [72]:
df.groupBy(spark_partition_id()).count().count()

200

**Test 10** Muli column parition and write partition keys

In [74]:
from datetime import date, timedelta

In [75]:
start_date = date(2019, 1, 1)
data = []
for i in range(0, 50):
    data.append({"Country": "CN", "Date": start_date +
                 timedelta(days=i), "Amount": 10+i})
    data.append({"Country": "AU", "Date": start_date +
                 timedelta(days=i), "Amount": 10+i})

schema = StructType([StructField('Country', StringType(), nullable=False),
                     StructField('Date', DateType(), nullable=False),
                     StructField('Amount', IntegerType(), nullable=False)])

df = spark.createDataFrame(data, schema=schema)
df.show()
print(df.rdd.getNumPartitions())

+-------+----------+------+
|Country|      Date|Amount|
+-------+----------+------+
|     CN|2019-01-01|    10|
|     AU|2019-01-01|    10|
|     CN|2019-01-02|    11|
|     AU|2019-01-02|    11|
|     CN|2019-01-03|    12|
|     AU|2019-01-03|    12|
|     CN|2019-01-04|    13|
|     AU|2019-01-04|    13|
|     CN|2019-01-05|    14|
|     AU|2019-01-05|    14|
|     CN|2019-01-06|    15|
|     AU|2019-01-06|    15|
|     CN|2019-01-07|    16|
|     AU|2019-01-07|    16|
|     CN|2019-01-08|    17|
|     AU|2019-01-08|    17|
|     CN|2019-01-09|    18|
|     AU|2019-01-09|    18|
|     CN|2019-01-10|    19|
|     AU|2019-01-10|    19|
+-------+----------+------+
only showing top 20 rows

10


In [77]:
df = df.withColumn("Year", F.year("Date")).withColumn("Month", F.month("Date")).withColumn("Day", F.dayofmonth("Date"))
df = df.repartition("Year", "Month", "Day", "Country")
print(df.rdd.getNumPartitions())
df.show()

200
+-------+----------+------+----+-----+---+
|Country|      Date|Amount|Year|Month|Day|
+-------+----------+------+----+-----+---+
|     AU|2019-01-21|    30|2019|    1| 21|
|     CN|2019-01-29|    38|2019|    1| 29|
|     AU|2019-01-19|    28|2019|    1| 19|
|     AU|2019-02-07|    47|2019|    2|  7|
|     AU|2019-02-02|    42|2019|    2|  2|
|     AU|2019-02-05|    45|2019|    2|  5|
|     AU|2019-02-08|    48|2019|    2|  8|
|     CN|2019-01-27|    36|2019|    1| 27|
|     CN|2019-01-21|    30|2019|    1| 21|
|     CN|2019-01-25|    34|2019|    1| 25|
|     AU|2019-01-11|    20|2019|    1| 11|
|     CN|2019-02-06|    46|2019|    2|  6|
|     CN|2019-02-19|    59|2019|    2| 19|
|     CN|2019-01-19|    28|2019|    1| 19|
|     AU|2019-02-03|    43|2019|    2|  3|
|     AU|2019-02-09|    49|2019|    2|  9|
|     CN|2019-01-14|    23|2019|    1| 14|
|     AU|2019-01-16|    25|2019|    1| 16|
|     CN|2019-02-16|    56|2019|    2| 16|
|     AU|2019-01-10|    19|2019|    1| 10|
+------

In [78]:
df.write.partitionBy("Year", "Month", "Day", "Country").mode("overwrite").csv("/tmp/df_tes/", header=True)

In [81]:
!tree /tmp/df_tes/

[01;34m/tmp/df_tes/[00m
├── [01;34mYear=2019[00m
│   ├── [01;34mMonth=1[00m
│   │   ├── [01;34mDay=1[00m
│   │   │   ├── [01;34mCountry=AU[00m
│   │   │   │   └── part-00151-9275ac39-b2ca-4aa9-b66d-7290e54ff769.c000.csv
│   │   │   └── [01;34mCountry=CN[00m
│   │   │       └── part-00172-9275ac39-b2ca-4aa9-b66d-7290e54ff769.c000.csv
│   │   ├── [01;34mDay=10[00m
│   │   │   ├── [01;34mCountry=AU[00m
│   │   │   │   └── part-00037-9275ac39-b2ca-4aa9-b66d-7290e54ff769.c000.csv
│   │   │   └── [01;34mCountry=CN[00m
│   │   │       └── part-00112-9275ac39-b2ca-4aa9-b66d-7290e54ff769.c000.csv
│   │   ├── [01;34mDay=11[00m
│   │   │   ├── [01;34mCountry=AU[00m
│   │   │   │   └── part-00026-9275ac39-b2ca-4aa9-b66d-7290e54ff769.c000.csv
│   │   │   └── [01;34mCountry=CN[00m
│   │   │       └── part-00111-9275ac39-b2ca-4aa9-b66d-7290e54ff769.c000.csv
│   │   ├── [01;34mDay=12[00m
│   │   │   ├── [01;34mCountry=AU[00m
│   │   │   │   └── part-00060-9275ac39-b2ca-4aa9

**Read from partitioned data**

Now let’s read the data from the partitioned files with the these criteria:

    Year= 2019
    Month=2
    Day=1
    Country=CN

In [84]:
df = spark.read.csv("/tmp/df_tes/Year=2019/Month=2/Day=1/Country=CN")
print(df.rdd.getNumPartitions()) # only one becaise there is only one record
df.show()

1
+----------+------+
|       _c0|   _c1|
+----------+------+
|      Date|Amount|
|2019-02-01|    41|
+----------+------+



Similarly, we can also query all the data for the second month:

In [86]:
df = spark.read.csv("/tmp/df_tes/Year=2019/Month=2")
print(df.rdd.getNumPartitions())
df.show()

10
+----------+------+---+-------+
|       _c0|   _c1|Day|Country|
+----------+------+---+-------+
|      Date|Amount|  3|     CN|
|2019-02-03|    43|  3|     CN|
|      Date|Amount| 10|     CN|
|2019-02-10|    50| 10|     CN|
|      Date|Amount| 13|     CN|
|2019-02-13|    53| 13|     CN|
|      Date|Amount| 16|     AU|
|2019-02-16|    56| 16|     AU|
|      Date|Amount| 15|     CN|
|2019-02-15|    55| 15|     CN|
|      Date|Amount| 16|     CN|
|2019-02-16|    56| 16|     CN|
|      Date|Amount| 17|     CN|
|2019-02-17|    57| 17|     CN|
|      Date|Amount| 10|     AU|
|2019-02-10|    50| 10|     AU|
|      Date|Amount|  5|     AU|
|2019-02-05|    45|  5|     AU|
|      Date|Amount| 15|     AU|
|2019-02-15|    55| 15|     AU|
+----------+------+---+-------+
only showing top 20 rows



**Use wildcards for partition discovery**

In [88]:
df = spark.read.option("basePath", "/tmp/df_tes/").csv("/tmp/df_tes/Year=*/Month=*/Day=*/Country=CN")
print(df.rdd.getNumPartitions())
df.show()

10
+----------+------+----+-----+---+-------+
|       _c0|   _c1|Year|Month|Day|Country|
+----------+------+----+-----+---+-------+
|      Date|Amount|2019|    2|  3|     CN|
|2019-02-03|    43|2019|    2|  3|     CN|
|      Date|Amount|2019|    1| 17|     CN|
|2019-01-17|    26|2019|    1| 17|     CN|
|      Date|Amount|2019|    2| 10|     CN|
|2019-02-10|    50|2019|    2| 10|     CN|
|      Date|Amount|2019|    1|  3|     CN|
|2019-01-03|    12|2019|    1|  3|     CN|
|      Date|Amount|2019|    1| 24|     CN|
|2019-01-24|    33|2019|    1| 24|     CN|
|      Date|Amount|2019|    2| 13|     CN|
|2019-02-13|    53|2019|    2| 13|     CN|
|      Date|Amount|2019|    1| 25|     CN|
|2019-01-25|    34|2019|    1| 25|     CN|
|      Date|Amount|2019|    1|  1|     CN|
|2019-01-01|    10|2019|    1|  1|     CN|
|      Date|Amount|2019|    1| 21|     CN|
|2019-01-21|    30|2019|    1| 21|     CN|
|      Date|Amount|2019|    2| 15|     CN|
|2019-02-15|    55|2019|    2| 15|     CN|
+-------

We can use wildcards in any part of the path for partition discovery. For example, the following code looks data for month 2 of Country AU:

In [89]:
df = spark.read.option("basePath", "/tmp/df_tes/").csv("/tmp/df_tes/Year=*/Month=2/Day=*/Country=AU")
print(df.rdd.getNumPartitions())
df.show()

10
+----------+------+----+-----+---+-------+
|       _c0|   _c1|Year|Month|Day|Country|
+----------+------+----+-----+---+-------+
|      Date|Amount|2019|    2| 16|     AU|
|2019-02-16|    56|2019|    2| 16|     AU|
|      Date|Amount|2019|    2| 10|     AU|
|2019-02-10|    50|2019|    2| 10|     AU|
|      Date|Amount|2019|    2|  5|     AU|
|2019-02-05|    45|2019|    2|  5|     AU|
|      Date|Amount|2019|    2| 15|     AU|
|2019-02-15|    55|2019|    2| 15|     AU|
|      Date|Amount|2019|    2| 12|     AU|
|2019-02-12|    52|2019|    2| 12|     AU|
|      Date|Amount|2019|    2|  1|     AU|
|2019-02-01|    41|2019|    2|  1|     AU|
|      Date|Amount|2019|    2|  8|     AU|
|2019-02-08|    48|2019|    2|  8|     AU|
|      Date|Amount|2019|    2|  6|     AU|
|2019-02-06|    46|2019|    2|  6|     AU|
|      Date|Amount|2019|    2| 14|     AU|
|2019-02-14|    54|2019|    2| 14|     AU|
|      Date|Amount|2019|    2| 13|     AU|
|2019-02-13|    53|2019|    2| 13|     AU|
+-------

## Data Partitioning Functions  

In [91]:
from pyspark.rdd import portable_hash
from pyspark import Row

In [95]:
# Populate sample data
countries = ("CN", "AU", "US")
data = []
for i in range(1, 13):
    data.append({"ID": i, "Country": countries[i % 3],  "Amount": 10+i})

df = spark.createDataFrame(data)
df.show()

+------+-------+---+
|Amount|Country| ID|
+------+-------+---+
|    11|     AU|  1|
|    12|     US|  2|
|    13|     CN|  3|
|    14|     AU|  4|
|    15|     US|  5|
|    16|     CN|  6|
|    17|     AU|  7|
|    18|     US|  8|
|    19|     CN|  9|
|    20|     AU| 10|
|    21|     US| 11|
|    22|     CN| 12|
+------+-------+---+



In [100]:
def print_partitions(df):
    numPartitions = df.rdd.getNumPartitions()
    print("Total partitions: {}\n".format(numPartitions))
    print("Partitioner: {}\n".format(df.rdd.partitioner))
    df.explain()
    print("\n")
    parts = df.rdd.glom().collect()
    i = 0
    j = 0
    for p in parts:
        print("\nPartition {}:".format(i))
        for r in p:
            print("Row {}:{}".format(j, r))
            j = j+1
        i = i+1

In [101]:
print_partitions(df)

Total partitions: 10

Partitioner: None

== Physical Plan ==
Scan ExistingRDD[Amount#744L,Country#745,ID#746L]



Partition 0:
Row 0:Row(Amount=11, Country='AU', ID=1)

Partition 1:
Row 1:Row(Amount=12, Country='US', ID=2)

Partition 2:
Row 2:Row(Amount=13, Country='CN', ID=3)

Partition 3:
Row 3:Row(Amount=14, Country='AU', ID=4)

Partition 4:
Row 4:Row(Amount=15, Country='US', ID=5)
Row 5:Row(Amount=16, Country='CN', ID=6)

Partition 5:
Row 6:Row(Amount=17, Country='AU', ID=7)

Partition 6:
Row 7:Row(Amount=18, Country='US', ID=8)

Partition 7:
Row 8:Row(Amount=19, Country='CN', ID=9)

Partition 8:
Row 9:Row(Amount=20, Country='AU', ID=10)

Partition 9:
Row 10:Row(Amount=21, Country='US', ID=11)
Row 11:Row(Amount=22, Country='CN', ID=12)


In [103]:
df = df.repartition(3, "Country")

In [104]:
print_partitions(df)

Total partitions: 3

Partitioner: None

== Physical Plan ==
Exchange hashpartitioning(Country#745, 3)
+- Scan ExistingRDD[Amount#744L,Country#745,ID#746L]



Partition 0:

Partition 1:
Row 0:Row(Amount=12, Country='US', ID=2)
Row 1:Row(Amount=13, Country='CN', ID=3)
Row 2:Row(Amount=18, Country='US', ID=8)
Row 3:Row(Amount=15, Country='US', ID=5)
Row 4:Row(Amount=16, Country='CN', ID=6)
Row 5:Row(Amount=19, Country='CN', ID=9)
Row 6:Row(Amount=21, Country='US', ID=11)
Row 7:Row(Amount=22, Country='CN', ID=12)

Partition 2:
Row 8:Row(Amount=11, Country='AU', ID=1)
Row 9:Row(Amount=17, Country='AU', ID=7)
Row 10:Row(Amount=14, Country='AU', ID=4)
Row 11:Row(Amount=20, Country='AU', ID=10)


**You may expect that each partition includes data for each Country but that is not the case. Why? Because repartition function by default uses hash partitioning. For different country code, it may be allocated into the same partition number.**
We can verify this by using the following code to calculate the hash

In [116]:
udf_portable_hash = F.udf(lambda str: portable_hash(str))
df = df.withColumn("Hash#", udf_portable_hash(df.Country))
df = df.withColumn("Partition#", df["Hash#"] % 3)
df.show()

+------+-------+---+--------------------+----------+
|Amount|Country| ID|               Hash#|Partition#|
+------+-------+---+--------------------+----------+
|    13|     CN|  3|-7458853143580063552|      -1.0|
|    19|     CN|  9|-7458853143580063552|      -1.0|
|    12|     US|  2|-8328537658613580243|      -1.0|
|    18|     US|  8|-8328537658613580243|      -1.0|
|    15|     US|  5|-8328537658613580243|      -1.0|
|    16|     CN|  6|-7458853143580063552|      -1.0|
|    21|     US| 11|-8328537658613580243|      -1.0|
|    22|     CN| 12|-7458853143580063552|      -1.0|
|    14|     AU|  4| 6593628092971972691|       0.0|
|    20|     AU| 10| 6593628092971972691|       0.0|
|    11|     AU|  1| 6593628092971972691|       0.0|
|    17|     AU|  7| 6593628092971972691|       0.0|
+------+-------+---+--------------------+----------+



The output shows that each country’s data is now located in the same partition:

In [117]:
countries = ("CN", "AU", "US")
def country_partitioning(k):
    return countries.index(k)
    
udf_country_hash = F.udf(lambda str: country_partitioning(str))

In [119]:
numPartitions = 3
# df = df.partitionBy(numPartitions, country_partitioning)
df = df.withColumn("Hash#", udf_country_hash(df['Country']))
df = df.withColumn("Partition#", df["Hash#"] % numPartitions)
df.orderBy('Country').show()

+------+-------+---+-----+----------+
|Amount|Country| ID|Hash#|Partition#|
+------+-------+---+-----+----------+
|    11|     AU|  1|    1|       1.0|
|    17|     AU|  7|    1|       1.0|
|    14|     AU|  4|    1|       1.0|
|    20|     AU| 10|    1|       1.0|
|    16|     CN|  6|    0|       0.0|
|    22|     CN| 12|    0|       0.0|
|    13|     CN|  3|    0|       0.0|
|    19|     CN|  9|    0|       0.0|
|    15|     US|  5|    2|       2.0|
|    18|     US|  8|    2|       2.0|
|    21|     US| 11|    2|       2.0|
|    12|     US|  2|    2|       2.0|
+------+-------+---+-----+----------+



In [120]:
print_partitions(df)

Total partitions: 3

Partitioner: None

== Physical Plan ==
*(1) Project [Amount#744L, Country#745, ID#746L, pythonUDF1#894 AS Hash##863, (cast(pythonUDF1#894 as double) % 3.0) AS Partition##869]
+- BatchEvalPython [<lambda>(Country#745), <lambda>(Country#745)], [Amount#744L, Country#745, ID#746L, pythonUDF0#893, pythonUDF1#894]
   +- Exchange hashpartitioning(Country#745, 3)
      +- Scan ExistingRDD[Amount#744L,Country#745,ID#746L]



Partition 0:

Partition 1:
Row 0:Row(Amount=12, Country='US', ID=2, Hash#='2', Partition#=2.0)
Row 1:Row(Amount=18, Country='US', ID=8, Hash#='2', Partition#=2.0)
Row 2:Row(Amount=13, Country='CN', ID=3, Hash#='0', Partition#=0.0)
Row 3:Row(Amount=19, Country='CN', ID=9, Hash#='0', Partition#=0.0)
Row 4:Row(Amount=15, Country='US', ID=5, Hash#='2', Partition#=2.0)
Row 5:Row(Amount=16, Country='CN', ID=6, Hash#='0', Partition#=0.0)
Row 6:Row(Amount=21, Country='US', ID=11, Hash#='2', Partition#=2.0)
Row 7:Row(Amount=22, Country='CN', ID=12, Hash#='0', Pa

In [122]:
print_partitions(df.repartition(3, "Partition#"))

Total partitions: 3

Partitioner: None

== Physical Plan ==
Exchange hashpartitioning(Partition##869, 3)
+- *(1) Project [Amount#744L, Country#745, ID#746L, pythonUDF1#898 AS Hash##863, (cast(pythonUDF1#898 as double) % 3.0) AS Partition##869]
   +- BatchEvalPython [<lambda>(Country#745), <lambda>(Country#745)], [Amount#744L, Country#745, ID#746L, pythonUDF0#897, pythonUDF1#898]
      +- Exchange hashpartitioning(Country#745, 3)
         +- Scan ExistingRDD[Amount#744L,Country#745,ID#746L]



Partition 0:
Row 0:Row(Amount=15, Country='US', ID=5, Hash#='2', Partition#=2.0)
Row 1:Row(Amount=21, Country='US', ID=11, Hash#='2', Partition#=2.0)
Row 2:Row(Amount=12, Country='US', ID=2, Hash#='2', Partition#=2.0)
Row 3:Row(Amount=18, Country='US', ID=8, Hash#='2', Partition#=2.0)

Partition 1:
Row 4:Row(Amount=13, Country='CN', ID=3, Hash#='0', Partition#=0.0)
Row 5:Row(Amount=19, Country='CN', ID=9, Hash#='0', Partition#=0.0)
Row 6:Row(Amount=16, Country='CN', ID=6, Hash#='0', Partition#=0.0