Skip to content

Commit

Permalink
PUBDEV-1847: More code cleanup / refactoring of GBM/DRF.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed Sep 25, 2015
1 parent aa07b01 commit 2534dc6
Show file tree
Hide file tree
Showing 6 changed files with 154 additions and 147 deletions.
17 changes: 9 additions & 8 deletions h2o-algos/src/main/java/hex/tree/DTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -255,25 +255,26 @@ public DHistogram[] split(int way, char nbins, char nbins_cats, double min_rows,
// with other histograms) in a single pass over the data. Does not contain
// any split-decision.
public static class UndecidedNode extends Node {
public transient DHistogram[] _hs;
public transient DHistogram[] _hs; //(up to) one histogram per column
public final int _scoreCols[]; // A list of columns to score; could be null for all
public UndecidedNode( DTree tree, int pid, DHistogram[] hs ) {
super(tree,pid);
assert hs.length==tree._ncols;
_scoreCols = scoreCols(_hs=hs);
_hs = hs; //these histograms have no bins yet (just constructed)
_scoreCols = scoreCols();
}

// Pick a random selection of columns to compute best score.
// Can return null for 'all columns'.
public int[] scoreCols( DHistogram[] hs ) {
public int[] scoreCols() {
DTree tree = _tree;
if (tree._mtrys == hs.length) return null;
int[] cols = new int[hs.length];
if (tree._mtrys == _hs.length) return null;
int[] cols = new int[_hs.length];
int len=0;
// Gather all active columns to choose from.
for( int i=0; i<hs.length; i++ ) {
if( hs[i]==null ) continue; // Ignore not-tracked cols
assert hs[i]._min < hs[i]._maxEx && hs[i].nbins() > 1 : "broken histo range "+hs[i];
for( int i=0; i<_hs.length; i++ ) {
if( _hs[i]==null ) continue; // Ignore not-tracked cols
assert _hs[i]._min < _hs[i]._maxEx && _hs[i].nbins() > 1 : "broken histo range "+_hs[i];
cols[len++] = i; // Gather active column
}
int choices = len; // Number of columns I can choose from
Expand Down
12 changes: 9 additions & 3 deletions h2o-algos/src/main/java/hex/tree/ScoreBuildHistogram.java
Original file line number Diff line number Diff line change
Expand Up @@ -167,12 +167,18 @@ private void accum_subset(Chunk chks[], Chunk wrks, Chunk weight, int nnids[]) {
if( nid >= 0 ) { // row already predicts perfectly or OOB
double w = weight.atd(row);
if (w == 0) continue;
double resp = wrks.atd(row);
assert !Double.isNaN(wrks.atd(row)); // Already marked as sampled-away
DHistogram nhs[] = _hcs[nid];
int sCols[] = _tree.undecided(nid+_leaf)._scoreCols; // Columns to score (null, or a list of selected cols)
//FIXME/TODO: sum into local variables, do atomic increment once at the end, similar to accum_all
for( int col : sCols ) { // For tracked cols
nhs[col].incr((float) chks[col].atd(row), wrks.atd(row), w); // Histogram row/col
if (sCols == null) {
for(int col=0; col<nhs.length; ++col ) { //all columns
if (nhs[col]!=null)
nhs[col].incr((float) chks[col].atd(row), resp, w); // Histogram row/col
}
} else {
for( int col : sCols )
nhs[col].incr((float) chks[col].atd(row), resp, w); // Histogram row/col
}
}
}
Expand Down
61 changes: 55 additions & 6 deletions h2o-algos/src/main/java/hex/tree/SharedTree.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import hex.*;
import hex.genmodel.GenModel;
import hex.tree.drf.DRF;
import jsr166y.CountedCompleter;
import org.joda.time.format.DateTimeFormat;
import org.joda.time.format.DateTimeFormatter;
Expand Down Expand Up @@ -45,6 +44,8 @@ public abstract class SharedTree<M extends SharedTreeModel<M,P,O>, P extends Sha
// Sum of variable empirical improvement in squared-error. The value is not scaled.
private transient float[/*nfeatures*/] _improvPerVar;

protected Random _rand;

public boolean isSupervised(){return true;}

Key _response_key;
Expand Down Expand Up @@ -222,14 +223,20 @@ abstract protected class Driver extends H2OCountedCompleter<Driver> {
for( int i=0; i<_nclass; i++ )
_train.add("NIDs_"+domain[i], _response.makeCon(_model._output._distribution==null ? 0 : (_model._output._distribution[i]==0?-1:0)));

// Append number of trees participating in on-the-fly scoring
_train.add("OUT_BAG_TREES", _response.makeZero());

// Tag out rows missing the response column
new ExcludeNAResponse().doAll(_train);

// Variable importance: squared-error-improvement-per-variable-per-split
_improvPerVar = new float[_ncols];
_rand = createRNG(_parms._seed);

initializeModelSpecifics();
resumeFromCheckpoint();
scoreAndBuildTrees(doOOBScoring());

// Sub-class tree-model-builder specific build code
buildModel();
done(); // Job done!
} catch( Throwable t ) {
Job thisJob = DKV.getGet(_key);
Expand All @@ -255,7 +262,49 @@ abstract protected class Driver extends H2OCountedCompleter<Driver> {

// Abstract classes implemented by the tree builders
abstract protected M makeModel( Key modelKey, P parms, double mse_train, double mse_valid );
abstract protected void buildModel();
abstract protected boolean doOOBScoring();
abstract protected void buildNextKTrees();
abstract protected void initializeModelSpecifics();

// Common methods for all tree builders

/**
* Restore the workspace from a previous model (checkpoint)
*/
protected final void resumeFromCheckpoint() {
if( !_parms.hasCheckpoint() ) return;
// Reconstruct the working tree state from the checkpoint
Timer t = new Timer();
int ntreesFromCheckpoint = ((SharedTreeModel.SharedTreeParameters) _parms._checkpoint.<SharedTreeModel>get()._parms)._ntrees;
new OOBScorer(_ncols, _nclass, numSpecialCols(), _parms._sample_rate,_model._output._treeKeys).doAll(_train, _parms._build_tree_one_node);
for (int i = 0; i < ntreesFromCheckpoint; i++) _rand.nextLong(); //for determinism
Log.info("Reconstructing OOB stats from checkpoint took " + t);
}

/**
* Build more trees, as specified by the model parameters
* @param oob Whether or not Out-Of-Bag scoring should be performed
*/
protected final void scoreAndBuildTrees(boolean oob) {
for( int tid=0; tid< _ntrees; tid++) {
// During first iteration model contains 0 trees, then 1-tree, ...
// No need to score a checkpoint with no extra trees added
if( tid!=0 || !_parms.hasCheckpoint() ) { // do not make initial scoring if model already exist
double training_r2 = doScoringAndSaveModel(false, oob, _parms._build_tree_one_node);
if( training_r2 >= _parms._r2_stopping ) {
doScoringAndSaveModel(true, oob, _parms._build_tree_one_node);
return; // Stop when approaching round-off error
}
}
Timer kb_timer = new Timer();
buildNextKTrees();
Log.info((tid + 1) + ". tree was built in " + kb_timer.toString());
update(1);
if( !isRunning() ) return; // If canceled during building, do not bulkscore
}
// Final scoring (skip if job was cancelled)
doScoringAndSaveModel(true, oob, _parms._build_tree_one_node);
}

/** Performs deep clone of given model.
*
Expand All @@ -271,7 +320,7 @@ protected M getModelDeepClone(M model) {
Key[][] treeKeys = newModel._output._treeKeys;
for (int i = 0; i < treeKeys.length; i++) {
for (int j = 0; j < treeKeys[i].length; j++) {
if (treeKeys[i][j] == null) continue;;
if (treeKeys[i][j] == null) continue;
CompressedTree ct = DKV.get(treeKeys[i][j]).get();
CompressedTree newCt = IcedUtils.clone(ct, CompressedTree.makeTreeKey(i, j), true);
treeKeys[i][j] = newCt._key;
Expand Down Expand Up @@ -459,7 +508,7 @@ class ExcludeNAResponse extends MRTask<ExcludeNAResponse> {

// --------------------------------------------------------------------------
transient long _timeLastScoreStart, _timeLastScoreEnd, _firstScore;
protected double doScoringAndSaveModel(boolean finalScoring, boolean oob, boolean build_tree_one_node ) {
protected final double doScoringAndSaveModel(boolean finalScoring, boolean oob, boolean build_tree_one_node ) {
double training_r2 = Double.NaN; // Training R^2 value, if computed
long now = System.currentTimeMillis();
if( _firstScore == 0 ) _firstScore=now;
Expand Down
99 changes: 25 additions & 74 deletions h2o-algos/src/main/java/hex/tree/drf/DRF.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class DRF extends SharedTree<hex.tree.drf.DRFModel, hex.tree.drf.DRFModel
@Override public BuilderVisibility builderVisibility() { return BuilderVisibility.Stable; };

// Called from an http request
public DRF( hex.tree.drf.DRFModel.DRFParameters parms) { super("DRF",parms); init(false); }
public DRF( hex.tree.drf.DRFModel.DRFParameters parms) { super("DRF", parms); init(false); }

@Override public DRFV3 schema() { return new DRFV3(); }

Expand Down Expand Up @@ -72,9 +72,6 @@ public class DRF extends SharedTree<hex.tree.drf.DRFModel, hex.tree.drf.DRFModel
if (_nclass == 1) _parms._distribution = Distribution.Family.gaussian;
if (_nclass >= 2) _parms._distribution = Distribution.Family.multinomial;
}
if (expensive) {
_initialPrediction = isClassifier() ? 0 : getInitialValue();
}
if (_parms._sample_rate == 1f && _valid == null)
error("_sample_rate", "Sample rate is 100% and no validation dataset is specified. There are no OOB data to compute out-of-bag error estimation!");
if (hasOffsetCol())
Expand All @@ -84,28 +81,9 @@ public class DRF extends SharedTree<hex.tree.drf.DRFModel, hex.tree.drf.DRFModel
}
}

/** Fill work columns:
* - classification: set 1 in the corresponding wrk col according to row response
* - regression: copy response into work column (there is only 1 work column)
*/
private class SetWrkTask extends MRTask<SetWrkTask> {
@Override public void map( Chunk chks[] ) {
Chunk cy = chk_resp(chks);
for( int i=0; i<cy._len; i++ ) {
if( cy.isNA(i) ) continue;
if (isClassifier()) {
int cls = (int)cy.at8(i);
chk_work(chks,cls).set(i,1L);
} else {
float pred = (float) cy.atd(i);
chk_work(chks,0).set(i,pred);
}
}
}
}

// ----------------------
private class DRFDriver extends Driver {
@Override protected boolean doOOBScoring() { return true; }

// --- Private data handled only on master node
// Classification or Regression:
Expand All @@ -131,76 +109,49 @@ private void initTreeMeasurements() {
}
}

@Override protected void buildModel() {
@Override protected void initializeModelSpecifics() {
_mtry = (_parms._mtries==-1) ? // classification: mtry=sqrt(_ncols), regression: mtry=_ncols/3
( isClassifier() ? Math.max((int)Math.sqrt(_ncols),1) : Math.max(_ncols/3,1)) : _parms._mtries;
// How many trees was in already in provided checkpointed model
int ntreesFromCheckpoint = _parms.hasCheckpoint() ?
((SharedTreeModel.SharedTreeParameters) _parms._checkpoint.<SharedTreeModel>get()._parms)._ntrees : 0;

if (!(1 <= _mtry && _mtry <= _ncols)) throw new IllegalArgumentException("Computed mtry should be in interval <1,"+_ncols+"> but it is " + _mtry);
_initialPrediction = isClassifier() ? 0 : getInitialValue();
// Initialize TreeVotes for classification, MSE arrays for regression
initTreeMeasurements();
// Append number of trees participating in on-the-fly scoring
_train.add("OUT_BAG_TREES", _response.makeZero());
// Prepare working columns
new SetWrkTask().doAll(_train);
// If there was a check point recompute tree_<_> and oob columns based on predictions from previous trees
// but only if OOB validation is requested.
if (_parms.hasCheckpoint()) {
Timer t = new Timer();
// Compute oob votes for each output level
new OOBScorer(_ncols, _nclass, numSpecialCols(), _parms._sample_rate, _model._output._treeKeys).doAll(_train);
Log.info("Reconstructing oob stats from checkpointed model took " + t);
}

// The RNG used to pick split columns
Random rand = createRNG(_parms._seed);
// To be deterministic get random numbers for previous trees and
// put random generator to the same state
for (int i = 0; i < ntreesFromCheckpoint; i++) rand.nextLong();

int tid;

// Prepare tree statistics
// Build trees until we hit the limit
for( tid=0; tid < _ntrees; tid++) { // Building tid-tree
if (tid!=0 || !_parms.hasCheckpoint()) { // do not make initial scoring if model already exist
double training_r2 = doScoringAndSaveModel(false, true, _parms._build_tree_one_node);
if( training_r2 >= _parms._r2_stopping ) {
doScoringAndSaveModel(true, true, _parms._build_tree_one_node);
return; // Stop when approaching round-off error
/** Fill work columns:
* - classification: set 1 in the corresponding wrk col according to row response
* - regression: copy response into work column (there is only 1 work column)
*/
new MRTask() {
@Override public void map(Chunk chks[]) {
Chunk cy = chk_resp(chks);
for (int i = 0; i < cy._len; i++) {
if (cy.isNA(i)) continue;
if (isClassifier()) {
int cls = (int) cy.at8(i);
chk_work(chks, cls).set(i, 1L);
} else {
float pred = (float) cy.atd(i);
chk_work(chks, 0).set(i, pred);
}
}
}
// At each iteration build K trees (K = nclass = response column domain size)

// TODO: parallelize more? build more than k trees at each time, we need to care about temporary data
// Idea: launch more DRF at once.
Timer kb_timer = new Timer();
buildNextKTrees(rand);
Log.info((tid+1) + ". tree was built " + kb_timer.toString());
DRF.this.update(1);
if( !isRunning() ) return; // If canceled during building, do not bulkscore

}
doScoringAndSaveModel(true, true, _parms._build_tree_one_node);
}.doAll(_train);
}



// --------------------------------------------------------------------------
// Build the next random k-trees representing tid-th tree
private void buildNextKTrees(Random rand) {
@Override protected void buildNextKTrees() {
// We're going to build K (nclass) trees - each focused on correcting
// errors for a single class.
final DTree[] ktrees = new DTree[_nclass];

// Define a "working set" of leaf splits, from leafs[i] to tree._len for each tree i
int[] leafs = new int[_nclass];

growTrees(ktrees, leafs, rand);
// Assign rows to nodes - fill the "NIDs" column(s)
growTrees(ktrees, leafs, _rand);

// Move rows into the final leaf rows
// Move rows into the final leaf rows - fill "Tree" and OUT_BAG_TREES columns and zap the NIDs column
CollectPreds cp = new CollectPreds(ktrees,leafs,_model.defaultThreshold()).doAll(_train,_parms._build_tree_one_node);

if (isClassifier()) asVotes(_treeMeasuresOnOOB).append(cp.rightVotes, cp.allRows); // Track right votes over OOB rows for this tree
Expand Down
Loading

0 comments on commit 2534dc6

Please sign in to comment.