/
taskfinder.go
283 lines (267 loc) · 7.81 KB
/
taskfinder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
package esu
import (
"fmt"
"sort"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go/service/ecs"
)
// TaskFinder provides a wrapper around the AWS-SDK for locating ECS tasks.
type TaskFinder struct {
cluster string
ecs *ecs.ECS
ec2 *ec2.EC2
}
// NewTaskFinder returns a new task finder. It is as thread-safe as the
// underlying AWS SDK :)
func NewTaskFinder(sess *session.Session, cluster string) *TaskFinder {
return &TaskFinder{
cluster: cluster,
ecs: ecs.New(sess),
ec2: ec2.New(sess),
}
}
// Services returns a list of ARNs for all services active on a cluster.
func (f *TaskFinder) Services() ([]string, error) {
var nextToken *string
services := []string{}
for {
resp, err := f.ecs.ListServices(&ecs.ListServicesInput{
Cluster: aws.String(f.cluster),
MaxResults: aws.Int64(10),
NextToken: nextToken,
})
if err != nil {
return nil, err
}
for _, str := range resp.ServiceArns {
services = append(services, *str)
}
nextToken = resp.NextToken
if nextToken == nil {
return services, nil
}
}
}
// Tasks returns information about a service's running tasks, sorted first by
// public DNS name and then port.
func (f *TaskFinder) Tasks(service string) ([]TaskInfo, error) {
tasksArns, err := f.fetchTasks(service)
if err != nil {
return nil, err
}
if len(tasksArns) == 0 {
return []TaskInfo{}, nil
}
tasks, err := f.describeTasks(tasksArns)
if err != nil {
return nil, err
}
instances, err := f.locateTasks(tasks)
if err != nil {
return nil, err
}
infos := []TaskInfo{}
for _, t := range tasks {
port, err := f.getPortForTask(t, service)
if err != nil {
return nil, fmt.Errorf("%s, cluster=%s, service=%s, task=%s", err, f.cluster, service, *t.TaskArn)
}
info := TaskInfo{
TaskDefinition: ParseARN(*t.TaskDefinitionArn).ShortName(),
DesiredStatus: ECSTaskStatus(realString(t.DesiredStatus)),
LastStatus: ECSTaskStatus(realString(t.LastStatus)),
StartedAt: realTime(t.StartedAt),
Port: port,
}
if t.ContainerInstanceArn != nil {
in, ok := instances[*t.ContainerInstanceArn]
if ok {
info.EC2InstanceID = realString(in.InstanceId)
info.PublicDNSName = realString(in.PublicDnsName)
info.PrivateDNSName = realString(in.PrivateDnsName)
info.PublicIPAddress = realString(in.PublicIpAddress)
info.PrivateIPAddress = realString(in.PrivateIpAddress)
}
}
infos = append(infos, info)
}
sort.Sort(taskInfoList(infos))
return infos, nil
}
// getPortForTasks looks up the containers associated with a task. For multi-
// container tasks, look for a container with the same name as the service. For
// example, if "foobaz" service runs an application container and a "mysql"
// container, for the purpose of this library the application should be named
// "foobaz". If multiple ports are mapped, the first one is taken.
func (f *TaskFinder) getPortForTask(t *ecs.Task, service string) (int, error) {
var c *ecs.Container
if len(t.Containers) == 0 {
return 0, fmt.Errorf("no containers configured")
} else if len(t.Containers) == 1 {
c = t.Containers[0]
} else {
for _, cc := range t.Containers {
if *cc.Name == service {
c = cc
break
}
}
if c == nil {
return 0, fmt.Errorf("ambiguous, multi-container task, one container should match service name")
}
}
if c.NetworkBindings == nil || len(c.NetworkBindings) == 0 {
// Pending tasks don't yet have network bindings.
return 0, nil
}
// Take the first port binding.
return int(*c.NetworkBindings[0].HostPort), nil
}
func (f *TaskFinder) locateTasks(tasks []*ecs.Task) (map[string]*ec2.Instance, error) {
if len(tasks) == 0 {
return map[string]*ec2.Instance{}, nil
}
ciArns := make([]*string, len(tasks))
for i, task := range tasks {
ciArns[i] = task.ContainerInstanceArn
}
resp, err := f.ecs.DescribeContainerInstances(&ecs.DescribeContainerInstancesInput{
ContainerInstances: ciArns,
Cluster: aws.String(f.cluster),
})
if err != nil {
return nil, propagate(err, "ecs describe container instances")
}
if len(resp.Failures) != 0 {
// TODO: This only shows first error.
return nil, fmt.Errorf("describe container failure on %s: %s", *resp.Failures[0].Arn, *resp.Failures[0].Reason)
}
ec2Ids := make([]*string, len(resp.ContainerInstances))
for i, ci := range resp.ContainerInstances {
ec2Ids[i] = ci.Ec2InstanceId
}
instances, err := f.locateInstances(ec2Ids)
if err != nil {
return nil, err
}
rv := map[string]*ec2.Instance{}
for _, ci := range resp.ContainerInstances {
for _, i := range instances {
if *i.InstanceId == *ci.Ec2InstanceId {
rv[*ci.ContainerInstanceArn] = i
}
}
}
return rv, nil
}
func (f *TaskFinder) locateInstances(ec2Ids []*string) ([]*ec2.Instance, error) {
resp, err := f.ec2.DescribeInstances(&ec2.DescribeInstancesInput{
DryRun: aws.Bool(false),
InstanceIds: ec2Ids,
})
if err != nil {
return nil, propagate(err, "ec2 describe instances")
}
instances := []*ec2.Instance{}
for _, r := range resp.Reservations {
// TODO: under what situation does this return multiple items?
instances = append(instances, r.Instances[0])
}
return instances, nil
}
func (f *TaskFinder) describeTasks(tasksArns []*string) ([]*ecs.Task, error) {
if len(tasksArns) == 0 {
return []*ecs.Task{}, nil
}
// DescribeTasks only allows 100 parameters, so in the case there's a flapping
// service and lots of stopped tasks we need to chunk calls to the SDK.
chunkedArns := chunk(tasksArns, 100)
var tasks []*ecs.Task
for _, chunk := range chunkedArns {
resp, err := f.ecs.DescribeTasks(&ecs.DescribeTasksInput{
Tasks: chunk,
Cluster: aws.String(f.cluster),
})
if err != nil {
return nil, propagate(err, "ecs describe tasks")
}
if len(resp.Failures) != 0 {
// TODO: This only shows the first error.
return nil, fmt.Errorf("describe task failure on %s: %s", *resp.Failures[0].Arn, *resp.Failures[0].Reason)
}
// Filter out stopped tasks, we still return tasks in the process of stopping.
for _, t := range resp.Tasks {
if t.LastStatus != nil && ECSTaskStatus(*t.LastStatus) != ECSTaskStatusStopped {
tasks = append(tasks, t)
}
}
}
return tasks, nil
}
func (f *TaskFinder) fetchTasks(service string) ([]*string, error) {
// ListTasks queries based off "DesiredState" not current state, we STOPPED as
// well so we can see running tasks that are in the process of stopping.
tasks, err := f.fetchTasksWithStatus(service, ECSTaskStatusRunning)
if err != nil {
return nil, err
}
stoppingTasks, err := f.fetchTasksWithStatus(service, ECSTaskStatusStopped)
if err != nil {
return nil, err
}
tasks = append(tasks, stoppingTasks...)
return tasks, nil
}
func (f *TaskFinder) fetchTasksWithStatus(service string, desiredStatus ECSTaskStatus) ([]*string, error) {
var nextToken *string
tasks := []*string{}
for {
resp, err := f.ecs.ListTasks(&ecs.ListTasksInput{
Cluster: aws.String(f.cluster),
ServiceName: aws.String(service),
DesiredStatus: aws.String(string(desiredStatus)),
NextToken: nextToken,
})
if err != nil {
return nil, propagate(err, "ecs list tasks")
}
for _, str := range resp.TaskArns {
tasks = append(tasks, str)
}
nextToken = resp.NextToken
if nextToken == nil {
return tasks, nil
}
}
}
func realString(s *string) string {
if s == nil {
return ""
}
return *s
}
func realTime(t *time.Time) time.Time {
if t == nil {
return time.Time{}
}
return *t
}
func propagate(err error, msg string) error {
return fmt.Errorf("%s: %s", msg, err)
}
func chunk(tasks []*string, count int) [][]*string {
var chunked [][]*string
for i := 0; i < len(tasks); i += count {
end := -1
if i+count < len(tasks) {
end = i + count
} else {
end = len(tasks)
}
chunked = append(chunked, tasks[i:end])
}
return chunked
}