forked from flyteorg/flyteplugins
/
client.go
177 lines (138 loc) · 5.43 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
/*
* Copyright (c) 2018 Lyft. All rights reserved.
*/
// This package deals with the communication with AWS-Batch and adopting its APIs to the flyte-plugin model.
package awsbatch
import (
"context"
"fmt"
definition2 "github.com/lyft/flyteplugins/go/tasks/plugins/array/awsbatch/definition"
"github.com/lyft/flyteplugins/go/tasks/aws"
"github.com/lyft/flytestdlib/utils"
"github.com/lyft/flytestdlib/logger"
a "github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/service/batch"
)
//go:generate mockery -all -case=underscore
// AWS Batch Client interface.
type Client interface {
// Submits a new job to AWS Batch and retrieves job info. Note that submitted jobs will not have status populated.
SubmitJob(ctx context.Context, input *batch.SubmitJobInput) (jobID string, err error)
// Attempts to terminate a job. If the job hasn't started yet, it'll just get deleted.
TerminateJob(ctx context.Context, jobID JobID, reason string) error
// Retrieves jobs' details from AWS Batch.
GetJobDetailsBatch(ctx context.Context, ids []JobID) ([]*batch.JobDetail, error)
// Registers a new Job Definition with AWS Batch provided a name, image and role.
RegisterJobDefinition(ctx context.Context, name, image, role string) (arn string, err error)
// Gets the single region this client interacts with.
GetRegion() string
GetAccountID() string
}
// BatchServiceClient is an interface on top of the native AWS Batch client to allow for mocking and alternative implementations.
type BatchServiceClient interface {
SubmitJobWithContext(ctx a.Context, input *batch.SubmitJobInput, opts ...request.Option) (*batch.SubmitJobOutput, error)
TerminateJobWithContext(ctx a.Context, input *batch.TerminateJobInput, opts ...request.Option) (*batch.TerminateJobOutput, error)
DescribeJobsWithContext(ctx a.Context, input *batch.DescribeJobsInput, opts ...request.Option) (*batch.DescribeJobsOutput, error)
RegisterJobDefinitionWithContext(ctx a.Context, input *batch.RegisterJobDefinitionInput, opts ...request.Option) (*batch.RegisterJobDefinitionOutput, error)
}
type client struct {
Batch BatchServiceClient
getRateLimiter utils.RateLimiter
defaultRateLimiter utils.RateLimiter
region string
accountId string
}
func (b client) GetRegion() string {
return b.region
}
func (b client) GetAccountID() string {
return b.accountId
}
// Registers a new job definition. There is no deduping on AWS side (even for the same name).
func (b *client) RegisterJobDefinition(ctx context.Context, name, image, role string) (arn definition2.JobDefinitionArn, err error) {
logger.Infof(ctx, "Registering job definition with name [%v], image [%v], role [%v]", name, image, role)
res, err := b.Batch.RegisterJobDefinitionWithContext(ctx, &batch.RegisterJobDefinitionInput{
Type: refStr(batch.JobDefinitionTypeContainer),
JobDefinitionName: refStr(name),
ContainerProperties: &batch.ContainerProperties{
Image: refStr(image),
JobRoleArn: refStr(role),
// These will be overwritten on execution
Vcpus: refInt(1),
Memory: refInt(100),
},
})
if err != nil {
return "", err
}
return *res.JobDefinitionArn, nil
}
// Submits a new job to a desired queue
func (b *client) SubmitJob(ctx context.Context, input *batch.SubmitJobInput) (jobID string, err error) {
if input == nil {
return "", nil
}
if err := b.defaultRateLimiter.Wait(ctx); err != nil {
return "", err
}
output, err := b.Batch.SubmitJobWithContext(ctx, input)
if err != nil {
return "", err
}
if output.JobId == nil {
logger.Errorf(ctx, "Job submitted has no ID and no error is returned. This is an AWS-issue. Input [%v]", input.JobName)
return "", fmt.Errorf("job submitted has no ID and no error is returned. This is an AWS-issue. Input [%v]", input.JobName)
}
return *output.JobId, nil
}
// Terminates an in progress job
func (b *client) TerminateJob(ctx context.Context, jobID JobID, reason string) error {
if err := b.defaultRateLimiter.Wait(ctx); err != nil {
return err
}
input := batch.TerminateJobInput{
JobId: refStr(jobID),
Reason: refStr(reason),
}
if _, err := b.Batch.TerminateJobWithContext(ctx, &input); err != nil {
return err
}
return nil
}
func (b *client) GetJobDetailsBatch(ctx context.Context, jobIds []JobID) ([]*batch.JobDetail, error) {
if err := b.getRateLimiter.Wait(ctx); err != nil {
return nil, err
}
ids := make([]*string, 0, len(jobIds))
for _, id := range jobIds {
ids = append(ids, refStr(id))
}
input := batch.DescribeJobsInput{
Jobs: ids,
}
output, err := b.Batch.DescribeJobsWithContext(ctx, &input)
if err != nil {
return nil, err
}
return output.Jobs, nil
}
// Initializes a new Batch Client that can be used to interact with AWS Batch.
func NewBatchClient(awsClient aws.Client,
getRateLimiter utils.RateLimiter,
defaultRateLimiter utils.RateLimiter) Client {
batchClient := batch.New(awsClient.GetSession(), awsClient.GetSdkConfig())
return NewCustomBatchClient(batchClient, awsClient.GetConfig().AccountID, batchClient.SigningRegion,
getRateLimiter, defaultRateLimiter)
}
func NewCustomBatchClient(batchClient BatchServiceClient, accountId, region string,
getRateLimiter utils.RateLimiter,
defaultRateLimiter utils.RateLimiter) Client {
return &client{
Batch: batchClient,
accountId: accountId,
region: region,
getRateLimiter: getRateLimiter,
defaultRateLimiter: defaultRateLimiter,
}
}