Skip to content

Commit

Permalink
Merge pull request #4590 from h2oai/michalk_fix-na-splits
Browse files Browse the repository at this point in the history
PUBDEV-7517: GBM can fail when tree node has just one single value and NAs
  • Loading branch information
Michal Kurka committed May 8, 2020
2 parents a900d14 + 595ef4d commit e7336ff
Show file tree
Hide file tree
Showing 8 changed files with 80 additions and 41 deletions.
41 changes: 26 additions & 15 deletions h2o-algos/src/main/java/hex/tree/DHistogram.java
Expand Up @@ -51,6 +51,8 @@ public final class DHistogram extends Iced {
public char _nbin; // Bin count (excluding NA bucket)
public double _step; // Linear interpolation step per bin
public final double _min, _maxEx; // Conservative Min/Max over whole collection. _maxEx is Exclusive.
public final boolean _initNA; // Does the initial histogram have any NAs?
// Needed to correctly count actual number of bins of the initial histogram.
public final double _pred1; // We calculate what would be the SE for a possible fallback predictions _pred1
public final double _pred2; // and _pred2. Currently used for min-max bounds in monotonic GBMs.

Expand Down Expand Up @@ -191,7 +193,7 @@ public StepOutOfRangeException(String name, double step, int xbins, double maxEx
super("column=" + name + " leads to invalid histogram(check numeric range) -> [max=" + maxEx + ", min = " + min + "], step= " + step + ", xbin= " + xbins);
}
}
public DHistogram(String name, final int nbins, int nbins_cats, byte isInt, double min, double maxEx,
public DHistogram(String name, final int nbins, int nbins_cats, byte isInt, double min, double maxEx, boolean initNA,
double minSplitImprovement, SharedTreeModel.SharedTreeParameters.HistogramType histogramType, long seed, Key globalQuantilesKey,
Constraints cs) {
assert nbins >= 1;
Expand All @@ -215,10 +217,11 @@ public DHistogram(String name, final int nbins, int nbins_cats, byte isInt, doub
}
_isInt = isInt;
_name = name;
_min=min;
_maxEx=maxEx; // Set Exclusive max
_min = min;
_maxEx = maxEx; // Set Exclusive max
_min2 = Double.MAX_VALUE; // Set min/max to outer bounds
_maxIn= -Double.MAX_VALUE;
_initNA = initNA;
_minSplitImprovement = minSplitImprovement;
_histoType = histogramType;
_seed = seed;
Expand Down Expand Up @@ -282,17 +285,20 @@ public double binAt( int b ) {
public int nbins() { return _nbin; }
// actual number of bins (possibly including NA bin)
public int actNBins() {
return nbins() + (hasNABin() ? 1: 0);
return nbins() + (hasNABin() ? 1 : 0);
}
public double bins(int b) { return w(b); }

private boolean hasNABin() {
return _vals != null && wNA() > 0;
public boolean hasNABin() {
if (_vals == null)
return _initNA; // we are in the initial histogram (and didn't see the data yet)
else
return wNA() > 0;
}

// Big allocation of arrays
public void init() { init(null);}
public void init(double [] vals) {
public void init(final double[] vals) {
assert _vals == null;
if (_histoType==SharedTreeModel.SharedTreeParameters.HistogramType.Random) {
// every node makes the same split points
Expand Down Expand Up @@ -344,9 +350,10 @@ else if (_histoType== SharedTreeModel.SharedTreeParameters.HistogramType.Quantil
}
}
}
//otherwise AUTO/UniformAdaptive
assert(_nbin>0);
_vals = vals == null?MemoryManager.malloc8d(_vals_dim*_nbin+_vals_dim):vals;
// otherwise AUTO/UniformAdaptive
_vals = vals == null ? MemoryManager.malloc8d(_vals_dim * _nbin + _vals_dim) : vals;
// this always holds: _vals != null
assert _nbin > 0;
}

// Add one row to a bin found via simple linear interpolation.
Expand Down Expand Up @@ -403,9 +410,11 @@ public static DHistogram[] initialHist(Frame fr, int ncols, int nbins, DHistogra
final double maxIn = v.isCategorical() ? v.domain().length-1 : Math.min(v.max(), Double.MAX_VALUE); // inclusive vector max
final double maxEx = v.isCategorical() ? v.domain().length : find_maxEx(maxIn,v.isInt()?1:0); // smallest exclusive max
final long vlen = v.length();
final long nacnt = v.naCnt();
try {
hs[c] = v.naCnt() == vlen || v.isConst(true) ?
null : make(fr._names[c], nbins, (byte) (v.isCategorical() ? 2 : (v.isInt() ? 1 : 0)), minIn, maxEx, seed, parms, globalQuantilesKey[c], cs);
byte type = (byte) (v.isCategorical() ? 2 : (v.isInt() ? 1 : 0));
hs[c] = nacnt == vlen || v.isConst(true) ?
null : make(fr._names[c], nbins, type, minIn, maxEx, nacnt > 0, seed, parms, globalQuantilesKey[c], cs);
} catch(StepOutOfRangeException e) {
hs[c] = null;
Log.warn("Column " + fr._names[c] + " with min = " + v.min() + ", max = " + v.max() + " has step out of range (" + e.getMessage() + ") and is ignored.");
Expand All @@ -415,14 +424,16 @@ public static DHistogram[] initialHist(Frame fr, int ncols, int nbins, DHistogra
return hs;
}

public static DHistogram make(String name, final int nbins, byte isInt, double min, double maxEx, long seed, SharedTreeModel.SharedTreeParameters parms, Key globalQuantilesKey, Constraints cs) {
return new DHistogram(name,nbins, parms._nbins_cats, isInt, min, maxEx, parms._min_split_improvement, parms._histogram_type, seed, globalQuantilesKey, cs);
public static DHistogram make(String name, final int nbins, byte isInt, double min, double maxEx, boolean hasNAs,
long seed, SharedTreeModel.SharedTreeParameters parms, Key globalQuantilesKey, Constraints cs) {
return new DHistogram(name, nbins, parms._nbins_cats, isInt, min, maxEx, hasNAs,
parms._min_split_improvement, parms._histogram_type, seed, globalQuantilesKey, cs);
}

// Pretty-print a histogram
@Override public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(_name).append(":").append(_min).append("-").append(_maxEx).append(" step=" + (1 / _step) + " nbins=" + nbins() + " isInt=" + _isInt);
sb.append(_name).append(":").append(_min).append("-").append(_maxEx).append(" step=" + (1 / _step) + " nbins=" + nbins() + " actNBins=" + actNBins() + " isInt=" + _isInt);
if( _vals != null ) {
for(int b = 0; b< _nbin; b++ ) {
sb.append(String.format("\ncnt=%f, [%f - %f], mean/var=", w(b),_min+b/_step,_min+(b+1)/_step));
Expand Down
13 changes: 9 additions & 4 deletions h2o-algos/src/main/java/hex/tree/DTree.java
Expand Up @@ -286,8 +286,12 @@ public DHistogram[] nextLevelHistos(DHistogram currentHistos[], int way, double
if( h._isInt > 0 && !(min+1 < maxEx ) )
continue; // This column will not split again
assert min < maxEx && adj_nbins > 1 : ""+min+"<"+maxEx+" nbins="+adj_nbins;

nhists[j] = DHistogram.make(h._name, adj_nbins, h._isInt, min, maxEx, h._seed*0xDECAF+(way+1), parms, h._globalQuantilesKey, cs);

// only count NAs if we have any going our way (note: NAvsREST doesn't build a histo for the NA direction)
final boolean hasNAs = (_nasplit == DHistogram.NASplitDir.NALeft && way == 0 ||
_nasplit == DHistogram.NASplitDir.NARight && way == 1) && h.hasNABin();

nhists[j] = DHistogram.make(h._name, adj_nbins, h._isInt, min, maxEx, hasNAs,h._seed*0xDECAF+(way+1), parms, h._globalQuantilesKey, cs);
cnt++; // At least some chance of splitting
}
return cnt == 0 ? null : nhists;
Expand Down Expand Up @@ -328,7 +332,8 @@ public UndecidedNode( DTree tree, int pid, DHistogram[] hs, Constraints cs ) {
// Can return null for 'all columns'.
public int[] scoreCols() {
DTree tree = _tree;
if (tree.actual_mtries() == _hs.length && tree._mtrys_per_tree == _hs.length) return null;
if (tree.actual_mtries() == _hs.length && tree._mtrys_per_tree == _hs.length)
return null;

// per-tree pre-selected columns
int[] activeCols = tree._cols;
Expand All @@ -342,7 +347,7 @@ public int[] scoreCols() {
int idx = activeCols[i];
assert(idx == i || tree._mtrys_per_tree < _hs.length);
if( _hs[idx]==null ) continue; // Ignore not-tracked cols
assert _hs[idx]._min < _hs[idx]._maxEx && _hs[idx].nbins() > 1 : "broken histo range "+_hs[idx];
assert _hs[idx]._min < _hs[idx]._maxEx && _hs[idx].actNBins() > 1 : "broken histo range "+_hs[idx];
cols[len++] = idx; // Gather active column
}
// Log.info("These columns can be split: " + Arrays.toString(Arrays.copyOfRange(cols, 0, len)));
Expand Down
4 changes: 2 additions & 2 deletions h2o-algos/src/test/java/hex/tree/DHistogramTest.java
Expand Up @@ -37,7 +37,7 @@ public void initCachesZeroPosition() {
DKV.put(hq);
Scope.track_generic(hq);

DHistogram histo = new DHistogram("test", 20, 1024, (byte) 1, -1, 2, -0.001,
DHistogram histo = new DHistogram("test", 20, 1024, (byte) 1, -1, 2, false, -0.001,
SharedTreeModel.SharedTreeParameters.HistogramType.QuantilesGlobal, 42L, hq._key, null);
histo.init();

Expand All @@ -59,7 +59,7 @@ public void findBinForNegativeZero() {
DKV.put(hq);
Scope.track_generic(hq);

DHistogram histo = new DHistogram("test", 20, 1024, (byte) 1, -1, 2, -0.001,
DHistogram histo = new DHistogram("test", 20, 1024, (byte) 1, -1, 2, false, -0.001,
SharedTreeModel.SharedTreeParameters.HistogramType.QuantilesGlobal, 42L, hq._key, null);
histo.init();

Expand Down
6 changes: 3 additions & 3 deletions h2o-algos/src/test/java/hex/tree/DTreeTest.java
Expand Up @@ -16,7 +16,7 @@ public void testFindBestSplitPoint_pubdev6495() {
double[] ys = new double[]{0, 1, 0, 1 };
int[] rows = new int[] {0, 1, 2, 3 };

DHistogram hs = new DHistogram("test_hs", 2, 2, (byte) 0, 0, 2, 0.01,
DHistogram hs = new DHistogram("test_hs", 2, 2, (byte) 0, 0, 2, true, 0.01,
SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123, null, null);
hs.init();
hs.updateHisto(ws, null, cs, ys,rows, rows.length, 0);
Expand All @@ -30,7 +30,7 @@ public void testFindBestSplitPoint_pubdev6495() {
assertNull(s2); // not enough improvement => no split

// 3. allow negative (!!!) split improvement, min_rows = #NAs + 1
DHistogram hsN = new DHistogram("test_hs", 2, 2, (byte) 0, 0, 2, -9,
DHistogram hsN = new DHistogram("test_hs", 2, 2, (byte) 0, 0, 2, true, -9,
SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123, null, null);
hsN.init();
hsN.updateHisto(ws, null, cs, ys,rows, rows.length, 0);
Expand Down Expand Up @@ -73,7 +73,7 @@ private static DHistogram makeHisto(int nbins, double min_pred, double max_pred)
.withNewConstraint(0, 0, max_pred);
assertEquals(min_pred, c._min, 0);
assertEquals(max_pred, c._max, 0);
return new DHistogram("test_hs", nbins, 2, (byte) 0, 0, 10, 0.01,
return new DHistogram("test_hs", nbins, 2, (byte) 0, 0, 10, false, 0.01,
SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive, 123, null, c);
}

Expand Down
8 changes: 4 additions & 4 deletions h2o-algos/src/test/java/hex/tree/HistogramTest.java
Expand Up @@ -126,7 +126,7 @@ public void compute2() {
}
Key k = Key.make();
DKV.put(new DHistogram.HistoQuantiles(k,splitPts));
DHistogram hist = new DHistogram("myhisto",nbins,nbins_cats,isInt,min,maxEx,0,histoType,seed,k,null);
DHistogram hist = new DHistogram("myhisto",nbins,nbins_cats,isInt,min,maxEx,false,0,histoType,seed,k,null);
hist.init();
int N=10000000;
int bin=-1;
Expand Down Expand Up @@ -161,7 +161,7 @@ public void compute2() {
double maxEx = 6.900000000000001;
long seed = 1234;
SharedTreeModel.SharedTreeParameters.HistogramType histoType = SharedTreeModel.SharedTreeParameters.HistogramType.UniformAdaptive;
DHistogram hist = new DHistogram("myhisto", nbins, nbins_cats, isInt, min, maxEx, 0, histoType, seed, null, null);
DHistogram hist = new DHistogram("myhisto", nbins, nbins_cats, isInt, min, maxEx,false, 0, histoType, seed, null, null);
hist.init();
assert(hist.binAt(0)==min);
assert(hist.binAt(nbins-1)<maxEx);
Expand All @@ -177,7 +177,7 @@ public void compute2() {
double maxEx = 6.900000000000001;
long seed = 1234;
SharedTreeModel.SharedTreeParameters.HistogramType histoType = SharedTreeModel.SharedTreeParameters.HistogramType.Random;
DHistogram hist = new DHistogram("myhisto", nbins, nbins_cats, isInt, min, maxEx, 0, histoType, seed, null, null);
DHistogram hist = new DHistogram("myhisto", nbins, nbins_cats, isInt, min, maxEx, false,0, histoType, seed, null, null);
hist.init();
assert(hist.binAt(0)==min);
assert(hist.binAt(nbins-1)<maxEx);
Expand All @@ -196,7 +196,7 @@ public void compute2() {
double[] splitPts = new double[]{1,1.5,2,2.5,3,4,5,6.1,6.2,6.3,6.7,6.8,6.85};
Key k = Key.make();
DKV.put(new DHistogram.HistoQuantiles(k,splitPts));
DHistogram hist = new DHistogram("myhisto",nbins,nbins_cats,isInt,min,maxEx,0,histoType,seed,k,null);
DHistogram hist = new DHistogram("myhisto",nbins,nbins_cats,isInt,min,maxEx,false,0,histoType,seed,k,null);
hist.init();
assert(hist.binAt(0)==min);
assert(hist.binAt(nbins-1)<maxEx);
Expand Down
39 changes: 31 additions & 8 deletions h2o-algos/src/test/java/hex/tree/gbm/SharedTreeTest.java
Expand Up @@ -17,6 +17,7 @@
import java.util.Arrays;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertTrue;

@RunWith(Parameterized.class)
Expand All @@ -39,38 +40,55 @@ public static Iterable<SharedTreeModel.SharedTreeParameters> data() {
}

@Parameterized.Parameter
public SharedTreeModel.SharedTreeParameters parms;
public SharedTreeModel.SharedTreeParameters _parms;

@Test
public void testNAPredictor_cat() {
checkNAPredictor(new TestFrameBuilder()
checkNAPredictor(twoVecFrameBuilder()
.withVecTypes(Vec.T_CAT, Vec.T_CAT)
.withDataForCol(0, ar(null, "V", null, "V", null, "V"))
);
}

@Test
public void testNAPredictor_num() {
checkNAPredictor(new TestFrameBuilder()
checkNAPredictor(twoVecFrameBuilder()
.withVecTypes(Vec.T_NUM, Vec.T_CAT)
.withDataForCol(0, ard(Double.NaN, 1, Double.NaN, 1, Double.NaN, 1))
);
}

@Test
public void testNAPredictor_PUBDEV7517() {
SharedTreeModel.SharedTreeParameters parms = (SharedTreeModel.SharedTreeParameters) _parms.clone();
parms._col_sample_rate_per_tree = 0.5; // this will trigger a code path that actually evaluates the initial histograms
checkNAPredictor(new TestFrameBuilder()
.withColNames("F1", "F2", "Response")
.withVecTypes(Vec.T_NUM, Vec.T_NUM, Vec.T_CAT)
.withDataForCol(0, ard(Double.NaN, 0, Double.NaN, 0, Double.NaN, 0))
.withDataForCol(1, ard(Double.NaN, 0, Double.NaN, 0, Double.NaN, 0)) // copy of the first one
.withDataForCol(2, ar("A", "B", "A", "B", "A", "B")),
parms
);
}

private void checkNAPredictor(TestFrameBuilder fb) {
checkNAPredictor(fb, (SharedTreeModel.SharedTreeParameters) _parms.clone());
}

private void checkNAPredictor(TestFrameBuilder fb, SharedTreeModel.SharedTreeParameters parms) {
Scope.enter();
try {
final Frame frame = fb
.withColNames("F", "Response")
.withDataForCol(1, ar("A", "B", "A", "B", "A", "B"))
.build();
Frame frame = fb.build();

assertNotSame(parms, _parms); // make sure we are mutating a clone
parms._train = frame._key;
parms._valid = frame._key; // we don't do sampling in DRF, metrics will be NA
parms._response_column = "Response";
parms._ntrees = 1;
parms._ignore_const_cols = true; // default but to make sure and illustrate the point
parms._min_rows = 1;
parms._seed = 42;

SharedTreeModel model = (SharedTreeModel) ModelBuilder.make(parms).trainModel().get();
Scope.track_generic(model);
Expand All @@ -79,7 +97,7 @@ private void checkNAPredictor(TestFrameBuilder fb) {
assertEquals(0, model.classification_error(), 0);

// Check that we predict perfectly
Frame test = Scope.track(frame.subframe(new String[]{"F"}));
Frame test = Scope.track(frame.subframe(model._output.features()));
Frame scored = Scope.track(model.score(test));
assertCatVecEquals(frame.vec("Response"), scored.vec("predict"));

Expand All @@ -92,5 +110,10 @@ private void checkNAPredictor(TestFrameBuilder fb) {
}
}

private TestFrameBuilder twoVecFrameBuilder() {
return new TestFrameBuilder()
.withColNames("F", "Response")
.withDataForCol(1, ar("A", "B", "A", "B", "A", "B"));
}

}
2 changes: 1 addition & 1 deletion h2o-automl/src/main/java/ai/h2o/automl/FrameMetadata.java
Expand Up @@ -474,7 +474,7 @@ public MetaPass1(int idx, FrameMetadata fm) {
int xbins = (char) ((long) v.max() - (long) v.min());

if(!(_colMeta._ignored) && !(_colMeta._v.isBad()) && xbins > 0) {
_colMeta._histo = MetaCollector.DynamicHisto.makeDHistogram(colname, nbins, nbins, (byte) (v.isCategorical() ? 2 : (v.isInt() ? 1 : 0)), v.min(), v.max());
_colMeta._histo = MetaCollector.DynamicHisto.makeDHistogram(colname, nbins, nbins, (byte) (v.isCategorical() ? 2 : (v.isInt() ? 1 : 0)), v.min(), v.max(), v.naCnt() > 0);
}

// Skewness & Kurtosis
Expand Down
Expand Up @@ -98,9 +98,9 @@ public final static class DynamicHisto extends MRTask<DynamicHisto> {
public double[] _ssqs; // different from _h._ssqs
public DynamicHisto(DHistogram h) { _h=h; }
DynamicHisto(String name, final int nbins, int nbins_cats, byte isInt,
double min, double max) {
double min, double max, boolean hasNAs) {
if(!(Double.isNaN(min)) && !(Double.isNaN(max))) { //If both are NaN then we don't need a histogram
_h = makeDHistogram(name, nbins, nbins_cats, isInt, min, max);
_h = makeDHistogram(name, nbins, nbins_cats, isInt, min, max, hasNAs);
}else{
Log.info("Ignoring all NaN column -> "+ name);
}
Expand All @@ -112,7 +112,7 @@ private static class SharedTreeParameters extends SharedTreeModel.SharedTreePara
public String javaName() { return "this.is.unused"; }
}
public static DHistogram makeDHistogram(String name, int nbins, int nbins_cats, byte isInt,
double min, double max) {
double min, double max, boolean hasNAs) {
final double minIn = Math.max(min,-Double.MAX_VALUE); // inclusive vector min
final double maxIn = Math.min(max, Double.MAX_VALUE); // inclusive vector max
final double maxEx = DHistogram.find_maxEx(maxIn,isInt==1?1:0); // smallest exclusive max
Expand All @@ -122,7 +122,7 @@ public static DHistogram makeDHistogram(String name, int nbins, int nbins_cats,
parms._nbins = nbins;
parms._nbins_cats = nbins_cats;

return DHistogram.make(name, nbins, isInt, minIn, maxEx, 0, parms, null, null);
return DHistogram.make(name, nbins, isInt, minIn, maxEx, hasNAs, 0, parms, null, null);
}
public double binAt(int b) { return _h.binAt(b); }

Expand Down

0 comments on commit e7336ff

Please sign in to comment.