Permalink
Browse files

Refactors and adds bin reservoirs for better performance

  • Loading branch information...
ashenfad committed Feb 18, 2013
1 parent 56c647f commit 3adfd7319618589907bb1761d8ebbe29da6c7cc2
View
@@ -404,16 +404,33 @@ the histogram can suffer if the `:freeze` parameter is too small.
```clojure
examples> (time (reduce insert! (create) ex/normal-data))
-"Elapsed time: 391.857 msecs"
+"Elapsed time: 333.5 msecs"
examples> (time (reduce insert! (create :freeze 1024) ex/normal-data))
-"Elapsed time: 99.92 msecs"
+"Elapsed time: 166.9 msecs"
```
# Performance
-Insert time scales `log(n)` with respect to the number of bins in the
-histogram.
+There are two implementations of bin reservoirs (which support the
+`insert!` and `merge!` functions). Either of the two implementations,
+`:tree` and `:array`, can be explicitly selected with the `:reservoir`
+parameter. The `:tree` option is useful for histograms with many bins
+as the insert time scales at `O(log n)` with respect to the # of
+bins. The `:array` option is good for small number of bins since
+inserts are `O(n)` but there's a smaller overhead. If `:reservoir` is
+left unspecified then `:array` is used for histograms with <= 256 bins
+and `:tree` is used for anything larger.
+```clojure
+examples> (time (reduce insert! (create :bins 16 :reservoir :tree)
+ ex/normal-data))
+"Elapsed time: 554.478 msecs"
+examples> (time (reduce insert! (create :bins 16 :reservoir :array)
+ ex/normal-data))
+"Elapsed time: 183.532 msecs"
+```
+
+Insert times using reservoir defaults:
![timing chart]
(https://docs.google.com/spreadsheet/oimg?key=0Ah2oAcudnjP4dG1CLUluRS1rcHVqU05DQ2Z4UVZnbmc&oid=2&zx=mppmmoe214jm)
@@ -7,9 +7,14 @@
Target SimpleTarget NumericTarget
ArrayCategoricalTarget GroupTarget
MapCategoricalTarget SumResult
- MixedInsertException)
+ MixedInsertException
+ Histogram$BinReservoirType)
(java.util HashMap ArrayList)))
+(def ^:private clj-to-reservoir-types
+ {:array Histogram$BinReservoirType/array
+ :tree Histogram$BinReservoirType/tree})
+
(def ^:private clj-to-java-types
{:none Histogram$TargetType/none
:numeric Histogram$TargetType/numeric
@@ -30,11 +35,14 @@
:group-types - A sequence of types (:numeric or :categorical) that
describing a group target.
:freeze - After this # of inserts, bin locations will 'freeze',
- improving the performance of future inserts."
- [& {:keys [bins gap-weighted? categories group-types freeze]
+ improving the performance of future inserts.
+ :reservoir - Selects the bin reservoir type (:array or :tree).
+ Defaults to :array for <= 256 bins, otherwise :tree."
+ [& {:keys [bins gap-weighted? categories group-types freeze reservoir]
:or {bins 64 gap-weighted? false}}]
- (let [group-types (seq (map clj-to-java-types group-types))]
- (Histogram. bins gap-weighted? categories group-types freeze)))
+ (let [group-types (seq (map clj-to-java-types group-types))
+ reservoir (clj-to-reservoir-types reservoir)]
+ (Histogram. bins gap-weighted? categories group-types freeze reservoir)))
(defn histogram?
"Returns true if the input is a histogram."
@@ -0,0 +1,141 @@
+/**
+ * Copyright 2013 BigML
+ * Licensed under the Apache License, Version 2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package com.bigml.histogram;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+
+/**
+ * This class implements bin operations (insertions, merges, etc) for a histogram.
+ * This implementation is best for histograms with a small (<=256) number of bins.
+ * It uses an ArrayList to give O(N) insert performance with regard to the number
+ * of bins in the histogram. For histograms with more bins, the TreeBinReservoir
+ * class offers faster insert performance.
+ */
+public class ArrayBinReservoir <T extends Target> extends BinReservoir<T> {
+
+ public ArrayBinReservoir(int maxBins, boolean weightGaps, Long freezeThreshold) {
+ super(maxBins, weightGaps, freezeThreshold);
+ _bins = new ArrayList<Bin<T>>();
+ }
+
+ @Override
+ public void insert(Bin<T> bin) {
+ addTotalCount(bin);
+ int index = Collections.binarySearch(_bins, bin);
+ if (index >= 0) {
+ _bins.get(index).sumUpdate(bin);
+ } else {
+ if (isFrozen()) {
+ int prevIndex = Math.abs(index) - 2;
+ int nextIndex = prevIndex + 1;
+ double prevDist = (prevIndex >= 0) ?
+ bin.getMean() - _bins.get(prevIndex).getMean() : Double.MAX_VALUE;
+ double nextDist = (nextIndex < _bins.size()) ?
+ _bins.get(nextIndex).getMean() - bin.getMean() : Double.MAX_VALUE;
+ if (prevDist < nextDist) {
+ _bins.get(prevIndex).sumUpdate(bin);
+ } else {
+ _bins.get(nextIndex).sumUpdate(bin);
+ }
+ } else {
+ _bins.add(Math.abs(index) - 1, bin);
+ }
+ }
+ }
+
+ @Override
+ public Bin<T> first() {
+ return _bins.get(0);
+ }
+
+ @Override
+ public Bin<T> last() {
+ return _bins.get(_bins.size() - 1);
+ }
+
+ @Override
+ public Bin<T> get(double p) {
+ int index = Collections.binarySearch(_bins, new Bin(p, 0, null));
+ if (index >= 0) {
+ return _bins.get(index);
+ } else {
+ return null;
+ }
+ }
+
+ @Override
+ public Bin<T> floor(double p) {
+ int index = Collections.binarySearch(_bins, new Bin(p, 0, null));
+ if (index >= 0) {
+ return _bins.get(index);
+ } else {
+ index = Math.abs(index) - 2;
+ return (index >= 0) ? _bins.get(index) : null;
+ }
+ }
+
+ @Override
+ public Bin<T> ceiling(double p) {
+ int index = Collections.binarySearch(_bins, new Bin(p, 0, null));
+ if (index >= 0) {
+ return _bins.get(index);
+ } else {
+ index = Math.abs(index) - 1;
+ return (index < _bins.size()) ? _bins.get(index) : null;
+ }
+ }
+
+ @Override
+ public Bin<T> lower(double p) {
+ int index = Collections.binarySearch(_bins, new Bin(p, 0, null));
+ if (index >= 0) {
+ index--;
+ return (index >= 0) ? _bins.get(index) : null;
+ } else {
+ index = Math.abs(index) - 2;
+ return (index >= 0) ? _bins.get(index) : null;
+ }
+ }
+
+ @Override
+ public Bin<T> higher(double p) {
+ int index = Collections.binarySearch(_bins, new Bin(p, 0, null));
+ if (index >= 0) {
+ index++;
+ return (index < _bins.size()) ? _bins.get(index) : null;
+ } else {
+ index = Math.abs(index) - 1;
+ return (index < _bins.size()) ? _bins.get(index) : null;
+ }
+ }
+
+ @Override
+ public Collection<Bin<T>> getBins() {
+ return _bins;
+ }
+
+ @Override
+ public void merge() {
+ while (_bins.size() > getMaxBins()) {
+ int minGapIndex = -1;
+ double minGap = Double.MAX_VALUE;
+ for (int i = 0; i < _bins.size() - 1; i++) {
+ double gap = gapWeight(_bins.get(i), _bins.get(i + 1));
+ if (minGap > gap) {
+ minGap = gap;
+ minGapIndex = i;
+ }
+ }
+ Bin<T> prev = _bins.get(minGapIndex);
+ Bin<T> next = _bins.remove(minGapIndex + 1);
+ _bins.set(minGapIndex, prev.combine(next));
+ }
+ }
+
+ private ArrayList<Bin<T>> _bins;
+}
@@ -8,7 +8,7 @@
import java.text.DecimalFormat;
import org.json.simple.JSONArray;
-public class Bin<T extends Target> {
+public class Bin<T extends Target> implements Comparable<Bin> {
public Bin(double mean, double count, T target) {
/* Hack to avoid Java's negative zero */
@@ -100,4 +100,7 @@ public int hashCode() {
private final double _mean;
private double _count;
+ public int compareTo(Bin o) {
+ return Double.compare(getMean(), o.getMean());
+ }
}
@@ -0,0 +1,64 @@
+/**
+ * Copyright 2013 BigML
+ * Licensed under the Apache License, Version 2.0
+ * http://www.apache.org/licenses/LICENSE-2.0
+ */
+package com.bigml.histogram;
+
+import java.util.Collection;
+
+public abstract class BinReservoir<T extends Target> {
+ public BinReservoir(int maxBins, boolean weightGaps, Long freezeThreshold) {
+ _maxBins = maxBins;
+ _weightGaps = weightGaps;
+ _freezeThreshold = freezeThreshold;
+ _totalCount = 0;
+ }
+
+ public int getMaxBins() {
+ return _maxBins;
+ }
+
+ public boolean isWeightGaps() {
+ return _weightGaps;
+ }
+
+ public Long getFreezeThreshold() {
+ return _freezeThreshold;
+ }
+
+ public boolean isFrozen() {
+ return _freezeThreshold != null && _totalCount > _freezeThreshold;
+ }
+ public long getTotalCount() {
+ return _totalCount;
+ }
+
+ public void addTotalCount(Bin<T> bin) {
+ _totalCount += bin.getCount();
+ }
+
+ public abstract void insert(Bin<T> bin);
+ public abstract Bin<T> first();
+ public abstract Bin<T> last();
+ public abstract Bin<T> get(double p);
+ public abstract Bin<T> floor(double p);
+ public abstract Bin<T> ceiling(double p);
+ public abstract Bin<T> higher(double p);
+ public abstract Bin<T> lower(double p);
+ public abstract Collection<Bin<T>> getBins();
+ public abstract void merge();
+
+ protected double gapWeight(Bin<T> prev, Bin<T> next) {
+ double diff = next.getMean() - prev.getMean();
+ if (isWeightGaps()) {
+ diff *= Math.log(Math.E + Math.min(prev.getCount(), next.getCount()));
+ }
+ return diff;
+ }
+
+ private final int _maxBins;
+ private final boolean _weightGaps;
+ private final Long _freezeThreshold;
+ private long _totalCount;
+}
Oops, something went wrong.

0 comments on commit 3adfd73

Please sign in to comment.