Permalink
Browse files

Added sorting by key for pair RDDs

  • Loading branch information...
1 parent 98f008b commit e93f6226658cc18fd29995e72f73c0d9246682d3 Antonio committed Feb 11, 2012
@@ -359,6 +359,29 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](
def getValueClass() = implicitly[ClassManifest[V]].erasure
}
+ class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
+ self: RDD[(K, V)])
+ extends Logging
+ with Serializable {
+
+ def sortByKey(ascending: Boolean = true): RDD[(K,V)] = {
+ val rangePartitionedRDD = self.partitionBy(new RangePartitioner(self.splits.size, self, ascending))
+ new SortedRDD(rangePartitionedRDD, ascending)
+ }
+ }
+
+ class SortedRDD[K <% Ordered[K], V](prev: RDD[(K, V)], ascending: Boolean)
+ extends RDD[(K, V)](prev.context) {
+
+ override def splits = prev.splits
+ override val partitioner = prev.partitioner
+ override val dependencies = List(new OneToOneDependency(prev))
+ override def compute(split: Split) = {
+ prev.iterator(split).toList
+ .sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1).iterator
+ }
+ }
+
class MappedValuesRDD[K, V, U](prev: RDD[(K, V)], f: V => U) extends RDD[(K, U)](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
@@ -23,4 +23,41 @@ class HashPartitioner(partitions: Int) extends Partitioner {
case _ =>
false
}
-}
+}
+
+class RangePartitioner[K <% Ordered[K],V](partitions: Int, rdd: RDD[(K,V)], ascending: Boolean = true)
+ extends Partitioner {
+
+ def numPartitions = partitions
+
+ val rddSize = rdd.count()
+ val maxSampleSize = partitions*10.0
+ val frac = 1.0.min(maxSampleSize / rddSize)
+ val rddSample = rdd.sample(true, frac, 1).collect.toList
+ .sortWith((x, y) => if (ascending) x._1 < y._1 else x._1 > y._1)
+ .map(_._1)
+ val bucketSize:Float = rddSample.size / partitions
+ val rangeBounds = rddSample.zipWithIndex.filter(_._2 % bucketSize == 0)
+ .map(_._1).slice(1, partitions)
+
+ def getPartition(key: Any): Int = {
+ key match {
+ case k:K => {
+ val p =
+ rangeBounds.zipWithIndex.foldLeft(0) {
+ case (part, (bound, index)) =>
+ if (k > bound) index + 1 else part
+ }
+ if (ascending) p else numPartitions-1-p
+ }
+ case _ => 0
+ }
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case r: RangePartitioner[K,V] =>
+ r.numPartitions == numPartitions
+ case _ => false
+ }
+}
+
@@ -349,10 +349,13 @@ object SparkContext {
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
-
+
implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
+ implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
+ new OrderedRDDFunctions(rdd)
+
// Implicit conversions to common Writable types, for saveAsSequenceFile
implicit def intToIntWritable(i: Int) = new IntWritable(i)

0 comments on commit e93f622

Please sign in to comment.