-
Notifications
You must be signed in to change notification settings - Fork 4.2k
/
audited_headers.go
161 lines (130 loc) · 4.11 KB
/
audited_headers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
package vault
import (
"context"
"fmt"
"strings"
"sync"
"github.com/hashicorp/vault/sdk/logical"
)
// N.B.: While we could use textproto to get the canonical mime header, HTTP/2
// requires all headers to be converted to lower case, so we just do that.
const (
// Key used in the BarrierView to store and retrieve the header config
auditedHeadersEntry = "audited-headers"
// Path used to create a sub view off of BarrierView
auditedHeadersSubPath = "audited-headers-config/"
)
type auditedHeaderSettings struct {
HMAC bool `json:"hmac"`
}
// AuditedHeadersConfig is used by the Audit Broker to write only approved
// headers to the audit logs. It uses a BarrierView to persist the settings.
type AuditedHeadersConfig struct {
Headers map[string]*auditedHeaderSettings
view *BarrierView
sync.RWMutex
}
// add adds or overwrites a header in the config and updates the barrier view
func (a *AuditedHeadersConfig) add(ctx context.Context, header string, hmac bool) error {
if header == "" {
return fmt.Errorf("header value cannot be empty")
}
// Grab a write lock
a.Lock()
defer a.Unlock()
if a.Headers == nil {
a.Headers = make(map[string]*auditedHeaderSettings, 1)
}
a.Headers[strings.ToLower(header)] = &auditedHeaderSettings{hmac}
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
if err != nil {
return fmt.Errorf("failed to persist audited headers config: %w", err)
}
if err := a.view.Put(ctx, entry); err != nil {
return fmt.Errorf("failed to persist audited headers config: %w", err)
}
return nil
}
// remove deletes a header out of the header config and updates the barrier view
func (a *AuditedHeadersConfig) remove(ctx context.Context, header string) error {
if header == "" {
return fmt.Errorf("header value cannot be empty")
}
// Grab a write lock
a.Lock()
defer a.Unlock()
// Nothing to delete
if len(a.Headers) == 0 {
return nil
}
delete(a.Headers, strings.ToLower(header))
entry, err := logical.StorageEntryJSON(auditedHeadersEntry, a.Headers)
if err != nil {
return fmt.Errorf("failed to persist audited headers config: %w", err)
}
if err := a.view.Put(ctx, entry); err != nil {
return fmt.Errorf("failed to persist audited headers config: %w", err)
}
return nil
}
// ApplyConfig returns a map of approved headers and their values, either
// hmac'ed or plaintext
func (a *AuditedHeadersConfig) ApplyConfig(ctx context.Context, headers map[string][]string, hashFunc func(context.Context, string) (string, error)) (result map[string][]string, retErr error) {
// Grab a read lock
a.RLock()
defer a.RUnlock()
// Make a copy of the incoming headers with everything lower so we can
// case-insensitively compare
lowerHeaders := make(map[string][]string, len(headers))
for k, v := range headers {
lowerHeaders[strings.ToLower(k)] = v
}
result = make(map[string][]string, len(a.Headers))
for key, settings := range a.Headers {
if val, ok := lowerHeaders[key]; ok {
// copy the header values so we don't overwrite them
hVals := make([]string, len(val))
copy(hVals, val)
// Optionally hmac the values
if settings.HMAC {
for i, el := range hVals {
hVal, err := hashFunc(ctx, el)
if err != nil {
return nil, err
}
hVals[i] = hVal
}
}
result[key] = hVals
}
}
return result, nil
}
// Initialize the headers config by loading from the barrier view
func (c *Core) setupAuditedHeadersConfig(ctx context.Context) error {
// Create a sub-view
view := c.systemBarrierView.SubView(auditedHeadersSubPath)
// Create the config
out, err := view.Get(ctx, auditedHeadersEntry)
if err != nil {
return fmt.Errorf("failed to read config: %w", err)
}
headers := make(map[string]*auditedHeaderSettings)
if out != nil {
err = out.DecodeJSON(&headers)
if err != nil {
return err
}
}
// Ensure that we are able to case-sensitively access the headers;
// necessary for the upgrade case
lowerHeaders := make(map[string]*auditedHeaderSettings, len(headers))
for k, v := range headers {
lowerHeaders[strings.ToLower(k)] = v
}
c.auditedHeaders = &AuditedHeadersConfig{
Headers: lowerHeaders,
view: view,
}
return nil
}