Skip to content

Commit

Permalink
Switch from Executors.newFixedThreadPool to Executors.newCachedThread…
Browse files Browse the repository at this point in the history
…Pool to reuse threads in restarts. Allways call shutdown on this thread pool to avoid infinite hang at the end of the program execution.. Fix #2151

ad54743f58bf228be3395a4f50bdb4bc8b05c042
  • Loading branch information
andrey-khropov committed Feb 15, 2024
1 parent 612cfcc commit 2cdb8a2
Showing 1 changed file with 50 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -223,62 +223,67 @@ trait CatBoostPredictorTrait[
master.savedPoolsFuture
)

breakable {
// retry training if network connection issues were the reason of failure
while (true) {
val trainingDriver : TrainingDriver = new TrainingDriver(
listeningPort = getOrDefault(trainingDriverListeningPort),
workerCount = partitionCount,
startMasterCallback = master.trainCallback,
connectTimeout = connectTimeoutValue,
workerInitializationTimeout = workerInitializationTimeoutValue
)

try {
val listeningPort = trainingDriver.getListeningPort
this.logInfo(s"fit. TrainingDriver listening port = ${listeningPort}")
val ecsPool = Executors.newCachedThreadPool()
try {
breakable {
// retry training if network connection issues were the reason of failure
while (true) {
val trainingDriver : TrainingDriver = new TrainingDriver(
listeningPort = getOrDefault(trainingDriverListeningPort),
workerCount = partitionCount,
startMasterCallback = master.trainCallback,
connectTimeout = connectTimeoutValue,
workerInitializationTimeout = workerInitializationTimeoutValue
)

this.logInfo(s"fit. Training started")
try {
val listeningPort = trainingDriver.getListeningPort
this.logInfo(s"fit. TrainingDriver listening port = ${listeningPort}")

val ecs = new ExecutorCompletionService[Unit](Executors.newFixedThreadPool(2))
this.logInfo(s"fit. Training started")

val trainingDriverFuture = ecs.submit(trainingDriver, ())
val ecs = new ExecutorCompletionService[Unit](ecsPool)

val workersFuture = ecs.submit(
new Runnable {
def run = {
workers.run(listeningPort)
}
},
()
)
val trainingDriverFuture = ecs.submit(trainingDriver, ())

var catboostWorkersConnectionLost = false
try {
impl.Helpers.waitForTwoFutures(ecs, trainingDriverFuture, "master", workersFuture, "workers")
break
} catch {
case e : java.util.concurrent.ExecutionException => {
e.getCause match {
case connectionLostException : CatBoostWorkersConnectionLostException => {
catboostWorkersConnectionLost = true
val workersFuture = ecs.submit(
new Runnable {
def run = {
workers.run(listeningPort)
}
},
()
)

var catboostWorkersConnectionLost = false
try {
impl.Helpers.waitForTwoFutures(ecs, trainingDriverFuture, "master", workersFuture, "workers")
break
} catch {
case e : java.util.concurrent.ExecutionException => {
e.getCause match {
case connectionLostException : CatBoostWorkersConnectionLostException => {
catboostWorkersConnectionLost = true
}
case _ => throw e
}
case _ => throw e
}
}
if (workers.workerFailureCount >= workerMaxFailuresValue) {
throw new CatBoostError(s"CatBoost workers failed at least $workerMaxFailuresValue times")
}
if (catboostWorkersConnectionLost) {
log.info(s"CatBoost master: communication with some of the workers has been lost. Retry training")
} else {
break
}
} finally {
trainingDriver.close(tryToShutdownWorkers=true, waitToShutdownWorkers=false)
}
if (workers.workerFailureCount >= workerMaxFailuresValue) {
throw new CatBoostError(s"CatBoost workers failed at least $workerMaxFailuresValue times")
}
if (catboostWorkersConnectionLost) {
log.info(s"CatBoost master: communication with some of the workers has been lost. Retry training")
} else {
break
}
} finally {
trainingDriver.close(tryToShutdownWorkers=true, waitToShutdownWorkers=false)
}
}
} finally {
ecsPool.shutdown()
}
this.logInfo(s"fit. Training finished")

Expand Down

0 comments on commit 2cdb8a2

Please sign in to comment.