In [0]:
instacart_df = spark.table("default.instacart")

In [0]:
from pyspark.sql.functions import col

filtered_df = instacart_df.filter(
    col("user_id").isNotNull() &
    col("product_id").isNotNull()
)

In [0]:
from pyspark.sql.functions import count

interaction_df = filtered_df.groupBy(
    "user_id", "product_id"
).agg(
    count("*").alias("interaction")
)


In [0]:
interaction_df.printSchema()

root
 |-- user_id: long (nullable = true)
 |-- product_id: long (nullable = true)
 |-- interaction: long (nullable = false)



In [0]:
interaction_df.show(10)

+-------+----------+-----------+
|user_id|product_id|interaction|
+-------+----------+-----------+
+-------+----------+-----------+



In [0]:
instacart_df.count()

35905260

In [0]:
instacart_df.filter("user_id IS NOT NULL").count()

3421083

In [0]:
instacart_df.filter("product_id IS NOT NULL").count()

32484177

In [0]:
instacart_df.filter(
    "user_id IS NOT NULL AND product_id IS NOT NULL"
).count()

0

In [0]:
instacart_df.filter(
    "user_id IS NOT NULL AND product_id IS NOT NULL"
).show(10)

+----------+------------+--------+-------------+--------+-------+--------+------------+---------+-----------------+----------------------+-----------------+---------+
|product_id|product_name|aisle_id|department_id|order_id|user_id|eval_set|order_number|order_dow|order_hour_of_day|days_since_prior_order|add_to_cart_order|reordered|
+----------+------------+--------+-------------+--------+-------+--------+------------+---------+-----------------+----------------------+-----------------+---------+
+----------+------------+--------+-------------+--------+-------+--------+------------+---------+-----------------+----------------------+-----------------+---------+



In [0]:
from pyspark.sql.functions import col

orders_df = instacart_df.filter(
    col("user_id").isNotNull()
).select(
    "order_id", "user_id"
).dropDuplicates()

In [0]:
order_products_df = instacart_df.filter(
    col("product_id").isNotNull()
).select(
    "order_id", "product_id"
)


In [0]:
user_product_df = orders_df.join(
    order_products_df,
    on="order_id",
    how="inner"
)

In [0]:
from pyspark.sql.functions import count

interaction_df = user_product_df.groupBy(
    "user_id", "product_id"
).agg(
    count("*").alias("interaction")
)

In [0]:
interaction_df.show(10)
interaction_df.count()

+-------+----------+-----------+
|user_id|product_id|interaction|
+-------+----------+-----------+
| 115481|     47049|         24|
|  88362|     36316|          5|
| 133810|     12409|          5|
| 166655|     12144|         11|
|  26183|     33731|          2|
| 202833|     33198|          7|
|  91030|     11790|          3|
| 196148|     35948|          1|
| 190318|     25020|          1|
|  28585|     21938|          4|
+-------+----------+-----------+
only showing top 10 rows


13307953

In [0]:
from pyspark.ml.feature import StringIndexer

In [0]:
user_indexer = StringIndexer(
    inputCol="user_id",
    outputCol="user_index",
    handleInvalid="skip"
)

user_indexed_df = user_indexer.fit(interaction_df).transform(interaction_df)


In [0]:
product_indexer = StringIndexer(
    inputCol="product_id",
    outputCol="product_index",
    handleInvalid="skip"
)

indexed_df = product_indexer.fit(user_indexed_df).transform(user_indexed_df)


In [0]:
final_als_df = indexed_df.select(
    "user_index",
    "product_index",
    "interaction"
)

In [0]:
from pyspark.sql.functions import col

final_als_df = final_als_df.withColumn(
    "user_index", col("user_index").cast("int")
).withColumn(
    "product_index", col("product_index").cast("int")
)

In [0]:
final_als_df.printSchema()
final_als_df.show(10)

root
 |-- user_index: integer (nullable = true)
 |-- product_index: integer (nullable = true)
 |-- interaction: long (nullable = false)

+----------+-------------+-----------+
|user_index|product_index|interaction|
+----------+-------------+-----------+
|      5041|         6446|         24|
|     55927|          750|          5|
|     17857|          274|          5|
|      1960|          918|         11|
|     46524|           51|          2|
|       532|          101|          7|
|      6178|         2511|          3|
|      4911|         1013|          1|
|     73876|         4402|          1|
|    118870|           37|          4|
+----------+-------------+-----------+
only showing top 10 rows


In [0]:
final_als_df.count()

13307953

In [0]:
final_als_df.write.mode("overwrite").saveAsTable(
    "default.als_interactions"
)

In [0]:
spark.sql("SHOW TABLES IN default").show()

+--------+----------------+-----------+
|database|       tableName|isTemporary|
+--------+----------------+-----------+
| default|als_interactions|      false|
| default|       instacart|      false|
+--------+----------------+-----------+

