![](images/09_05.jpg)

# 1. Đọc dữ liệu

In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName("instacart").getOrCreate()

In [3]:
data = spark.read.csv("../../local_data/instacart_2017_05_01/order_products__train.csv", header=True, inferSchema=True)

In [4]:
data.count()

1384617

In [5]:
data.show(5)

+--------+----------+-----------------+---------+
|order_id|product_id|add_to_cart_order|reordered|
+--------+----------+-----------------+---------+
|       1|     49302|                1|        1|
|       1|     11109|                2|        1|
|       1|     10246|                3|        0|
|       1|     49683|                4|        0|
|       1|     43633|                5|        1|
+--------+----------+-----------------+---------+
only showing top 5 rows



# 2. Chuẩn dữ liệu

In [6]:
from pyspark.sql.functions import collect_list, col, count, collect_set

In [7]:
data.createOrReplaceTempView('order_products_train')

In [8]:
products = spark.sql("""
    SELECT DISTINCT product_id
    FROM order_products_train
""")

In [9]:
products.count()

39123

In [10]:
raw_data = spark.sql("SELECT * FROM order_products_train")

In [11]:
baskets = raw_data.groupBy('order_id').agg(collect_set('product_id').alias('items'))

In [12]:
baskets.createOrReplaceTempView('baskets')

In [13]:
baskets.show(5, False)

+--------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|order_id|items                                                                                                                                                                                                               |
+--------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|1342    |[30827, 3798, 14966, 21137, 46129, 33081, 13176, 7862]                                                                                                                                                              |
|1591    |[48246, 44116, 24852, 5194, 9130, 48823, 46473, 40310, 32520, 22105, 16900, 27681, 4103, 44008

# 3. Build model

In [14]:
from pyspark.ml.fpm import FPGrowth

In [15]:
fp_growth = FPGrowth(itemsCol='items', minSupport=0.003, minConfidence=0.003)

In [16]:
model = fp_growth.fit(baskets)

# 4. Hiển thị các mẫu phổ biến

In [17]:
model.freqItemsets.show()

+--------------------+-----+
|               items| freq|
+--------------------+-----+
|             [13629]|  772|
|              [5194]|  475|
|             [24852]|18726|
|             [13176]|15480|
|             [35921]|  769|
|             [20345]|  473|
|             [21137]|10894|
|      [21137, 13176]| 3074|
|      [21137, 24852]| 2174|
|             [23165]|  764|
|             [13380]|  473|
|              [7969]|  472|
|             [21903]| 9784|
|      [21903, 21137]| 1639|
|[21903, 21137, 13...|  587|
|      [21903, 13176]| 2236|
|      [21903, 24852]| 2000|
|             [32478]|  763|
|             [47626]| 8135|
|      [47626, 21137]| 1017|
+--------------------+-----+
only showing top 20 rows



# 5. Các luật kết hợp

In [18]:
most_popular_item_in_a_basket = model.transform(baskets)

In [19]:
most_popular_item_in_a_basket.show()

+--------+--------------------+--------------------+
|order_id|               items|          prediction|
+--------+--------------------+--------------------+
|    1342|[30827, 3798, 149...|[21903, 47626, 47...|
|    1591|[48246, 44116, 24...|[21137, 21903, 47...|
|    4519|             [29270]|                  []|
|    4935|             [45190]|                  []|
|    6357|[33731, 14669, 43...|[21137, 21903, 47...|
|   10362|[28522, 43789, 12...|[21137, 47626, 47...|
|   19204|[45255, 37285, 48...|                  []|
|   29601|[2716, 48057, 219...|[21137, 21903, 47...|
|   31035|[40723, 8174, 131...|[21137, 21903, 47...|
|   40011|[27292, 35213, 21...|[21137, 13176, 24...|
|   46266|[38558, 48642, 13...|[47626, 47766, 47...|
|   51607|[41390, 42752, 17...|                  []|
|   58797|[30827, 8803, 326...|[21137, 21903, 47...|
|   61793|[26348, 6184, 433...|[21137, 16797, 39...|
|   67089|[47766, 29388, 21...|[47626, 21137, 47...|
|   70863|[34791, 2618, 173...|      [13176, 2

<hr>

# Sử dụng `product_name` thay vì `product_id` như bên trên

# 1. Đọc dữ liệu

In [20]:
product_data = spark.read.csv("../../local_data/instacart_2017_05_01/products.csv", header=True, inferSchema=True)

In [21]:
product_data.show(3)

+----------+--------------------+--------+-------------+
|product_id|        product_name|aisle_id|department_id|
+----------+--------------------+--------+-------------+
|         1|Chocolate Sandwic...|      61|           19|
|         2|    All-Seasons Salt|     104|           13|
|         3|Robust Golden Uns...|      94|            7|
+----------+--------------------+--------+-------------+
only showing top 3 rows



In [22]:
product_data.createOrReplaceTempView('products')

In [23]:
raw_data_1 = spark.sql("""
    SELECT p.product_name, o.order_id
    FROM products p INNER JOIN order_products_train o
    WHERE o.product_id = p.product_id
""")

In [24]:
baskets_1 = raw_data_1.groupBy('order_id').agg(collect_set('product_name').alias('items'))
baskets_1.show(5)

+--------+--------------------+
|order_id|               items|
+--------+--------------------+
|    1342|[Raw Shrimp, Seed...|
|    1591|[Cracked Wheat, S...|
|    4519|[Beet Apple Carro...|
|    4935|             [Vodka]|
|    6357|[Globe Eggplant, ...|
+--------+--------------------+
only showing top 5 rows



In [25]:
baskets_1.createOrReplaceTempView('baskets')

# 2. Build model

In [26]:
fp_growth_1 = FPGrowth(itemsCol='items', minSupport=0.003, minConfidence=0.003)

In [27]:
model_1 = fp_growth_1.fit(baskets_1)

In [28]:
model_1.freqItemsets.show(5)

+--------------------+-----+
|               items| freq|
+--------------------+-----+
|[Organic Tomato B...|  772|
|[Organic Spinach ...|  475|
|            [Banana]|18726|
|[Bag of Organic B...|15480|
|[Organic Large Gr...|  769|
+--------------------+-----+
only showing top 5 rows



In [29]:
most_popular_item_in_a_basket_1 = model_1.transform(baskets_1)

In [30]:
most_popular_item_in_a_basket_1.head(3)

[Row(order_id=1342, items=['Raw Shrimp', 'Seedless Cucumbers', 'Versatile Stain Remover', 'Organic Strawberries', 'Organic Mandarins', 'Chicken Apple Sausage', 'Pink Lady Apples', 'Bag of Organic Bananas'], prediction=['Organic Baby Spinach', 'Large Lemon', 'Organic Avocado', 'Organic Hass Avocado', 'Strawberries', 'Limes', 'Organic Raspberries', 'Organic Blueberries', 'Organic Whole Milk', 'Organic Cucumber', 'Organic Zucchini', 'Organic Yellow Onion', 'Organic Garlic', 'Seedless Red Grapes', 'Asparagus', 'Organic Grape Tomatoes', 'Organic Red Onion', 'Organic Baby Carrots', 'Honeycrisp Apple', 'Organic Cilantro', 'Organic Lemon', 'Sparkling Water Grapefruit', 'Raspberries', 'Organic Fuji Apple', 'Small Hass Avocado', 'Organic Baby Arugula', 'Organic Large Extra Fancy Fuji Apple', 'Original Hummus', 'Organic Blackberries', 'Organic Gala Apples', 'Fresh Cauliflower', 'Organic Half & Half', 'Michigan Organic Kale', 'Organic Small Bunch Celery', 'Organic Garnet Sweet Potato (Yam)', 'Orga

In [31]:
most_popular_item_in_a_basket_1.printSchema()

root
 |-- order_id: integer (nullable = true)
 |-- items: array (nullable = false)
 |    |-- element: string (containsNull = false)
 |-- prediction: array (nullable = true)
 |    |-- element: string (containsNull = false)



In [32]:
most_popular_item_in_a_basket_1.createOrReplaceGlobalTempView('popular_items')

In [33]:
from pyspark.sql.types import StringType

In [34]:
df_cast = most_popular_item_in_a_basket_1.select('order_id', most_popular_item_in_a_basket_1['items'].cast(StringType()))

In [35]:
df_cast.printSchema()

root
 |-- order_id: integer (nullable = true)
 |-- items: string (nullable = false)



In [36]:
df_cast.head(3)

[Row(order_id=1342, items='[Raw Shrimp, Seedless Cucumbers, Versatile Stain Remover, Organic Strawberries, Organic Mandarins, Chicken Apple Sausage, Pink Lady Apples, Bag of Organic Bananas]'),
 Row(order_id=1591, items='[Cracked Wheat, Strawberry Rhubarb Yoghurt, Organic Bunny Fruit Snacks Berry Patch, Goodness Grapeness Organic Juice Drink, Honey Graham Snacks, Spinach, Granny Smith Apples, Oven Roasted Turkey Breast, Pure Vanilla Extract, Chewy 25% Low Sugar Chocolate Chip Granola, Banana, Original Turkey Burgers Smoke Flavor Added, Twisted Tropical Tango Organic Juice Drink, Navel Oranges, Lower Sugar Instant Oatmeal  Variety, Ultra Thin Sliced Provolone Cheese, Natural Vanilla Ice Cream, Cinnamon Multigrain Cereal, Garlic, Goldfish Pretzel Baked Snack Crackers, Original Whole Grain Chips, Medium Scarlet Raspberries, Lemon Yogurt, Original Patties (100965) 12 Oz Breakfast, Nutty Bars, Strawberry Banana Smoothie, Green Machine Juice Smoothie, Coconut Dreams Cookies, Buttermilk Waffl

In [37]:
df_cast.write.csv('./data/mostPopularItemInABasket.csv')