Skip to content

Commit

Permalink
update rabit
Browse files Browse the repository at this point in the history
  • Loading branch information
Chen Qin committed Oct 11, 2019
1 parent abef46a commit 97120f1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ private[spark] trait RabitParams extends Params {
* rabit_reduce_buffer - buffer size to recv and run reduction
* rabit_bootstrap_cache - enable save allreduce cache before loadcheckpoint
* rabit_debug - enable more verbose rabit logging to stdout
* rabit_timeout - enable sidecar thread after rabit observed failures
* rabit_timeout_sec - wait interval before exit after rabit observed failures
* DMLC_WORKER_CONNECT_RETRY - number of retrys to tracker
*/
final val ringReduceMin = new IntParam(this, "rabit_reduce_ring_mincount",
Expand All @@ -50,13 +52,23 @@ private[spark] trait RabitParams extends Params {

final def getRabitDebug: Int = ${rabitDebug}

final def rabitTimeout: IntParam = new IntParam(this, "rabit_timeout",
"enable failure timeout sidecar threads", (timeout: Int) => timeout == 0 || timeout == 1)

final def getRabitTimeout: Int = ${rabitTimeout}

final def timeoutInterval: IntParam = new IntParam(this, "rabit_timeout_sec",
"timeout threshold after rabit observed failures", (interval: Int) => interval > 0)

final def getTimeoutInterval: Int = ${timeoutInterval}

final def connectRetry: IntParam = new IntParam(this, "DMLC_WORKER_CONNECT_RETRY",
"number of retry worker do before fail", ParamValidators.gtEq(1))

final def getConnectRetry: Int = ${connectRetry}

setDefault(ringReduceMin -> (32 << 10), reduceBuffer -> "256MB", bootstrapCache -> 0,
rabitDebug -> 0, connectRetry -> 5)
rabitDebug -> 0, connectRetry -> 5, rabitTimeout -> 0, timeoutInterval -> 1800)

def XGBoostToRabitParams(xgboostParams: Map[String, Any]): Unit = {
for ((paramName, paramValue) <- xgboostParams) {
Expand All @@ -71,10 +83,12 @@ private[spark] trait RabitParams extends Params {
}

def RabitParamsToXGBoost: Map[String, String] = Map(
"rabit_reduce_ring_mincount" -> getRingReduceMin.toString,
"rabit_reduce_buffer" -> getReduceBuffer.toString,
"rabit_bootstrap_cache" -> getBootstrapCache.toString,
"rabit_debug" -> getRabitDebug.toString,
"DMLC_WORKER_CONNECT_RETRY" -> getConnectRetry.toString
ringReduceMin.name -> getRingReduceMin.toString,
reduceBuffer.name -> getReduceBuffer.toString,
bootstrapCache.name -> getBootstrapCache.toString,
rabitDebug.name -> getRabitDebug.toString,
rabitTimeout.name -> getRabitTimeout.toString,
timeoutInterval.name -> getTimeoutInterval.toString,
connectRetry.name -> getConnectRetry.toString
)
}

0 comments on commit 97120f1

Please sign in to comment.