/
http_response_headers.go
129 lines (109 loc) · 3.31 KB
/
http_response_headers.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
package configutil
import (
"fmt"
"net/textproto"
"strconv"
"strings"
"github.com/hashicorp/go-secure-stdlib/strutil"
)
var ValidCustomStatusCodeCollection = []string{
"default",
"1xx",
"2xx",
"3xx",
"4xx",
"5xx",
}
const StrictTransportSecurity = "max-age=31536000; includeSubDomains"
// ParseCustomResponseHeaders takes a raw config values for the
// "custom_response_headers". It makes sure the config entry is passed in
// as a map of status code to a map of header name and header values. It
// verifies the validity of the status codes, and header values. It also
// adds the default headers values.
func ParseCustomResponseHeaders(responseHeaders interface{}) (map[string]map[string]string, error) {
h := make(map[string]map[string]string)
// if r is nil, we still should set the default custom headers
if responseHeaders == nil {
h["default"] = map[string]string{"Strict-Transport-Security": StrictTransportSecurity}
return h, nil
}
customResponseHeader, ok := responseHeaders.([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps")
}
for _, crh := range customResponseHeader {
for statusCode, responseHeader := range crh {
headerValList, ok := responseHeader.([]map[string]interface{})
if !ok {
return nil, fmt.Errorf("response headers were not configured correctly. please make sure they're in a slice of maps")
}
if !IsValidStatusCode(statusCode) {
return nil, fmt.Errorf("invalid status code found in the server configuration: %v", statusCode)
}
if len(headerValList) != 1 {
return nil, fmt.Errorf("invalid number of response headers exist")
}
headerValMap := headerValList[0]
headerVal, err := parseHeaders(headerValMap)
if err != nil {
return nil, err
}
h[statusCode] = headerVal
}
}
// setting Strict-Transport-Security as a default header
if h["default"] == nil {
h["default"] = make(map[string]string)
}
if _, ok := h["default"]["Strict-Transport-Security"]; !ok {
h["default"]["Strict-Transport-Security"] = StrictTransportSecurity
}
return h, nil
}
// IsValidStatusCode checking for status codes outside the boundary
func IsValidStatusCode(sc string) bool {
if strutil.StrListContains(ValidCustomStatusCodeCollection, sc) {
return true
}
i, err := strconv.Atoi(sc)
if err != nil {
return false
}
if i >= 600 || i < 100 {
return false
}
return true
}
func parseHeaders(in map[string]interface{}) (map[string]string, error) {
hvMap := make(map[string]string)
for k, v := range in {
// parsing header name
headerName := textproto.CanonicalMIMEHeaderKey(k)
// parsing header values
s, err := parseHeaderValues(v)
if err != nil {
return nil, err
}
hvMap[headerName] = s
}
return hvMap, nil
}
func parseHeaderValues(header interface{}) (string, error) {
var sl []string
if _, ok := header.([]interface{}); !ok {
return "", fmt.Errorf("headers must be given in a list of strings")
}
headerValList := header.([]interface{})
for _, vh := range headerValList {
if _, ok := vh.(string); !ok {
return "", fmt.Errorf("found a non-string header value: %v", vh)
}
headerVal := strings.TrimSpace(vh.(string))
if headerVal == "" {
continue
}
sl = append(sl, headerVal)
}
s := strings.Join(sl, "; ")
return s, nil
}