diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index efa8459d3cdba..2ae6e9c26d86b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -62,13 +62,14 @@ class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) ex @Experimental class FPGrowth private ( private var minSupport: Double, - private var numPartitions: Int) extends Logging with Serializable { + private var numPartitions: Int, + private var mineSequences: Boolean) extends Logging with Serializable { /** * Constructs a default instance with default parameters {minSupport: `0.3`, numPartitions: same - * as the input data}. + * as the input data, mineSequences: `false`}. */ - def this() = this(0.3, -1) + def this() = this(0.3, -1, false) /** * Sets the minimal support level (default: `0.3`). @@ -86,6 +87,14 @@ class FPGrowth private ( this } + /** + * Indicates whether to mine item-sets or item-sequences (default: false, mine item-sets). + */ + def setMineSequences(value: Boolean): this.type = { + this.mineSequences = value + this + } + /** * Computes an FP-Growth model that contains frequent itemsets. * @param data input data set, each element contains a transaction @@ -171,9 +180,12 @@ class FPGrowth private ( itemToRank: Map[Item, Int], partitioner: Partitioner): mutable.Map[Int, Array[Int]] = { val output = mutable.Map.empty[Int, Array[Int]] - // Filter the basket by frequent items pattern and sort their ranks. + // Filter the basket by frequent items pattern val filtered = transaction.flatMap(itemToRank.get) - ju.Arrays.sort(filtered) + if (!this.mineSequences) { // Ignore ordering if not mining sequences + ju.Arrays.sort(filtered) + } + // Generate conditional transactions val n = filtered.length var i = n - 1 while (i >= 0) { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 66ae3543ecc4e..bb1d91114aa23 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -38,6 +38,7 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) + .setMineSequences(false) .run(rdd) assert(model6.freqItemsets.count() === 0) @@ -61,12 +62,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) + .setMineSequences(false) .run(rdd) assert(model2.freqItemsets.count() === 54) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) + .setMineSequences(false) .run(rdd) assert(model1.freqItemsets.count() === 625) } @@ -88,12 +91,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model6 = fpg .setMinSupport(0.9) .setNumPartitions(1) + .setMineSequences(false) .run(rdd) assert(model6.freqItemsets.count() === 0) val model3 = fpg .setMinSupport(0.5) .setNumPartitions(2) + .setMineSequences(false) .run(rdd) assert(model3.freqItemsets.first().items.getClass === Array(1).getClass, "frequent itemsets should use primitive arrays") @@ -109,12 +114,14 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { val model2 = fpg .setMinSupport(0.3) .setNumPartitions(4) + .setMineSequences(false) .run(rdd) assert(model2.freqItemsets.count() === 15) val model1 = fpg .setMinSupport(0.1) .setNumPartitions(8) + .setMineSequences(false) .run(rdd) assert(model1.freqItemsets.count() === 65) }