Permalink
Browse files

An option for doing binomial+1 or epsilon-dropout from DART paper (#1922

)

* An option for doing binomial+1 or epsilon-dropout from DART paper

* use callback-based discrete_distribution to make MSVC2013 happy
  • Loading branch information...
1 parent ce84af7 commit d23ea5ca7dd681c563da58a256d86461659c2de1 @khotilov khotilov committed with tqchen Jan 6, 2017
Showing with 29 additions and 6 deletions.
  1. +5 −2 doc/parameter.md
  2. +24 −4 src/gbm/gbtree.cc
View
@@ -110,11 +110,14 @@ Additional parameters for Dart Booster
- weight of new trees are 1 / (1 + learning_rate)
- dropped trees are scaled by a factor of 1 / (1 + learning_rate)
* rate_drop [default=0.0]
- - dropout rate.
+ - dropout rate (a fraction of previous trees to drop during the dropout).
- range: [0.0, 1.0]
+* one_drop [default=0]
+ - when this flag is enabled, at least one tree is always dropped during the dropout (allows Binomial-plus-one or epsilon-dropout from the original DART paper).
* skip_drop [default=0.0]
- - probability of skip dropout.
+ - Probability of skipping the dropout procedure during a boosting iteration.
- If a dropout is skipped, new trees are added in the same manner as gbtree.
+ - Note that non-zero skip_drop has higher priority than rate_drop or one_drop.
- range: [0.0, 1.0]
Parameters for Linear Booster
View
@@ -72,9 +72,11 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
int sample_type;
/*! \brief type of normalization algorithm */
int normalize_type;
- /*! \brief how many trees are dropped */
+ /*! \brief fraction of trees to drop during the dropout */
float rate_drop;
- /*! \brief whether to drop trees */
+ /*! \brief whether at least one tree should always be dropped during the dropout */
+ bool one_drop;
+ /*! \brief probability of skipping the dropout during an iteration */
float skip_drop;
/*! \brief learning step size for a time */
float learning_rate;
@@ -96,11 +98,14 @@ struct DartTrainParam : public dmlc::Parameter<DartTrainParam> {
DMLC_DECLARE_FIELD(rate_drop)
.set_range(0.0f, 1.0f)
.set_default(0.0f)
- .describe("Parameter of how many trees are dropped.");
+ .describe("Fraction of trees to drop during the dropout.");
+ DMLC_DECLARE_FIELD(one_drop)
+ .set_default(false)
+ .describe("Whether at least one tree should always be dropped during the dropout.");
DMLC_DECLARE_FIELD(skip_drop)
.set_range(0.0f, 1.0f)
.set_default(0.0f)
- .describe("Parameter of whether to drop trees.");
+ .describe("Probability of skipping the dropout during a boosting iteration.");
DMLC_DECLARE_FIELD(learning_rate)
.set_lower_bound(0.0f)
.set_default(0.3f)
@@ -658,12 +663,27 @@ class Dart : public GBTree {
idx_drop.push_back(i);
}
}
+ if (dparam.one_drop && idx_drop.empty() && !weight_drop.empty()) {
+ // the expression below is an ugly but MSVC2013-friendly equivalent of
+ // size_t i = std::discrete_distribution<size_t>(weight_drop.begin(),
+ // weight_drop.end())(rnd);
+ size_t i = std::discrete_distribution<size_t>(
+ weight_drop.size(), 0., static_cast<double>(weight_drop.size()),
+ [this](double x) -> double {
+ return weight_drop[static_cast<size_t>(x)];
+ })(rnd);
+ idx_drop.push_back(i);
+ }
} else {
for (size_t i = 0; i < weight_drop.size(); ++i) {
if (runif(rnd) < dparam.rate_drop) {
idx_drop.push_back(i);
}
}
+ if (dparam.one_drop && idx_drop.empty() && !weight_drop.empty()) {
+ size_t i = std::uniform_int_distribution<size_t>(0, weight_drop.size() - 1)(rnd);
+ idx_drop.push_back(i);
+ }
}
}
}

0 comments on commit d23ea5c

Please sign in to comment.