In [1]:
number_of_baskets = 10000 #Modify to generate a smaller or larger test dataset as preferred

In [3]:
from pyspark.sql.types import *
from pyspark.sql import SparkSession
spark = SparkSession \
    .builder \
    .enableHiveSupport() \
    .getOrCreate()
spark

At this point (after spark session has been created) the Spark UI should be available at http://127.0.0.1:4040/

In [5]:
import csv
from decimal import Decimal
################################################################################
#Read file distribution_of_number_of_products_in_a_basket.csv that
# provides for each number N the tally of baskets that contain N products, that
# tally as a fraction of the whole, and the cumulative fraction
################################################################################
with open("distribution_of_number_of_products_in_a_basket.csv") as f:
    reader = csv.reader(f)
    distribution_of_number_of_products_in_a_basket = \
    [[int(row[0]), int(row[1]), Decimal(row[2]), Decimal(row[3])] for row in reader]

################################################################################
#Convert to a dataframe
################################################################################
schema = StructType([
    StructField("basket_size", IntegerType(), True),
    StructField("products_tally", IntegerType(), True),
    StructField("fraction_of_baskets", DecimalType(20, 19), True),
    StructField("cumulative_fraction_of_baskets", DecimalType(20,19), True)
])
products_in_basket_df = spark.createDataFrame(distribution_of_number_of_products_in_a_basket, schema)
#Uncomment the following line to view products_in_basket_df
#products_in_basket_df.toPandas()

################################################################################
#Read file baskets_per_customer.csv that
# provides the tally of baskets purchased by that customer, that
# tally as a fraction of the whole, and the cumulative fraction
################################################################################
with open("baskets_per_customer.csv") as f:
    reader = csv.reader(f)
    baskets_per_customer = \
    [[str(row[0]), int(row[1]), Decimal(row[2]), Decimal(row[3])] for row in reader]
################################################################################
#Convert to a dataframe
################################################################################
schema = StructType([
    StructField("customer", StringType(), True),
    StructField("baskets_tally", IntegerType(), True),
    StructField("fraction_of_baskets", DecimalType(20, 19), True),
    StructField("cumulative_fraction_of_baskets", DecimalType(20,19), True)
])
baskets_per_customer_df = spark.createDataFrame(baskets_per_customer, schema)
#baskets_per_customer_df.limit(10).toPandas()

################################################################################
#Read file baskets_per_product.csv that
# provides the tally of baskets containing that product, that
# tally as a fraction of the whole, and the cumulative fraction
################################################################################
with open("baskets_per_product.csv") as f:
    reader = csv.reader(f)
    baskets_per_product = \
    [[str(row[0]), int(row[1]), Decimal(row[2]), Decimal(row[3])] for row in reader]
################################################################################
#Convert to a dataframe
################################################################################
schema = StructType([
    StructField("product", StringType(), True),
    StructField("baskets_tally", IntegerType(), True),
    StructField("fraction_of_baskets", DecimalType(20, 19), True),
    StructField("cumulative_fraction_of_baskets", DecimalType(20,19), True)
])
baskets_per_product_df = spark.createDataFrame(baskets_per_product, schema)
#baskets_per_product_df.limit(10).toPandas()

################################################################################
#Read file baskets_per_store.csv that
# provides the tally of baskets bought at each store, that
# tally as a fraction of the whole, and the cumulative fraction
################################################################################
with open("baskets_per_store.csv") as f:
    reader = csv.reader(f)
    baskets_per_store = \
    [[str(row[0]), int(row[1]), Decimal(row[2]), Decimal(row[3])] for row in reader]
################################################################################
#Convert to a dataframe
################################################################################
schema = StructType([
    StructField("store", StringType(), True),
    StructField("baskets_tally", IntegerType(), True),
    StructField("fraction_of_baskets", DecimalType(20, 19), True),
    StructField("cumulative_fraction_of_baskets", DecimalType(20,19), True)
])
baskets_per_store_df = spark.createDataFrame(baskets_per_store, schema)
#baskets_per_store_df.limit(10).toPandas()

In [6]:
from pyspark.sql.functions import rand
baskets_df = spark.range(0, number_of_baskets).withColumn('rand', rand())
baskets_df.write.format('parquet').mode('overwrite').saveAsTable('raw_baskets')

In [28]:
from pyspark.sql import Window
from pyspark.sql.functions import lag, col, lit, floor, min, max
window = Window.orderBy('baskets_tally')
baskets_per_store_fraction_boundary_df = baskets_per_store_df.select(
    'store',
    lag('cumulative_fraction_of_baskets', 1, 0).over(window)
        .alias('lower_bound'),
    col('cumulative_fraction_of_baskets').alias('upper_bound')
)
baskets_per_customer_fraction_boundary_df = baskets_per_customer_df.select(
    'customer',
    lag('cumulative_fraction_of_baskets', 1, 0).over(window)
        .alias('lower_bound'),
    col('cumulative_fraction_of_baskets').alias('upper_bound')
)

In [51]:
baskets_df = spark.table('raw_baskets')
baskets_df = baskets_df.join(
    baskets_per_store_fraction_boundary_df.hint('broadcast'),
    (
        (baskets_df.rand >= baskets_per_store_fraction_boundary_df.lower_bound) & 
        (baskets_df.rand < baskets_per_store_fraction_boundary_df.upper_bound)
    )
)
baskets_df.write.format('parquet').mode('overwrite').saveAsTable('baskets_with_stores')

In [None]:
"""
Joining solely using 
 (baskets_df.rand >= baskets_per_store_fraction_boundary_df.lower_bound) & 
 (baskets_df.rand < baskets_per_store_fraction_boundary_df.upper_bound)
(i.e. a non-equi-join) would cause the optimiser to choose a nested loop join 
which is wildly inefficient.
By 'bucketizing' the rows on both sides of the join an equi-join can be used and
 the optimiser can choose a more effective join algorithm, likely a hash join
"""
store_join_buckets = 5000 #Fairly arbitrary number. Test different values to find the optimum.
bucket_size = 1.0 / store_join_buckets

baskets_df = spark.table('raw_baskets').withColumn(
    'store_join_bucket', floor(col('rand') / lit(bucket_size)))

"""
The upper_bound and lower_bound per store may straddle bucket boundaries.
To combat that we use a UNION to produce a record for each bucket where
  that is the case.
"""
baskets_per_store_fraction_boundary_df = \
    baskets_per_store_fraction_boundary_df.withColumn(
        'join_bucket', floor(col('upper_bound') / lit(bucket_size))) \
    .union(
        baskets_per_store_fraction_boundary_df.withColumn(
            'join_bucket', floor(col('lower_bound') / lit(bucket_size)))
    ).distinct()

baskets_df = baskets_df.join(baskets_per_store_fraction_boundary_df.hint('broadcast'), 
                (
                    (baskets_df.store_join_bucket == baskets_per_store_fraction_boundary_df.join_bucket) &
                    (baskets_df.rand >= baskets_per_store_fraction_boundary_df.lower_bound) & 
                    (baskets_df.rand < baskets_per_store_fraction_boundary_df.upper_bound)
                )
).select('id', 'rand', 'store')

In [None]:
baskets_df.count()

In [30]:
baskets_per_store_fraction_boundary_df.groupBy('join_bucket') \
    .agg(min(col('lower_bound')),max(col('upper_bound'))).orderBy('join_bucket').toPandas()

Unnamed: 0,join_bucket,min(lower_bound),max(upper_bound)
0,0,0.0,0.4984134837423228
1,1,0.4984134837423228,0.9999999999999996


In [None]:
baskets_df = spark.table('baskets_with_stores')
baskets_df = baskets_df.join(
    baskets_per_customer_fraction_boundary_df,
    (
        (baskets_df.rand >= baskets_per_customer_fraction_boundary_df.lower_bound) & 
        (baskets_df.rand < baskets_per_customer_fraction_boundary_df.upper_bound)
    )
).select(col('id').alias('basket_id'), 'store', 'customer')
baskets_df.count()

In [None]:
################################################################################
#Calculate lower & upper cumulative fraction boundary for each basket size
################################################################################
window = Window.orderBy('tally_of_products_per_basket')
basket_size_fraction_boundary_df = products_in_basket_df.select(
    'tally_of_products_per_basket',
    lag('cumulative_fraction_of_baskets_containing_products_tally', 1, 0).over(window)
        .alias('lower_bound'),
    col('cumulative_fraction_of_baskets_containing_products_tally').alias('upper_bound')
)
################################################################################
#Choose a basket size for each basket based on basket size distribution
################################################################################
baskets_df = baskets_df.join(
    basket_size_fraction_boundary_df.hint('broadcast'),
    (
        (baskets_df.rand >= basket_size_fraction_boundary_df.lower_bound) & 
        (baskets_df.rand < basket_size_fraction_boundary_df.upper_bound)
    )
).select(col('id').alias('basket_id'), col('tally_of_products_per_basket').alias('basket_size'))
#Uncomment the next line to see how many baskets there are for each basket size. It should be 
# roughly proportionally equivalent to the original distribution
#%time baskets_df.groupBy('basket_size').count().orderBy('basket_size').toPandas()