Skip to content

Commit

Permalink
Add Aws Batch native retry on spot reclaim
Browse files Browse the repository at this point in the history
  • Loading branch information
pditommaso committed Apr 1, 2022
1 parent 0044351 commit cd95e29
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 63 deletions.
1 change: 1 addition & 0 deletions docs/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ volumes One or more container mounts. Mounts can be specifie
delayBetweenAttempts Delay between download attempts from S3 (default `10 sec`).
maxParallelTransfers Max parallel upload/download transfer operations *per job* (default: ``4``).
maxTransferAttempts Max number of downloads attempts from S3 (default: `1`).
maxSpotAttempts Max number of execution attempts of a job interrupted by a EC2 spot reclaim event (default: ``5``, requires ``22.04.0`` or later)
=========================== ================


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,10 @@

package nextflow.cloud.aws.batch

import static AwsContainerOptionsMapper.*
import static nextflow.cloud.aws.batch.AwsContainerOptionsMapper.*

import java.nio.file.Path
import java.nio.file.Paths
import java.util.regex.Pattern

import com.amazonaws.services.batch.AWSBatch
import com.amazonaws.services.batch.model.AWSBatchException
Expand All @@ -33,6 +32,7 @@ import com.amazonaws.services.batch.model.DescribeJobDefinitionsRequest
import com.amazonaws.services.batch.model.DescribeJobDefinitionsResult
import com.amazonaws.services.batch.model.DescribeJobsRequest
import com.amazonaws.services.batch.model.DescribeJobsResult
import com.amazonaws.services.batch.model.EvaluateOnExit
import com.amazonaws.services.batch.model.Host
import com.amazonaws.services.batch.model.JobDefinition
import com.amazonaws.services.batch.model.JobDefinitionType
Expand All @@ -53,14 +53,12 @@ import groovy.transform.CompileStatic
import groovy.util.logging.Slf4j
import nextflow.cloud.types.CloudMachineInfo
import nextflow.container.ContainerNameValidator
import nextflow.exception.NodeTerminationException
import nextflow.exception.ProcessSubmitException
import nextflow.exception.ProcessUnrecoverableException
import nextflow.executor.BashWrapperBuilder
import nextflow.executor.res.AcceleratorResource
import nextflow.processor.BatchContext
import nextflow.processor.BatchHandler
import nextflow.processor.ErrorStrategy
import nextflow.processor.TaskBean
import nextflow.processor.TaskHandler
import nextflow.processor.TaskRun
Expand All @@ -74,8 +72,6 @@ import nextflow.util.CacheHelper
@Slf4j
class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,JobDetail> {

private static Pattern TERMINATED = ~/^Host EC2 .* terminated.*/

private final Path exitFile

private final Path wrapperFile
Expand Down Expand Up @@ -108,8 +104,6 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job

private Map<String,String> environment

private boolean batchNativeRetry

final static private Map<String,String> jobDefinitions = [:]

/**
Expand Down Expand Up @@ -256,23 +250,15 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
final job = describeJob(jobId)
final done = job?.status in ['SUCCEEDED', 'FAILED']
if( done ) {
if( !batchNativeRetry && TERMINATED.matcher(job.statusReason).find() ) {
// kee track of the node termination error
task.error = new NodeTerminationException(job.statusReason)
// mark the task as ABORTED since thr failure is caused by a node failure
task.aborted = true
// finalize the task
task.exitStatus = readExitFile()
task.stdout = outputFile
if( job?.status == 'FAILED' ) {
task.error = new ProcessUnrecoverableException(errReason(job))
task.stderr = executor.getJobOutputStream(jobId) ?: errorFile
}
else {
// finalize the task
task.exitStatus = readExitFile()
task.stdout = outputFile
if( job?.status == 'FAILED' ) {
task.error = new ProcessUnrecoverableException(errReason(job))
task.stderr = executor.getJobOutputStream(jobId) ?: errorFile
}
else {
task.stderr = errorFile
}
task.stderr = errorFile
}
status = TaskStatus.COMPLETED
return true
Expand Down Expand Up @@ -620,6 +606,10 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
return ['bash','-o','pipefail','-c', cmd.toString() ]
}

protected maxSpotAttempts() {
return executor.awsOptions.maxSpotAttempts
}

/**
* Create a new Batch job request for the given NF {@link TaskRun}
*
Expand All @@ -636,19 +626,16 @@ class AwsBatchTaskHandler extends TaskHandler implements BatchHandler<String,Job
result.setJobQueue(getJobQueue(task))
result.setJobDefinition(getJobDefinition(task))

// -- NF uses `maxRetries` *only* if `retry` error strategy is specified
// otherwise delegates the the retry to AWS Batch
// -- NOTE: make sure the `errorStrategy` is a static value before invoking `getMaxRetries` and `getErrorStrategy`
// when the errorStrategy is closure (ie. dynamic evaluated) value, the `task.config.getMaxRetries() && task.config.getErrorStrategy()`
// condition should not be evaluated because otherwise the closure value is cached using the wrong task.attempt and task.exitStatus values.
// -- use of `config.getRawValue('errorStrategy')` instead of `config.getErrorStrategy()` to prevent the resolution
// of values dynamic values i.e. closures
final strategy = task.config.getRawValue('errorStrategy')
final canCheck = strategy == null || strategy instanceof CharSequence
if( canCheck && task.config.getMaxRetries() && task.config.getErrorStrategy() != ErrorStrategy.RETRY ) {
def retry = new RetryStrategy().withAttempts( task.config.getMaxRetries()+1 )
/*
* retry on spot reclaim
* https://aws.amazon.com/blogs/compute/introducing-retry-strategies-for-aws-batch/
*/
final attempts = maxSpotAttempts()
if( attempts>0 ) {
final retry = new RetryStrategy()
.withAttempts( attempts )
.withEvaluateOnExit( new EvaluateOnExit().withOnReason('Host EC2*').withAction('RETRY') )
result.setRetryStrategy(retry)
this.batchNativeRetry = true
}

// set task timeout
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ class AwsOptions implements CloudTransferOptions {

public static final int DEFAULT_AWS_MAX_ATTEMPTS = 5

public static final int DEFAULT_MAX_SPOT_ATTEMPTS = 5

private Map<String,String> env = System.getenv()

String cliPath
Expand All @@ -61,6 +63,8 @@ class AwsOptions implements CloudTransferOptions {

String retryMode

int maxSpotAttempts

volatile Boolean fetchInstanceType

/**
Expand Down Expand Up @@ -93,6 +97,7 @@ class AwsOptions implements CloudTransferOptions {
maxParallelTransfers = session.config.navigate('aws.batch.maxParallelTransfers', MAX_TRANSFER) as int
maxTransferAttempts = session.config.navigate('aws.batch.maxTransferAttempts', defaultMaxTransferAttempts()) as int
delayBetweenAttempts = session.config.navigate('aws.batch.delayBetweenAttempts', DEFAULT_DELAY_BETWEEN_ATTEMPTS) as Duration
maxSpotAttempts = session.config.navigate('aws.batch.maxSpotAttempts', DEFAULT_MAX_SPOT_ATTEMPTS) as int
region = session.config.navigate('aws.region') as String
volumes = makeVols(session.config.navigate('aws.batch.volumes'))
jobRole = session.config.navigate('aws.batch.jobRole')
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import com.amazonaws.services.batch.model.DescribeJobDefinitionsRequest
import com.amazonaws.services.batch.model.DescribeJobDefinitionsResult
import com.amazonaws.services.batch.model.DescribeJobsRequest
import com.amazonaws.services.batch.model.DescribeJobsResult
import com.amazonaws.services.batch.model.EvaluateOnExit
import com.amazonaws.services.batch.model.JobDefinition
import com.amazonaws.services.batch.model.JobDetail
import com.amazonaws.services.batch.model.KeyValuePair
Expand All @@ -36,7 +37,6 @@ import com.amazonaws.services.batch.model.SubmitJobResult
import com.amazonaws.services.batch.model.TerminateJobRequest
import nextflow.cloud.types.CloudMachineInfo
import nextflow.cloud.types.PriceModel
import nextflow.exception.NodeTerminationException
import nextflow.exception.ProcessUnrecoverableException
import nextflow.executor.Executor
import nextflow.processor.BatchContext
Expand Down Expand Up @@ -84,6 +84,7 @@ class AwsBatchTaskHandlerTest extends Specification {
when:
def req = handler.newSubmitRequest(task)
then:
1 * handler.maxSpotAttempts() >> 5
1 * handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') }
1 * handler.getJobQueue(task) >> 'queue1'
1 * handler.getJobDefinition(task) >> 'job-def:1'
Expand All @@ -98,11 +99,12 @@ class AwsBatchTaskHandlerTest extends Specification {
req.getContainerOverrides().getResourceRequirements().find { it.type=='MEMORY'}.getValue() == '8192'
req.getContainerOverrides().getEnvironment() == [VAR_FOO, VAR_BAR]
req.getContainerOverrides().getCommand() == ['bash', '-o','pipefail','-c', "trap \"{ ret=\$?; /bin/aws s3 cp --only-show-errors .command.log s3://bucket/test/.command.log||true; exit \$ret; }\" EXIT; /bin/aws s3 cp --only-show-errors s3://bucket/test/.command.run - | bash 2>&1 | tee .command.log".toString()]
req.getRetryStrategy() == null // <-- retry is managed by NF, hence this must be null
req.getRetryStrategy() == new RetryStrategy().withAttempts(5).withEvaluateOnExit( new EvaluateOnExit().withAction('RETRY').withOnReason('Host EC2*') )

when:
req = handler.newSubmitRequest(task)
then:
1 * handler.maxSpotAttempts() >> 0
1 * handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', region: 'eu-west-1') }
1 * handler.getJobQueue(task) >> 'queue1'
1 * handler.getJobDefinition(task) >> 'job-def:1'
Expand Down Expand Up @@ -135,6 +137,7 @@ class AwsBatchTaskHandlerTest extends Specification {
then:
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', region: 'eu-west-1') }
and:
1 * handler.maxSpotAttempts() >> 0
1 * handler.getJobQueue(task) >> 'queue1'
1 * handler.getJobDefinition(task) >> 'job-def:1'
and:
Expand All @@ -160,6 +163,7 @@ class AwsBatchTaskHandlerTest extends Specification {
task.getConfig() >> new TaskConfig()
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') }
and:
1 * handler.maxSpotAttempts() >> 0
1 * handler.getJobQueue(task) >> 'queue1'
1 * handler.getJobDefinition(task) >> 'job-def:1'
and:
Expand All @@ -176,6 +180,7 @@ class AwsBatchTaskHandlerTest extends Specification {
task.getConfig() >> new TaskConfig(time: '5 sec')
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') }
and:
1 * handler.maxSpotAttempts() >> 0
1 * handler.getJobQueue(task) >> 'queue2'
1 * handler.getJobDefinition(task) >> 'job-def:2'
and:
Expand All @@ -193,6 +198,7 @@ class AwsBatchTaskHandlerTest extends Specification {
task.getConfig() >> new TaskConfig(time: '1 hour')
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws') }
and:
1 * handler.maxSpotAttempts() >> 0
1 * handler.getJobQueue(task) >> 'queue3'
1 * handler.getJobDefinition(task) >> 'job-def:3'
and:
Expand Down Expand Up @@ -221,6 +227,7 @@ class AwsBatchTaskHandlerTest extends Specification {
then:
handler.getAwsOptions() >> { new AwsOptions(cliPath: '/bin/aws', retryMode: 'adaptive', maxTransferAttempts: 10) }
and:
1 * handler.maxSpotAttempts() >> 3
1 * handler.getJobQueue(task) >> 'queue1'
1 * handler.getJobDefinition(task) >> 'job-def:1'
1 * handler.wrapperFile >> Paths.get('/bucket/test/.command.run')
Expand All @@ -230,7 +237,7 @@ class AwsBatchTaskHandlerTest extends Specification {
req.getJobQueue() == 'queue1'
req.getJobDefinition() == 'job-def:1'
// no error `retry` error strategy is defined by NF, use `maxRetries` to se Batch attempts
req.getRetryStrategy() == new RetryStrategy().withAttempts(3)
req.getRetryStrategy() == new RetryStrategy().withAttempts(3).withEvaluateOnExit( new EvaluateOnExit().withAction('RETRY').withOnReason('Host EC2*') )
req.getContainerOverrides().getEnvironment() == [VAR_RETRY_MODE, VAR_MAX_ATTEMPTS, VAR_METADATA_ATTEMPTS]
}

Expand Down Expand Up @@ -727,29 +734,4 @@ class AwsBatchTaskHandlerTest extends Specification {
trace.machineInfo.priceModel == PriceModel.spot
}

def 'should check spot termination' () {
given:
def JOB_ID = 'job-2'
def client = Mock(AWSBatch)
def task = new TaskRun()
def handler = Spy(AwsBatchTaskHandler)
handler.client = client
handler.jobId = JOB_ID
handler.task = task
and:
handler.isRunning() >> true
handler.describeJob(JOB_ID) >> Mock(JobDetail) {
getStatus() >> 'FAILED'
getStatusReason() >> "Host EC2 (instance i-0e2d5c2edc932b4e8) terminated."
}

when:
def done = handler.checkIfCompleted()
then:
task.aborted
task.error instanceof NodeTerminationException
and:
done == true

}
}

0 comments on commit cd95e29

Please sign in to comment.