# PrefixSpan

频繁子序列挖掘

数据集：[Groceries dataset](https://www.kaggle.com/datasets/heeraldedhia/groceries-dataset)

In [1]:
from itertools import groupby
from pyspark.sql import (
    SparkSession,
    Row,
    functions as F
)
from pyspark.sql.types import ArrayType, IntegerType, BooleanType
from pyspark.ml.fpm import PrefixSpan

import utils

In [2]:
CSV_PATH = './data'
CSV_FILE = 'Groceries_dataset.csv'

## 1. 一个简单的例子

数据说明：

- `[[1, 2], [3]]`：1 和 2 被认为是一起发生，3 随后发生
- `[[1, 2, 3]]`：1, 2, 3 按顺序发生

In [3]:
# 创建 SparkSession
spark = SparkSession.builder \
    .appName("App") \
    .getOrCreate()

# 将日志级别设为 WARN
spark.sparkContext.setLogLevel("WARN")

sc = spark.sparkContext
spark_df = sc.parallelize([Row(sequence=[[1, 2], [3]]),
                           Row(sequence=[[1, 2, 3, 2, 1, 2]]),
                           Row(sequence=[[1, 2], [5]]),
                           Row(sequence=[[6]])]).toDF()

prefixSpan = PrefixSpan(minSupport=0.5,
                        maxPatternLength=5,
                        maxLocalProjDBSize=32000000)

# Find frequent sequential patterns.
prefixSpan.findFrequentSequentialPatterns(spark_df).show()

Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
24/09/05 16:48:30 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
24/09/05 16:48:30 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
24/09/05 16:48:37 WARN PrefixSpan: Input data is not cached.                    
                                                                                

+--------+----+
|sequence|freq|
+--------+----+
|   [[2]]|   3|
|   [[3]]|   2|
|   [[1]]|   3|
|[[1, 2]]|   3|
+--------+----+



## 2. 子序列挖掘

### 2.1 商品编码

从 CSV 读入数据，并将商品名称 `itemDescription` 转换为商品编码 `itemCode`。

In [4]:
# 从 CSV 中读取数据，存成 Pandas DataFrame
abs_path = utils.gen_abspath(CSV_PATH, CSV_FILE)
df = utils.read_csv(abs_path)

# 将 Pandas DataFrame 加载到 Spark
spark_df = spark.createDataFrame(df)
spark_df.show(10)

24/09/05 16:48:41 WARN GarbageCollectionMetrics: To enable non-built-in garbage collector(s) List(G1 Concurrent GC), users should configure it(them) to spark.eventLog.gcMetrics.youngGenerationGarbageCollectors or spark.eventLog.gcMetrics.oldGenerationGarbageCollectors


+-------------+----------+----------------+
|Member_number|      Date| itemDescription|
+-------------+----------+----------------+
|         1808|21-07-2015|  tropical fruit|
|         2552|05-01-2015|      whole milk|
|         2300|19-09-2015|       pip fruit|
|         1187|12-12-2015|other vegetables|
|         3037|01-02-2015|      whole milk|
|         4941|14-02-2015|      rolls/buns|
|         4501|08-05-2015|other vegetables|
|         3803|23-12-2015|      pot plants|
|         2762|20-03-2015|      whole milk|
|         4119|12-02-2015|  tropical fruit|
+-------------+----------+----------------+
only showing top 10 rows



In [5]:
# 获取 Spark DataFrame 中的 item Description 字段
items = [row.itemDescription for row in spark_df.collect()]
cv = utils.Convert(items)

# 将函数注册为 Spark UDF，返回值类型设为 Integer
encoder_udf = F.udf(cv.encoder, IntegerType())

# 在 DataFrame 上应用此函数
spark_df = spark_df.withColumn("itemCode", encoder_udf(F.col("itemDescription")))

# 将 Date 转换为 yyyy-MM-dd 格式
spark_df = spark_df \
    .withColumn("Date", F.date_format(F.to_date(F.col("Date"), "dd-MM-yyyy"),"yyyy-MM-dd"))

spark_df.show(10)

[Stage 18:>                                                         (0 + 1) / 1]

+-------------+----------+----------------+--------+
|Member_number|      Date| itemDescription|itemCode|
+-------------+----------+----------------+--------+
|         1808|2015-07-21|  tropical fruit|       6|
|         2552|2015-01-05|      whole milk|       0|
|         2300|2015-09-19|       pip fruit|      11|
|         1187|2015-12-12|other vegetables|       1|
|         3037|2015-02-01|      whole milk|       0|
|         4941|2015-02-14|      rolls/buns|       2|
|         4501|2015-05-08|other vegetables|       1|
|         3803|2015-12-23|      pot plants|      72|
|         2762|2015-03-20|      whole milk|       0|
|         4119|2015-02-12|  tropical fruit|       6|
+-------------+----------+----------------+--------+
only showing top 10 rows



                                                                                

### 2.2 生成商品序列

用 `Member_number` 分组，按 `Data` 从小到大的顺序，对 `itemCode` 排序，做成列表。

In [6]:
# 将 Date 和 itemCode 用 `,` 拼接为 item_rec
# 按 Member_number 分组，对 item_rec 去重，并用 `;` 拼接
items_df = spark_df.filter((F.length(F.col("Date")) > 1) & (F.col("itemCode").isNotNull())) \
    .withColumn("item_rec", F.concat_ws(",", F.col("Date"), F.col("itemCode"))) \
    .groupBy("Member_number").agg(F.concat_ws(";", F.collect_set("item_rec")) \
    .alias("item_list"))

items_df.show(3, truncate=False)

[Stage 19:>                                                         (0 + 8) / 8]

+-------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|Member_number|item_list                                                                                                                                                                      |
+-------------+-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+
|1000         |2015-03-15,0;2015-05-27,3;2015-03-15,8;2015-07-24,48;2015-11-25,56;2015-07-24,13;2014-06-24,10;2014-06-24,0;2015-11-25,8;2015-05-27,67;2015-03-15,4;2015-03-15,65;2014-06-24,38|
|1001         |2015-01-20,15;2015-04-14,29;2014-12-12,0;2015-05-02,17;2014-02-07,0;2015-01-20,3;2014-02-07,8;2015-01-20,17;2014-02-07,2;2015-05-02,24;2014-12-12,3;2015-04-14,23              |
|1002         |2015-04-26,6;2014-04-26,0

                                                                                

In [7]:
def sorted_items(items: str,
                 main_delimiter: str = ';',
                 minor_delimiter: str = ','):
    if len(items) == 0:
        return []

    item_list = []
    for kv in items.split(main_delimiter):
        kv_list = kv.split(minor_delimiter)
        if len(kv_list) == 2:
            item_list.append(kv_list)

    sorted_list = sorted(item_list, key=lambda e: e[0])
    sorted_items = [int(e[1]) for e in sorted_list]
    return [[key for key, _ in groupby(sorted_items)]]

sorted_items('2015-03-15,0;2015-05-27,3;2015-03-16,0;2014-03-16,1')

[[1, 0, 3]]

In [8]:
# 将函数注册为 Spark UDF，返回值类型设为 Integer
sort_udf = F.udf(sorted_items, ArrayType(ArrayType(IntegerType())))

# 在 DataFrame 上应用此函数
sorted_items_df = items_df.withColumn("item_list", sort_udf(F.col("item_list"))) \
    .withColumnRenamed("item_list", "sequence")

sorted_items_df.show(10, truncate=False)



+-------------+-----------------------------------------------------------------------------+
|Member_number|sequence                                                                     |
+-------------+-----------------------------------------------------------------------------+
|1000         |[[10, 0, 38, 0, 8, 4, 65, 3, 67, 48, 13, 56, 8]]                             |
|1001         |[[0, 8, 2, 0, 3, 15, 3, 17, 29, 23, 17, 24]]                                 |
|1002         |[[1, 27, 0, 21, 6, 41, 47, 42]]                                              |
|1003         |[[121, 45, 2, 5, 68, 2, 8]]                                                  |
|1004         |[[11, 0, 6, 89, 90, 61, 31, 2, 76, 13, 1, 12, 5, 0, 10, 0, 69, 2, 31, 1, 56]]|
|1005         |[[2, 15, 25]]                                                                |
|1006         |[[0, 105, 7, 2, 113, 12, 140, 17, 31, 0, 14, 28, 64, 2]]                     |
|1008         |[[5, 59, 4, 87, 20, 6, 3, 96, 32, 34, 104]]  

                                                                                

### 2.3 计算频繁子序列 

In [9]:
def freq2support(min_freq, spark_df):
    support = min_freq / spark_df.count()
    return support

In [10]:
min_freq = 50
minSupport = freq2support(min_freq=min_freq, spark_df=sorted_items_df)
print(f'minSupport: {minSupport:.4f}')

prefixSpan = PrefixSpan(minSupport=minSupport,
                        maxPatternLength=5,
                        maxLocalProjDBSize=32000000)

# Find frequent sequential patterns.
pattern_df = prefixSpan.findFrequentSequentialPatterns(sorted_items_df)

pattern_df.show(10)

                                                                                

minSupport: 0.0128


24/09/05 16:48:54 WARN PrefixSpan: Input data is not cached.        (2 + 6) / 8]
                                                                                

+--------+----+
|sequence|freq|
+--------+----+
|  [[59]]| 172|
|  [[42]]| 253|
| [[102]]|  60|
|  [[79]]|  89|
|  [[82]]|  89|
|  [[55]]| 205|
|  [[85]]|  80|
|  [[47]]| 228|
|  [[15]]| 603|
|  [[74]]| 101|
+--------+----+
only showing top 10 rows



筛选大于特定长度的子序列

In [11]:
# 长度至少为 3 的子序列
filtered_patterns = pattern_df.filter(F.col("sequence").getItem(0).isNotNull() \
                                      & (F.col("sequence").getItem(0).getItem(2).isNotNull()) )
filtered_patterns.show(10)

+-------------+----+
|     sequence|freq|
+-------------+----+
|[[0, 20, 22]]|  50|
|[[0, 18, 24]]|  53|
|[[0, 17, 18]]|  54|
|[[0, 16, 20]]|  52|
|[[0, 16, 18]]|  52|
|[[0, 15, 23]]|  57|
|[[0, 15, 19]]|  56|
|[[0, 15, 17]]|  57|
|[[0, 15, 16]]|  68|
|[[0, 14, 25]]|  53|
+-------------+----+
only showing top 10 rows



### 2.4 用子序列回溯原始数据

In [12]:
subsequence = filtered_patterns.limit(1) \
    .select(F.col("sequence").getItem(0)) \
    .withColumnRenamed("sequence[0]", "seq") \
    .collect()[0].seq

print(f'subsequence: {subsequence}')

subsequence: [0, 20, 22]


In [13]:
# 连续匹配
def is_continuous_subsequence(sequence, subseq):
    seq_len = len(sequence)
    sub_len = len(subseq)
    for i in range(seq_len - sub_len + 1):
        if sequence[i:i+sub_len] == subseq:
            return True
    return False

# 非连续匹配
def is_subsequence(sequence, subseq):
    it = iter(sequence)
    return all(elem in it for elem in subseq)

# 注册 UDF
is_subsequence_udf = F.udf(lambda seq: is_subsequence(seq, subsequence), BooleanType())

# 应用 UDF 来筛选 DataFrame
sorted_items_df = sorted_items_df.withColumn("seq", F.col("sequence").getItem(0))
result_df = sorted_items_df.filter(is_subsequence_udf(F.col("seq")))
result_df.select(F.col("Member_number"), F.col("sequence")).show(3, truncate=False)



+-------------+----------------------------------------------------------+
|Member_number|sequence                                                  |
+-------------+----------------------------------------------------------+
|1825         |[[40, 36, 0, 38, 101, 3, 20, 14, 1, 28, 22, 50, 70]]      |
|1916         |[[0, 9, 20, 4, 24, 11, 22]]                               |
|2389         |[[34, 42, 27, 7, 36, 0, 86, 12, 50, 20, 22, 4, 62, 2, 69]]|
+-------------+----------------------------------------------------------+
only showing top 3 rows



                                                                                

In [14]:
# 子序列解码
[cv[e] for e in subsequence]

['whole milk', 'domestic eggs', 'fruit/vegetable juice']