/
aws.go
132 lines (123 loc) · 3.67 KB
/
aws.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
package aws
import (
"fmt"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/awserr"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/elbv2"
"github.com/aws/aws-sdk-go/service/shield"
"github.com/aws/aws-sdk-go/service/sts"
log "github.com/sirupsen/logrus"
)
func EnableAWSShield(ingressList []string) error {
awsSession := session.Must(session.NewSession())
for _, ingress := range ingressList {
lb, err := getLoadbalancer(awsSession, ingress)
if err != nil {
return err
}
if *lb.Type == "network" {
// NLB currently does not support straight enabling by ARN so it needs to be enabled with EIPs
err := shieldEnableNLB(awsSession, lb)
if err != nil {
return err
}
} else {
// Empty function
err = shieldEnableLB(awsSession, lb)
if err != nil {
return err
}
}
}
return nil
}
// Gets all active load balancers in the AWS account
func getLoadbalancer(awsSession *session.Session, ingress string) (*elbv2.LoadBalancer, error) {
svc := elbv2.New(awsSession)
input := &elbv2.DescribeLoadBalancersInput{}
result, err := svc.DescribeLoadBalancers(input)
if err != nil {
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case elbv2.ErrCodeLoadBalancerNotFoundException:
log.Error(elbv2.ErrCodeLoadBalancerNotFoundException, aerr.Error())
default:
log.Error(aerr.Error())
}
} else {
log.Error(err.Error())
}
return nil, err
}
for _, loadbalancer := range result.LoadBalancers {
if *loadbalancer.DNSName == ingress {
log.Infof("found matching loadbalancer: %s", *loadbalancer.DNSName)
return loadbalancer, nil
}
}
return nil, fmt.Errorf("did not find matching loadbalancer for ingress: %s", ingress)
}
// Specific LB Shield enablers
func shieldEnableLB(awsSession *session.Session, lb *elbv2.LoadBalancer) error {
svc := shield.New(awsSession)
err := enableShield(svc, *lb.LoadBalancerArn, *lb.LoadBalancerName)
if err != nil {
return err
}
return nil
}
// Loops through NLBs EIP addresses and enables shield to them
// NLB currently does not provide straight Shield enablement
func shieldEnableNLB(awsSession *session.Session, lb *elbv2.LoadBalancer) error {
svc := shield.New(awsSession)
for _, l := range lb.AvailabilityZones {
for _, address := range l.LoadBalancerAddresses {
eipArn, err := generateEipArn(awsSession, address)
if err != nil {
return err
}
err = enableShield(svc, eipArn, *address.AllocationId)
if err != nil {
return err
}
}
}
return nil
}
// Enables the shield on the resources
func enableShield(svc *shield.Shield, resourceArn string, resourceName string) error {
log.Infof("enabling shield for %s", resourceName)
input := &shield.CreateProtectionInput{
Name: aws.String(resourceName),
ResourceArn: aws.String(resourceArn),
}
result, err := svc.CreateProtection(input)
if aerr, ok := err.(awserr.Error); ok {
switch aerr.Code() {
case shield.ErrCodeResourceAlreadyExistsException:
log.Infof("target: %s is already protected", resourceName)
return nil
default:
log.Error(aerr.Error())
return aerr
}
}
log.Infof("Resource %s is now protected: %s", resourceName, *result.ProtectionId)
return nil
}
// Generates the EIP ARN because it's not provided by the API
func generateEipArn(awsSession *session.Session, address *elbv2.LoadBalancerAddress) (string, error) {
svc := sts.New(awsSession)
input := &sts.GetCallerIdentityInput{}
result, err := svc.GetCallerIdentity(input)
if err != nil {
return "", err
}
accountId := *result.Account
region := *svc.Config.Region
eipArn := fmt.Sprintf(
"arn:aws:ec2:%s:%s:eip-allocation/%s", region, accountId, *address.AllocationId,
)
return eipArn, nil
}