forked from mweagle/Sparta
/
drift.go
155 lines (144 loc) · 4.92 KB
/
drift.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
package validator
import (
"fmt"
"strings"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/cloudformation"
sparta "github.com/mweagle/Sparta"
gocf "github.com/mweagle/go-cloudformation"
"github.com/pkg/errors"
"github.com/sirupsen/logrus"
)
// DriftDetector is a detector that ensures that the service hasn't
// experienced configuration drift prior to being overwritten by a new provisioning
// step.
func DriftDetector(errorOnDrift bool) sparta.ServiceValidationHookHandler {
driftDetector := func(context map[string]interface{},
serviceName string,
template *gocf.Template,
S3Bucket string,
S3Key string,
buildID string,
awsSession *session.Session,
noop bool,
logger *logrus.Logger) error {
// Create a cloudformation service.
cfSvc := cloudformation.New(awsSession)
detectStackDrift, detectStackDriftErr := cfSvc.DetectStackDrift(&cloudformation.DetectStackDriftInput{
StackName: aws.String(serviceName),
})
if detectStackDriftErr != nil {
// If it doesn't exist, then no worries...
if strings.Contains(detectStackDriftErr.Error(), "does not exist") {
return nil
}
return errors.Wrapf(detectStackDriftErr, "attempting to determine stack drift")
}
// Poll until it's done...
describeDriftDetectionStatus := &cloudformation.DescribeStackDriftDetectionStatusInput{
StackDriftDetectionId: detectStackDrift.StackDriftDetectionId,
}
detectionComplete := false
// Put a limit on the detection
for i := 0; i <= 30 && !detectionComplete; i++ {
driftStatus, driftStatusErr := cfSvc.DescribeStackDriftDetectionStatus(describeDriftDetectionStatus)
if driftStatusErr != nil {
logger.WithField("error", driftStatusErr).Warn("Failed to check Stack Drift")
}
if driftStatus != nil {
switch *driftStatus.DetectionStatus {
case "DETECTION_COMPLETE":
detectionComplete = true
default:
logger.WithField("Status", *driftStatus.DetectionStatus).
Info("Waiting for drift detection to complete")
time.Sleep(11 * time.Second)
}
}
}
if !detectionComplete {
return errors.Errorf("Stack drift detection did not complete in time")
}
golangFuncName := func(logicalResourceID string) string {
templateRes, templateResExists := template.Resources[logicalResourceID]
if !templateResExists {
return ""
}
metadata := templateRes.Metadata
if len(metadata) <= 0 {
metadata = make(map[string]interface{})
}
golangFunc, golangFuncExists := metadata["golangFunc"]
if !golangFuncExists {
return ""
}
switch typedFunc := golangFunc.(type) {
case string:
return typedFunc
default:
return fmt.Sprintf("%#v", typedFunc)
}
}
// Log the drifts
logDrifts := func(stackResourceDrifts []*cloudformation.StackResourceDrift) {
for _, eachDrift := range stackResourceDrifts {
if len(eachDrift.PropertyDifferences) != 0 {
for _, eachDiff := range eachDrift.PropertyDifferences {
entry := logger.WithFields(logrus.Fields{
"Resource": *eachDrift.LogicalResourceId,
"Actual": *eachDiff.ActualValue,
"Expected": *eachDiff.ExpectedValue,
"Relation": *eachDiff.DifferenceType,
"PropertyPath": *eachDiff.PropertyPath,
"LambdaFuncName": golangFuncName(*eachDrift.LogicalResourceId),
})
if errorOnDrift {
entry.Error("Stack drift detected")
} else {
entry.Warn("Stack drift detected")
}
}
}
}
}
// Utility function to fetch all the drifts
stackResourceDrifts := make([]*cloudformation.StackResourceDrift, 0)
input := &cloudformation.DescribeStackResourceDriftsInput{
MaxResults: aws.Int64(100),
StackName: aws.String(serviceName),
}
// There can't be more than 200 resources in the template
// https://docs.aws.amazon.com/AWSCloudFormation/latest/UserGuide/cloudformation-limits.html
loopCounter := 0
for {
driftResults, driftResultsErr := cfSvc.DescribeStackResourceDrifts(input)
if driftResultsErr != nil {
return errors.Wrapf(driftResultsErr, "attempting to describe stack drift")
}
stackResourceDrifts = append(stackResourceDrifts, driftResults.StackResourceDrifts...)
if driftResults.NextToken == nil {
break
}
loopCounter++
// If there is more than 10 (1k total) something is seriously wrong...
if loopCounter >= 10 {
logDrifts(stackResourceDrifts)
return errors.Errorf("Exceeded maximum number of Stack resource drifts: %d", len(stackResourceDrifts))
}
input = &cloudformation.DescribeStackResourceDriftsInput{
MaxResults: aws.Int64(100),
StackName: aws.String(serviceName),
NextToken: driftResults.NextToken,
}
}
// Log them
logDrifts(stackResourceDrifts)
if len(stackResourceDrifts) == 0 || !errorOnDrift {
return nil
}
return errors.Errorf("stack %s operation prevented due to stack drift", serviceName)
}
return sparta.ServiceValidationHookFunc(driftDetector)
}