Permalink
Browse files

Merge remote-tracking branch 'tixxit/2.10.0' into topic/selection-merge

  • Loading branch information...
2 parents 146efe2 + 30a59e8 commit 0297fd753fbeeaeff0de49e6e1965c5d00351221 @non committed Oct 17, 2012
View
139 benchmark/src/main/scala/spire/benchmark/SelectionBenchmarks.scala
@@ -0,0 +1,139 @@
+package spire.benchmark
+
+import scala.{specialized => spec}
+import scala.annotation.tailrec
+import scala.reflect.ClassTag
+
+import scala.util.Random
+import Random._
+
+import spire.algebra._
+import spire.math._
+
+import com.google.caliper.Runner
+import com.google.caliper.SimpleBenchmark
+import com.google.caliper.Param
+
+
+object SelectionBenchmarks extends MyRunner {
+ val cls = classOf[SelectionBenchmarks]
+}
+
+class SelectionBenchmarks extends MyBenchmark with BenchmarkData {
+ implicit val lexicographic:Order[Complex[Double]] = new Order[Complex[Double]] {
+ def eqv(a: Complex[Double], b: Complex[Double]) = a == b
+ def compare(a: Complex[Double], b: Complex[Double]): Int = {
+ if (a.real < b.real) -1
+ else if (a.real > b.real) 1
+ else if (a.imag < b.imag) -1
+ else if (a.imag > b.imag) 1
+ else 0
+ }
+ }
+
+ @Param(Array("3", "4", "6", "9", "13", "18"))
+ var pow: Int = 0
+
+ @Param(Array("int")) //, "long", "float", "double", "complex"))
+ var typ: String = null
+
+ //@Param(Array("random", "sorted", "reversed"))
+ @Param(Array("random"))
+ var layout: String = null
+
+ var is: Array[Int] = null
+ var js: Array[Long] = null
+ var fs: Array[Float] = null
+ var ds: Array[Double] = null
+ var cs: Array[Complex[Double]] = null
+ var cs2: Array[FakeComplex[Double]] = null
+
+ def mkarray[A:ClassTag:Order](size:Int)(init: => A): Array[A] = {
+ val data = Array.ofDim[A](size)
+ var i = 0
+ while (i < size) { data(i) = init; i += 1 }
+ if (layout == "random") return data
+ spire.math.Sorting.sort(data)
+ if (layout == "sorted") data else data.reverse
+ }
+
+ override protected def setUp() {
+ val size = fun.pow(2, pow).toInt
+
+ is = if (typ == "int") mkarray(size)(nextInt) else null
+ js = if (typ == "long") mkarray(size)(nextLong) else null
+ fs = if (typ == "float") mkarray(size)(nextFloat) else null
+ ds = if (typ == "double") mkarray(size)(nextDouble) else null
+ cs = if (typ == "complex") mkarray(size)(Complex(nextDouble, nextDouble)) else null
+ cs2 = if (typ == "complex") cs.map(c => new FakeComplex(c.real, c.imag)) else null
+ }
+
+ def timeSpireQuickSelect(reps:Int) = run(reps) {
+ if (typ == "int") {
+ val arr = is.clone; spire.math.Selection.quickSelect(arr, arr.length / 2); arr.length
+ } else if (typ == "long") {
+ val arr = js.clone; spire.math.Selection.quickSelect(arr, arr.length / 2); arr.length
+ } else if (typ == "float") {
+ val arr = fs.clone; spire.math.Selection.quickSelect(arr, arr.length / 2); arr.length
+ } else if (typ == "double") {
+ val arr = ds.clone; spire.math.Selection.quickSelect(arr, arr.length / 2); arr.length
+ } else if (typ == "complex") {
+ val arr = cs.clone; spire.math.Selection.quickSelect(arr, arr.length / 2); arr.length
+ }
+ }
+
+ def timeSpireLinearSelect(reps:Int) = run(reps) {
+ if (typ == "int") {
+ val arr = is.clone; spire.math.Selection.linearSelect(arr, arr.length / 2); arr.length
+ } else if (typ == "long") {
+ val arr = js.clone; spire.math.Selection.linearSelect(arr, arr.length / 2); arr.length
+ } else if (typ == "float") {
+ val arr = fs.clone; spire.math.Selection.linearSelect(arr, arr.length / 2); arr.length
+ } else if (typ == "double") {
+ val arr = ds.clone; spire.math.Selection.linearSelect(arr, arr.length / 2); arr.length
+ } else if (typ == "complex") {
+ val arr = cs.clone; spire.math.Selection.linearSelect(arr, arr.length / 2); arr.length
+ }
+ }
+}
+
+object Mo5Benchmarks extends MyRunner {
+ val cls = classOf[Mo5Benchmarks]
+}
+
+class Mo5Benchmarks extends MyBenchmark with BenchmarkData {
+ val mo5_hb = new HighBranchingMedianOf5 { }
+ val mo5_m = new MutatingMedianOf5 { }
+
+ var as: Array[Int] = null
+
+ val len = 5000000
+
+ override protected def setUp() {
+ as = new Array[Int](len)
+ (0 until len) foreach { i =>
+ as(i) = nextInt
+ }
+ }
+
+ def timeHBMo5(reps:Int) = run(reps) {
+ val a = as.clone()
+ var i = 0
+ while (i <= len - 5) {
+ mo5_hb.mo5(a, i, 1)
+ i += 5
+ }
+ a.length
+ }
+
+ def timeMMo5(reps:Int) = run(reps) {
+ val a = as.clone()
+ var i = 0
+ while (i <= len - 5) {
+ mo5_m.mo5(a, i, 1)
+ i += 5
+ }
+ a.length
+ }
+}
+
View
357 src/main/scala/spire/math/Selection.scala
@@ -0,0 +1,357 @@
+package spire.math
+
+import scala.reflect.ClassTag
+
+import scala.{specialized => spec}
+
+import scala.annotation.tailrec
+
+
+trait Select {
+ def select[@spec A: Order: ClassTag](data: Array[A], k: Int): Unit
+}
+
+/**
+ * Given a function for finding approximate medians, this will create an exact
+ * median finder.
+ */
+trait SelectLike extends Select {
+
+ def approxMedian[@spec A: Order](data: Array[A], left: Int, right: Int, stride: Int): A
+
+ /**
+ * Puts the k-th element of data, according to some Order, in the k-th
+ * position. All values before k are less than or equal to data(k) and all
+ * values above k are greater than or equal to data(k).
+ *
+ * This is an in-place algorithm and is not stable and it WILL mess up the
+ * order of equal elements.
+ */
+ final def select[@spec A: Order: ClassTag](data: Array[A], k: Int) {
+ select(data, 0, data.length, 1, k)
+ }
+
+ // Copy of InsertSort.sort, but with a stride.
+ final def sort[@spec A](data: Array[A], left: Int, right: Int, stride: Int)(implicit o: Order[A]) {
+ var i = left
+ while (i < right) {
+ val item = data(i)
+ var hole = i
+ while (hole > left && o.gt(data(hole - stride), item)) {
+ data(hole) = data(hole - stride)
+ hole -= stride
+ }
+ data(hole) = item
+ i += stride
+ }
+ }
+
+ @tailrec
+ protected final def select[@spec A: Order](data: Array[A], left: Int, right: Int, stride: Int, k: Int) {
+ val length = (right - left + stride - 1) / stride
+ if (length < 10) {
+ sort(data, left, right, stride)
+
+ } else {
+ val c = partition(data, left, right, stride)(approxMedian(data, left, right, stride))
+ val span = equalSpan(data, left, stride)
+
+ if (c <= k && k < (c + span)) {
+ // Spin.
+ } else if (k < c) {
+ select(data, left, c, stride, k)
+ } else {
+ val newLeft = c + span * stride
+ select(data, newLeft, right, stride, k)
+ }
+ }
+ }
+
+ final def equalSpan[@spec A](data: Array[A], offset: Int, stride: Int)(implicit o: Order[A]): Int = {
+ val m = data(offset)
+ var i = offset + stride
+ var len = 1
+ while (i < data.length && o.eqv(m, data(i))) {
+ i += stride
+ len += 1
+ }
+ len
+ }
+
+ final def partition[@spec A](data: Array[A], left: Int, right: Int, stride: Int)(m: A)(implicit o: Order[A]): Int = {
+ var i = left // Iterator.
+ var j = left // Pointer to first element > m.
+ var k = left // Pointer to end of equal elements.
+ var t = m
+
+ while (i < right) {
+ val cmp = o.compare(data(i), m)
+ if (cmp < 0) {
+ t = data(i); data(i) = data(j); data(j) = t
+ j += stride
+ } else if (cmp == 0) {
+ t = data(i); data(i) = data(j); data(j) = data(k); data(k) = t
+ k += stride
+ j += stride
+ }
+
+ i += stride
+ }
+
+ while (k > left) {
+ j -= stride
+ k -= stride
+ t = data(j)
+ data(j) = data(k)
+ data(k) = t
+ }
+
+ j
+ }
+}
+
+trait MutatingMedianOf5 {
+ final def mo5[@spec A](data: Array[A], offset: Int, stride: Int)(implicit o: Order[A]) {
+ var i0 = offset
+ var i1 = offset + 1 * stride
+ var i2 = offset + 2 * stride
+ var i3 = offset + 3 * stride
+ var i4 = offset + 4 * stride
+ var t = i0
+
+ if (o.gt(data(i3), data(i4))) { t = i3; i3 = i4; i4 = t }
+ if (o.gt(data(i1), data(i2))) { t = i1; i1 = i2; i2 = t }
+ val i = if (o.lt(data(i4), data(i2))) {
+ // Ignore 2. 3 < 4.
+ if (o.lt(data(i1), data(i0))) { t = i0; i0 = i1; i1 = t }
+ if (o.lt(data(i4), data(i1))) {
+ // Ignore 1. 3 < 4
+ if (o.lt(data(i4), data(i0))) i0 else i4
+ } else {
+ // Ignore 4. 0 < 1
+ if (o.lt(data(i3), data(i1))) i1 else i3
+ }
+ } else {
+ // Ignore 4. 1 < 2.
+ if (o.lt(data(i3), data(i0))) { t = i0; i0 = i3; i3 = t }
+ if (o.lt(data(i3), data(i2))) {
+ // Ignore 2. 0 < 3
+ if (o.lt(data(i3), data(i1))) i1 else i3
+ } else {
+ // Ignore 3. 1 < 2
+ if (o.lt(data(i2), data(i0))) i0 else i2
+ }
+ }
+
+ val m = data(i)
+ data(i) = data(offset)
+ data(offset) = m
+ }
+}
+
+trait HighBranchingMedianOf5 {
+
+ // Benchmarks show that this is slightly faster than the version above.
+
+ final def mo5[@spec A](data: Array[A], offset: Int, stride: Int)(implicit o: Order[A]) {
+ val ai1 = data(offset)
+ val ai2 = data(offset + stride)
+ val ai3 = data(offset + 2 * stride)
+ val ai4 = data(offset + 3 * stride)
+ val ai5 = data(offset + 4 * stride)
+
+ val i = if (o.lt(ai1, ai2)) { // i1 < i2
+ if (o.lt(ai3, ai4)) { // i1 < i2, i3 < i4
+ if (o.lt(ai2, ai4)) { // Drop i4
+ if (o.lt(ai3, ai5)) { // i1 < i2, i3 < i5
+ if (o.lt(ai2, ai5)) { // Drop i5
+ if (o.lt(ai2, ai3)) (offset + 2 * stride) else (offset + 1 * stride)
+ } else { // Drop i2
+ if (o.lt(ai1, ai5)) (offset + 4 * stride) else (offset + 0 * stride)
+ }
+ } else { // i1 < i2, i5 < i3
+ if (o.lt(ai2, ai3)) { // Drop i3
+ if (o.lt(ai2, ai5)) (offset + 4 * stride) else (offset + 1 * stride)
+ } else { // Drop i2
+ if (o.lt(ai1, ai3)) (offset + 2 * stride) else (offset + 0 * stride)
+ }
+ }
+ } else { // Drop i2
+ if (o.lt(ai1, ai5)) { // i1 < i5, i3 < i4
+ if (o.lt(ai5, ai4)) { // Drop i4
+ if (o.lt(ai5, ai3)) (offset + 2 * stride) else (offset + 4 * stride)
+ } else { // Drop i5
+ if (o.lt(ai1, ai4)) (offset + 3 * stride) else (offset + 0 * stride)
+ }
+ } else { // i5 < i1, i3 < i4
+ if (o.lt(ai1, ai4)) { // Drop i4
+ if (o.lt(ai1, ai3)) (offset + 2 * stride) else (offset + 0 * stride)
+ } else { // Drop i1
+ if (o.lt(ai5, ai4)) (offset + 3 * stride) else (offset + 4 * stride)
+ }
+ }
+ }
+ } else { // i1 < i2, i4 < i3
+ if (o.lt(ai2, ai3)) { // Drop i3
+ if (o.lt(ai4, ai5)) { // i1 < i2, i4 < i5
+ if (o.lt(ai2, ai5)) { // Drop i5
+ if (o.lt(ai2, ai4)) (offset + 3 * stride) else (offset + 1 * stride)
+ } else { // Drop i2
+ if (o.lt(ai1, ai5)) (offset + 4 * stride) else (offset + 0 * stride)
+ }
+ } else { // i1 < i2, i5 < i4
+ if (o.lt(ai2, ai4)) { // Drop i4
+ if (o.lt(ai2, ai5)) (offset + 4 * stride) else (offset + 1 * stride)
+ } else { // Drop i2
+ if (o.lt(ai1, ai4)) (offset + 3 * stride) else (offset + 0 * stride)
+ }
+ }
+ } else { // Drop i2
+ if (o.lt(ai1, ai5)) { // i1 < i5, i4 < i3
+ if (o.lt(ai5, ai3)) { // Drop i3
+ if (o.lt(ai5, ai4)) (offset + 3 * stride) else (offset + 4 * stride)
+ } else { // Drop i5
+ if (o.lt(ai1, ai3)) (offset + 2 * stride) else (offset + 0 * stride)
+ }
+ } else { // i5 < i1, i4 < i3
+ if (o.lt(ai1, ai3)) { // Drop i3
+ if (o.lt(ai1, ai4)) (offset + 3 * stride) else (offset + 0 * stride)
+ } else { // Drop i1
+ if (o.lt(ai5, ai3)) (offset + 2 * stride) else (offset + 4 * stride)
+ }
+ }
+ }
+ }
+ } else { // i2 < i1
+ if (o.lt(ai3, ai4)) { // i2 < i1, i3 < i4
+ if (o.lt(ai1, ai4)) { // Drop i4
+ if (o.lt(ai3, ai5)) { // i2 < i1, i3 < i5
+ if (o.lt(ai1, ai5)) { // Drop i5
+ if (o.lt(ai1, ai3)) (offset + 2 * stride) else (offset + 0 * stride)
+ } else { // Drop i1
+ if (o.lt(ai2, ai5)) (offset + 4 * stride) else (offset + 1 * stride)
+ }
+ } else { // i2 < i1, i5 < i3
+ if (o.lt(ai1, ai3)) { // Drop i3
+ if (o.lt(ai1, ai5)) (offset + 4 * stride) else (offset + 0 * stride)
+ } else { // Drop i1
+ if (o.lt(ai2, ai3)) (offset + 2 * stride) else (offset + 1 * stride)
+ }
+ }
+ } else { // Drop i1
+ if (o.lt(ai2, ai5)) { // i2 < i5, i3 < i4
+ if (o.lt(ai5, ai4)) { // Drop i4
+ if (o.lt(ai5, ai3)) (offset + 2 * stride) else (offset + 4 * stride)
+ } else { // Drop i5
+ if (o.lt(ai2, ai4)) (offset + 3 * stride) else (offset + 1 * stride)
+ }
+ } else { // i5 < i2, i3 < i4
+ if (o.lt(ai2, ai4)) { // Drop i4
+ if (o.lt(ai2, ai3)) (offset + 2 * stride) else (offset + 1 * stride)
+ } else { // Drop i2
+ if (o.lt(ai5, ai4)) (offset + 3 * stride) else (offset + 4 * stride)
+ }
+ }
+ }
+ } else { // i2 < i1, i4 < i3
+ if (o.lt(ai1, ai3)) { // Drop i3
+ if (o.lt(ai4, ai5)) { // i2 < i1, i4 < i5
+ if (o.lt(ai1, ai5)) { // Drop i5
+ if (o.lt(ai1, ai4)) (offset + 3 * stride) else (offset + 0 * stride)
+ } else { // Drop i1
+ if (o.lt(ai2, ai5)) (offset + 4 * stride) else (offset + 1 * stride)
+ }
+ } else { // i2 < i1, i5 < i4
+ if (o.lt(ai1, ai4)) { // Drop i4
+ if (o.lt(ai1, ai5)) (offset + 4 * stride) else (offset + 0 * stride)
+ } else { // Drop i1
+ if (o.lt(ai2, ai4)) (offset + 3 * stride) else (offset + 1 * stride)
+ }
+ }
+ } else { // Drop i1
+ if (o.lt(ai2, ai5)) { // i2 < i5, i4 < i3
+ if (o.lt(ai5, ai3)) { // Drop i3
+ if (o.lt(ai5, ai4)) (offset + 3 * stride) else (offset + 4 * stride)
+ } else { // Drop i5
+ if (o.lt(ai2, ai3)) (offset + 2 * stride) else (offset + 1 * stride)
+ }
+ } else { // i5 < i2, i4 < i3
+ if (o.lt(ai2, ai3)) { // Drop i3
+ if (o.lt(ai2, ai4)) (offset + 3 * stride) else (offset + 1 * stride)
+ } else { // Drop i2
+ if (o.lt(ai5, ai3)) (offset + 2 * stride) else (offset + 4 * stride)
+ }
+ }
+ }
+ }
+ }
+
+ val m = data(i)
+ data(i) = data(offset)
+ data(offset) = m
+ }
+}
+
+object LinearSelect extends SelectLike with HighBranchingMedianOf5 {
+
+ // We need to guarantee linear time complexity, so we have to get down and
+ // actually find a pivot w/ a good constant fraction of the data points on
+ // one side. This makes this quite a bit slower in the general case (though
+ // not terribly so), but doesn't suffer from bad worst-case behaviour.
+
+ final def approxMedian[@spec A: Order](data: Array[A], left: Int, right: Int, stride: Int): A = {
+ var offset = left
+ var last = left + 4 * stride
+ val nextStride = 5 * stride
+
+ while (last < right) {
+ mo5(data, offset, stride)
+ offset += nextStride
+ last += nextStride
+ }
+
+ val length = (right - left + nextStride - 1) / nextStride
+ val k = left + ((length - 1) / 2) * nextStride
+ select(data, left, right, nextStride, k)
+ data(k)
+ }
+}
+
+object QuickSelect extends SelectLike with HighBranchingMedianOf5 {
+
+ // For large arrays, the partitioning dominates the runtime. Choosing a good
+ // pivot, quickly is essential. So, we have 3 cases, getting slightly smarter
+ // about our pivot as the array grows.
+
+ final def approxMedian[@spec A: Order](data: Array[A], left: Int, right: Int, stride: Int): A = {
+ val length = (right - left + stride - 1) / stride
+
+ if (length >= 5) {
+ val p2stride = stride * (length / 5)
+
+ if (length >= 125) {
+ val p1stride = stride * (length / 25)
+ mo5(data, left, p1stride)
+ mo5(data, left + p2stride, p1stride)
+ mo5(data, left + 2 * p2stride, p1stride)
+ mo5(data, left + 3 * p2stride, p1stride)
+ mo5(data, left + 4 * p2stride, p1stride)
+ }
+
+ mo5(data, left, p2stride)
+ }
+
+ data(left)
+ }
+}
+
+object Selection {
+ final def select[@spec A: Order: ClassTag](data: Array[A], k: Int) =
+ quickSelect(data, k)
+
+ final def linearSelect[@spec A: Order: ClassTag](data: Array[A], k: Int) =
+ LinearSelect.select(data, k)
+
+ final def quickSelect[@spec A: Order: ClassTag](data: Array[A], k: Int) =
+ QuickSelect.select(data, k)
+}
View
63 src/test/scala/spire/math/SelectionTest.scala
@@ -0,0 +1,63 @@
+package spire.math
+
+import scala.{specialized => spec}
+import scala.reflect.ClassTag
+
+import org.scalatest.FunSuite
+import org.scalatest.prop.Checkers
+
+
+trait SelectTest extends FunSuite /* with Checkers */ {
+ def selector: Select
+
+ final def select[@spec A: Order: ClassTag](data: Array[A], k: Int) =
+ selector.select(data, k)
+
+ def shuffle[A: ClassTag](as: Array[A]): Array[A] =
+ scala.util.Random.shuffle(as.toList).toArray
+
+ test("selection in 0-length array") {
+ // Shouldn't throw an exception.
+ select(new Array[Int](0), 0)
+ }
+
+ test("select in 1-length array") {
+ val as = Array(1)
+ select(as, 0)
+ assert(as(0) === 1)
+ }
+
+ test("select from multiple equal elements") {
+ val as = Array(0, 0, 1, 1, 2, 2)
+ select(as, 0); assert(as(0) === 0)
+ select(as, 1); assert(as(1) === 0)
+ select(as, 2); assert(as(2) === 1)
+ select(as, 3); assert(as(3) === 1)
+ select(as, 4); assert(as(4) === 2)
+ select(as, 5); assert(as(5) === 2)
+ }
+
+ test("arbitrary selection") {
+ (1 to 10) foreach { len =>
+ val as = Array.range(0, len)
+
+ (0 until len) foreach { i =>
+ (1 to 5) foreach { _ =>
+ val bs = shuffle(as)
+ val orig = bs.clone()
+ select(bs, i)
+ assert(bs(i) === i, "Select %d on %s failed." format (i, orig.mkString("[ ", ", ", " ]")))
+ }
+ }
+ }
+ }
+}
+
+class LinearSelectTest extends SelectTest {
+ val selector = LinearSelect
+}
+
+class QuickSelectTest extends SelectTest {
+ val selector = QuickSelect
+}
+

0 comments on commit 0297fd7

Please sign in to comment.