In [1]:
from pyspark import StorageLevel
from pyspark.ml.fpm import FPGrowth
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, size, udf
from pyspark.sql.types import ArrayType, StringType

from tag_recommender.utils.text import to_snake_case_boosted

In [2]:
spark = (
    SparkSession.builder.appName("FrequentPatternsSpark")
    .config("spark.executor.memory", "8g")
    .config("spark.driver.memory", "8g")
    .config("spark.executor.memoryOverhead", "2g")
    .config("spark.sql.shuffle.partitions", "500")
    .config("spark.driver.maxResultSize", "4g")
    .getOrCreate()
)

24/10/15 00:43:58 WARN Utils: Your hostname, Georges-MacBook-Pro-2.local resolves to a loopback address: 127.0.0.1; using 192.168.66.77 instead (on interface en0)
24/10/15 00:43:58 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/10/15 00:43:58 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


In [3]:
# Load data from CSV (adjust the file path as necessary)
file_path = "../data/processed/train.parquet"
# The train dataset already contains the root_tags as an array form.
# The tags column are also in an array form.
df = spark.read.parquet(file_path, header=True, inferSchema=True)

                                                                                

In [4]:
# Show the schema of the DataFrame
df.printSchema()

root
 |-- type: string (nullable = true)
 |-- lang: string (nullable = true)
 |-- is_reblog: double (nullable = true)
 |-- tags: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- root_tags: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- root_tags_count: long (nullable = true)
 |-- tags_count: long (nullable = true)
 |-- lang_type: string (nullable = true)
 |-- type_bucket: string (nullable = true)
 |-- root_tags_popularity: long (nullable = true)
 |-- tags_popularity: long (nullable = true)
 |-- total_popularity: long (nullable = true)
 |-- tags_count_bucket: string (nullable = true)
 |-- root_tags_count_bucket: string (nullable = true)



In [5]:
# Show the first 5 rows of the DataFrame
df.show(truncate=True, n=10)

+--------+-----+---------+--------------------+--------------------+---------------+----------+---------+-----------+--------------------+---------------+----------------+-----------------+----------------------+
|    type| lang|is_reblog|                tags|           root_tags|root_tags_count|tags_count|lang_type|type_bucket|root_tags_popularity|tags_popularity|total_popularity|tags_count_bucket|root_tags_count_bucket|
+--------+-----+---------+--------------------+--------------------+---------------+----------+---------+-----------+--------------------+---------------+----------------+-----------------+----------------------+
|   photo|fr_FR|      0.0|[one piece, One p...|                  []|              0|         7|    other|      photo|                   0|           4489|            4489|              5-9|                   0-0|
|photoset|en_US|      0.0|[the outsiders, v...|                  []|              0|         4|       en|      other|                   0|          

In [6]:
# Register UDF for splitting tags
@udf(ArrayType(StringType()))
def preprocess_tags(tags):
    return [to_snake_case_boosted(tag) for tag in tags]


# Remove empty strings from the arrays
@udf(ArrayType(StringType()))
def remove_empty(arr):
    return [i for i in arr if i]


# UDF to remove duplicates from each tag array
@udf(ArrayType(StringType()))
def remove_duplicates(arr):
    return sorted(set(arr))

In [7]:
# Apply the split and snake_case conversion function to both 'root_tags' and 'tags' columns
df = df.withColumn("root_tags_array", preprocess_tags(col("root_tags")))
df = df.withColumn("tags_array", preprocess_tags(col("tags")))

In [8]:
# Remove empty lists that are like ['']
df = df.withColumn("root_tags_array", remove_empty(col("root_tags_array")))
df = df.withColumn("tags_array", remove_empty(col("tags_array")))

In [9]:
df.show(truncate=True, n=5)

+--------+-----+---------+--------------------+--------------------+---------------+----------+---------+-----------+--------------------+---------------+----------------+-----------------+----------------------+--------------------+--------------------+
|    type| lang|is_reblog|                tags|           root_tags|root_tags_count|tags_count|lang_type|type_bucket|root_tags_popularity|tags_popularity|total_popularity|tags_count_bucket|root_tags_count_bucket|     root_tags_array|          tags_array|
+--------+-----+---------+--------------------+--------------------+---------------+----------+---------+-----------+--------------------+---------------+----------------+-----------------+----------------------+--------------------+--------------------+
|   photo|fr_FR|      0.0|[one piece, One p...|                  []|              0|         7|    other|      photo|                   0|           4489|            4489|              5-9|                   0-0|                  []|[o

In [10]:
# Step 5: Join all normalized lists into one unified column

# This method performs a SQL-style set union of the rows from both DataFrame objects,
# with no automatic deduplication of elements.
df_unified = df.selectExpr("root_tags_array as tag_arrays").union(
    df.selectExpr("tags_array as tag_arrays")
)

In [11]:
# Show the unified dataset
df_unified.show(truncate=True, n=10)

+--------------------+
|          tag_arrays|
+--------------------+
|                  []|
|                  []|
|                  []|
|[tw_food, the_tin...|
|[sad, sadness, sa...|
|                  []|
|[black_panther, b...|
|[digital_art, otp...|
|     [maps, sky, us]|
|[filmedit, moviee...|
+--------------------+
only showing top 10 rows



In [12]:
df_unified.printSchema()

root
 |-- tag_arrays: array (nullable = true)
 |    |-- element: string (containsNull = true)



In [13]:
# Remove empty arrays from the tag_arrays column
df_unified = df_unified.filter(size(col("tag_arrays")) > 0)

In [14]:
df_unified.show(truncate=True, n=10)

+--------------------+
|          tag_arrays|
+--------------------+
|[tw_food, the_tin...|
|[sad, sadness, sa...|
|[black_panther, b...|
|[digital_art, otp...|
|     [maps, sky, us]|
|[filmedit, moviee...|
|[stray_kids, han_...|
|[ramblings, i_rea...|
|[meme, memes, fre...|
|[and_i_have_no_se...|
+--------------------+
only showing top 10 rows



In [15]:
# remove duplicate tags that were created during normalization
df_unified = df_unified.withColumn("tag_arrays", remove_duplicates(col("tag_arrays")))

In [16]:
# Show the unified dataset
df_unified.show(truncate=False, n=10)

+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|tag_arrays                                                                                                                                                                                                                                                                                                                  |
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|[the_tiny_dragon, the_tiny_dragon_draws, t

In [17]:
# Calculate the length of df_unified
n_baskets = df_unified.count()

# Print the result
print(f"Number of rows in df_unified: {n_baskets}")



Number of rows in df_unified: 2332818


                                                                                

In [18]:
# Repartition the dataset into more partitions
df_unified = df_unified.repartition(500)

In [19]:
# Persist with MEMORY_AND_DISK
df_unified.persist(StorageLevel.MEMORY_AND_DISK)

DataFrame[tag_arrays: array<string>]

In [20]:
support = 200
# Use the unified tags for frequent pattern mining
min_support = support / n_baskets  # Adjust based on dataset size
min_confidence = 0.5  # Adjust based on desired association rules

print(f"Min support threshold: {support}. As a percentage: {min_support}")

Min support threshold: 200. As a percentage: 8.573322050841514e-05


In [21]:
fp_growth = FPGrowth(
    itemsCol="tag_arrays", minSupport=min_support, minConfidence=min_confidence
)
model = fp_growth.fit(df_unified)

24/10/15 00:47:38 WARN FPGrowth: Input data is not cached.                      
                                                                                

In [23]:
# Extract frequent itemsets
frequent_itemsets = model.freqItemsets
frequent_itemsets.show(truncate=False, n=10)

+------------------------------------------------+----+
|items                                           |freq|
+------------------------------------------------+----+
|[parahumans]                                    |254 |
|[original_writing]                              |281 |
|[140123]                                        |575 |
|[140123, louis_tomlinson]                       |502 |
|[melontrack]                                    |740 |
|[melontrack, stray_kids]                        |590 |
|[melontrack, userbeepls]                        |443 |
|[melontrack, userbeepls, stray_kids]            |405 |
|[melontrack, userbeepls, staysource]            |392 |
|[melontrack, userbeepls, staysource, stray_kids]|382 |
+------------------------------------------------+----+
only showing top 10 rows



In [24]:
# Generate association rules
association_rules = model.associationRules
association_rules.show(truncate=False)



+--------------------------------------------------------------------------------------------------------------+---------------------+------------------+------------------+--------------------+
|antecedent                                                                                                    |consequent           |confidence        |lift              |support             |
+--------------------------------------------------------------------------------------------------------------+---------------------+------------------+------------------+--------------------+
|[whimsigoth_room, whimsical, whimsigoth, witchcore, whimsigothic, decor, vintage]                             |[witchy]             |1.0               |3323.102564102564 |8.701921881604136E-5|
|[whimsigoth_room, whimsical, whimsigoth, witchcore, whimsigothic, decor, vintage]                             |[witchy_aesthetic]   |1.0               |3702.885714285714 |8.701921881604136E-5|
|[whimsigoth_room, whimsical, 

24/10/15 00:50:31 WARN Executor: Managed memory leak detected; size = 36841566 bytes, task 0.0 in stage 34.0 (TID 4069)
                                                                                

In [25]:
# Sort rules by confidence
sorted_rules = association_rules.orderBy(col("confidence").desc())
sorted_rules.show(truncate=False, n=10)



+----------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------+----------+------------------+---------------------+
|antecedent                                                                                                                                                            |consequent          |confidence|lift              |support              |
+----------------------------------------------------------------------------------------------------------------------------------------------------------------------+--------------------+----------+------------------+---------------------+
|[whimsigothaesthetic, whimsigoth_fashion, whimsy, witchy_things, whimsical, witch_blog, whimsigoth, witchyvibes, witchcore, witchy, decor, witch, witchcraft, vintage]|[whimsigoth_room]   |1.0       |7854.60606060606  |8.616188661095722E-5 |
|[thelastofusedit, hboedit, tlou

                                                                                

In [26]:
# Filter rules by lift
filtered_rules = sorted_rules.filter(col("lift") >= 1.0)
filtered_rules.show(truncate=False, n=10)



+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------+----------+------------------+---------------------+
|antecedent                                                                                                                                                                                         |consequent                                |confidence|lift              |support              |
+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+------------------------------------------+----------+------------------+---------------------+
|[whimsigoth_room, whimsical, whimsigoth, witchcore, whimsigothic, decor, vintage]                                       

                                                                                

In [27]:
# Count the number of association rules with lift >= 1.0
n_rules = filtered_rules.count()
n_rules

                                                                                

63685196

In [28]:
# Calculate the actual support (frequency) of the association rules
filtered_rules = filtered_rules.withColumn(
    "support_count",
    col("support") * n_baskets,
)

In [29]:
# show the schema
filtered_rules.printSchema()

root
 |-- antecedent: array (nullable = false)
 |    |-- element: string (containsNull = true)
 |-- consequent: array (nullable = false)
 |    |-- element: string (containsNull = true)
 |-- confidence: double (nullable = false)
 |-- lift: double (nullable = true)
 |-- support: double (nullable = false)
 |-- support_count: double (nullable = false)



In [30]:
# calculate the antecedent and consequent sizes
filtered_rules = filtered_rules.withColumn(
    "antecedent_size", size(col("antecedent"))
).withColumn("consequent_size", size(col("consequent")))

In [31]:
# Group by the antecedent size and count the number of rules
filtered_rules.groupBy("antecedent_size").count().show()



+---------------+--------+
|antecedent_size|   count|
+---------------+--------+
|             12| 6167576|
|              1|    8539|
|              6| 3500952|
|              3|  212276|
|              2|   57985|
|             18|    5149|
|              4|  645087|
|              8| 8919069|
|             13| 3471496|
|             11| 8982193|
|             17|   34560|
|              7| 6137228|
|             14| 1584090|
|              9|10767799|
|             10|10800256|
|             16|  163115|
|              5| 1650957|
|             15|  576368|
|             19|     480|
|             20|      21|
+---------------+--------+



                                                                                

In [32]:
# Keep only rules with at most 4 antecedents
filtered_rules = filtered_rules.filter(col("antecedent_size") <= 4)

In [33]:
# Count the number of rules after filtering
filtered_rules.count()

                                                                                

923887

In [35]:
# Sort the antecedents and consequents arrays alphabetically for better readability
filtered_rules = filtered_rules.withColumn(
    "antecedent", udf(sorted, ArrayType(StringType()))(col("antecedent"))
).withColumn(
    "consequent", udf(sorted, ArrayType(StringType()))(col("consequent"))
)

In [36]:
# Show the updated DataFrame
filtered_rules.show(truncate=False, n=10)



+-------------------------------------------------------------+--------------------+----------+------------------+--------------------+-------------+---------------+---------------+
|antecedent                                                   |consequent          |confidence|lift              |support             |support_count|antecedent_size|consequent_size|
+-------------------------------------------------------------+--------------------+----------+------------------+--------------------+-------------+---------------+---------------+
|[byaurore, dailystrangerthings, stranger_things, tuserrachel]|[strangerthingsedit]|1.0       |701.1776375112713 |9.302054425163043E-5|217.0        |4              |1              |
|[byaurore, dailystrangerthings, stranger_things, tuserrachel]|[userbbelcher]      |1.0       |194.32053311120364|9.302054425163043E-5|217.0        |4              |1              |
|[byaurore, dailystrangerthings, stranger_things, tuserrachel]|[noalook]           |1.0   

                                                                                

In [37]:
# Save the association rules to a CSV file
output_path = "../artifacts/models/association_rules_spark.csv"
filtered_rules.toPandas().to_csv(output_path, index=False)

                                                                                

In [38]:
# Stop the Spark session after transformation
spark.stop()