## Source: [Partitioning in Apache Spark](https://medium.com/parrot-prediction/partitioning-in-apache-spark-8134ad840b0)

### A good summary for the default behavior of partitions can be found [here](https://techmagie.wordpress.com/2015/12/19/understanding-spark-partitioning/)

In [None]:
from pyspark import SparkContext

In [None]:
nums = range(0, 10)
print(nums)

### Using a single core

In [None]:
with SparkContext("local") as sc:
    rdd = sc.parallelize(nums)
    
    print("Number of partitions: {}".format(rdd.getNumPartitions()))
    print("Partitioner: {}".format(rdd.partitioner))
    print("Partitions structure: {}".format(rdd.glom().collect()))

### Using 2 cores

In [None]:
with SparkContext("local[2]") as sc:
    rdd = sc.parallelize(nums)
    
    print("Default parallelism: {}".format(sc.defaultParallelism))
    print("Number of partitions: {}".format(rdd.getNumPartitions()))
    print("Partitioner: {}".format(rdd.partitioner))
    print("Partitions structure: {}".format(rdd.glom().collect()))

### Using a single core, but specifying manually the number of partitions

In [None]:
with SparkContext("local") as sc:
    rdd = sc.parallelize(nums, 15)
    
    print("Number of partitions: {}".format(rdd.getNumPartitions()))
    print("Partitioner: {}".format(rdd.partitioner))
    print("Partitions structure: {}".format(rdd.glom().collect()))

### Using 2 cores and specifying a partitioner

In [None]:
with SparkContext("local[2]") as sc:
    rdd = sc.parallelize(nums) \
        .map(lambda el: (el, el)) \
        .partitionBy(2) \
        .persist()
    
    print("Number of partitions: {}".format(rdd.getNumPartitions()))
    print("Partitioner: {}".format(rdd.partitioner))
    print("Partitions structure: {}".format(rdd.glom().collect()))
    
# partition = partitionFunc(key) % num_partitions

In [None]:
import os
os.environ['PYTHONHASHSEED'] = '42'

from pyspark.rdd import portable_hash
num_partitions = 2
for el in nums:
    print("Element: [{}]: {} % {} = partition {}".format(el, portable_hash(el), num_partitions, portable_hash(el) % num_partitions))

### Another example

In [None]:
transactions = [
    {'name': 'Bob', 'amount': 100, 'country': 'United Kingdom'},
    {'name': 'James', 'amount': 15, 'country': 'United Kingdom'},
    {'name': 'Marek', 'amount': 51, 'country': 'Poland'},
    {'name': 'Johannes', 'amount': 200, 'country': 'Germany'},
    {'name': 'Paul', 'amount': 75, 'country': 'Poland'},
]

### Using a partitioner by country

In [None]:
# Dummy implementation assuring that data for each country is in one partition
def country_partitioner(country):
    return hash(country)

# Validate results
num_partitions = 5
print(country_partitioner("Poland") % num_partitions)
print(country_partitioner("Germany") % num_partitions)
print(country_partitioner("United Kingdom") % num_partitions)

In [None]:
with SparkContext("local[2]") as sc:
    rdd = sc.parallelize(transactions) \
        .map(lambda el: (el['country'], el)) \
        .partitionBy(4, country_partitioner)
    
    print("Number of partitions: {}".format(rdd.getNumPartitions()))
    print("Partitioner: {}".format(rdd.partitioner))
    print("Partitions structure: {}".format(rdd.glom().collect()))

In [None]:
# Function for calculating sum of sales for each partition
# Notice that we are getting an iterator.All work is done on one node
def sum_sales(iterator):
    yield sum(transaction[1]['amount'] for transaction in iterator)
    
with SparkContext("local[2]") as sc:
    by_country = sc.parallelize(transactions) \
        .map(lambda el: (el['country'], el)) \
        .partitionBy(3, country_partitioner)
    
    print("Partitions structure: {}".format(by_country.glom().collect()))
    
    # Sum sales in each partition
    sum_amounts = by_country \
        .mapPartitions(sum_sales) \
        .collect()
    
    print("Total sales for each partition: {}".format(sum_amounts))

### Using partitions in DataFrames

In [None]:
from pyspark.sql import SparkSession, Row
with SparkSession.builder \
        .master("local[2]") \
        .config("spark.sql.shuffle.partitions", 50) \
        .getOrCreate() as spark:
    
    rdd = spark.sparkContext \
        .parallelize(transactions) \
        .map(lambda x: Row(**x))
    
    df = spark.createDataFrame(rdd)
    
    print("Number of partitions: {}".format(df.rdd.getNumPartitions()))
    print("Partitioner: {}".format(rdd.partitioner))
    print("Partitions structure: {}".format(df.rdd.glom().collect()))
    
    # Repartition by column
    df2 = df.repartition("country")
    
    print("\nAfter 'repartition()'")
    print("Number of partitions: {}".format(df2.rdd.getNumPartitions()))
    print("Partitioner: {}".format(df2.rdd.partitioner))
    print("Partitions structure: {}".format(df2.rdd.glom().collect()))

### Repartitioning

In [None]:
with SparkSession.builder \
        .master("local[2]") \
        .getOrCreate() as spark:
    
    nums_rdd = spark.sparkContext \
        .parallelize(nums) \
        .map(lambda x: Row(x))
    
    nums_df = spark.createDataFrame(nums_rdd, ['num'])
    
    print("Number of partitions: {}".format(nums_df.rdd.getNumPartitions()))
    print("Partitions structure: {}".format(nums_df.rdd.glom().collect()))
    
    nums_df = nums_df.repartition(4)
    
    print("Number of partitions: {}".format(nums_df.rdd.getNumPartitions()))
    print("Partitions structure: {}".format(nums_df.rdd.glom().collect()))

### Vanishing Partition

In [None]:
with SparkContext("local[2]") as sc:
    rdd = sc.parallelize(nums) \
        .map(lambda el: (el, el)) \
        .partitionBy(2) \
        .persist()
    
    print("Number of partitions: {}".format(rdd.getNumPartitions()))
    print("Partitioner: {}".format(rdd.partitioner))
    print("Partitions structure: {}".format(rdd.glom().collect()))
    
    # Transform with `map()` 
    rdd2 = rdd.map(lambda el: (el[0], el[0]*2))
    
    print("Number of partitions: {}".format(rdd2.getNumPartitions()))
    print("Partitioner: {}".format(rdd2.partitioner))  # We have lost a partitioner
    print("Partitions structure: {}".format(rdd2.glom().collect()))

### Preserving Partition

In [None]:
with SparkContext("local[2]") as sc:
    rdd = sc.parallelize(nums) \
        .map(lambda el: (el, el)) \
        .partitionBy(2) \
        .persist()
    
    print("Number of partitions: {}".format(rdd.getNumPartitions()))
    print("Partitioner: {}".format(rdd.partitioner))
    print("Partitions structure: {}".format(rdd.glom().collect()))
    
    # Use `mapValues()` instead of `map()` 
    rdd2 = rdd.mapValues(lambda x: x * 2)
    
    print("Number of partitions: {}".format(rdd2.getNumPartitions()))
    print("Partitioner: {}".format(rdd2.partitioner))  # We still got partitioner
    print("Partitions structure: {}".format(rdd2.glom().collect()))