forked from emicklei/go-restful
/
cors_filter.go
125 lines (113 loc) · 4.01 KB
/
cors_filter.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
package restful
// Copyright 2013 Ernest Micklei. All rights reserved.
// Use of this source code is governed by a license
// that can be found in the LICENSE file.
import "strings"
// CrossOriginResourceSharing is used to create a Container Filter that implements CORS.
// Cross-origin resource sharing (CORS) is a mechanism that allows JavaScript on a web page
// to make XMLHttpRequests to another domain, not the domain the JavaScript originated from.
//
// http://en.wikipedia.org/wiki/Cross-origin_resource_sharing
// http://enable-cors.org/server.html
// http://www.html5rocks.com/en/tutorials/cors/#toc-handling-a-not-so-simple-request
type CrossOriginResourceSharing struct {
ExposeHeaders []string // list of Header names
AllowedHeaders []string // list of Header names
AllowedDomains []string
CookiesAllowed bool
Container *Container
}
// Filter is a filter function that implements the CORS flow as documented on http://enable-cors.org/server.html
// and http://www.html5rocks.com/static/images/cors_server_flowchart.png
func (c CrossOriginResourceSharing) Filter(req *Request, resp *Response, chain *FilterChain) {
if origin := req.Request.Header.Get(HEADER_Origin); len(origin) == 0 {
chain.ProcessFilter(req, resp)
return
}
if req.Request.Method != "OPTIONS" {
c.doActualRequest(req, resp, chain)
return
}
if acrm := req.Request.Header.Get(HEADER_AccessControlRequestMethod); acrm != "" {
c.doPreflightRequest(req, resp, chain)
} else {
c.doActualRequest(req, resp, chain)
}
}
func (c CrossOriginResourceSharing) doActualRequest(req *Request, resp *Response, chain *FilterChain) {
resp.AddHeader(HEADER_AccessControlExposeHeaders, strings.Join(c.AllowedHeaders, ","))
c.checkAndSetExposeHeaders(resp)
c.setAllowOriginHeader(req, resp)
c.checkAndSetAllowCredentials(resp)
// continue processing the response
chain.ProcessFilter(req, resp)
}
func (c CrossOriginResourceSharing) doPreflightRequest(req *Request, resp *Response, chain *FilterChain) {
allowedMethods := c.Container.computeAllowedMethods(req)
if !c.isValidAccessControlRequestMethod(req.Request.Header.Get(HEADER_AccessControlRequestMethod), allowedMethods) {
chain.ProcessFilter(req, resp)
return
}
acrhs := req.Request.Header.Get(HEADER_AccessControlRequestHeaders)
if len(acrhs) > 0 {
for _, each := range strings.Split(acrhs, ",") {
if !c.isValidAccessControlRequestHeader(strings.Trim(each, " ")) {
chain.ProcessFilter(req, resp)
return
}
}
}
resp.AddHeader(HEADER_AccessControlAllowMethods, strings.Join(allowedMethods, ","))
resp.AddHeader(HEADER_AccessControlAllowHeaders, acrhs)
c.setAllowOriginHeader(req, resp)
c.checkAndSetAllowCredentials(resp)
// return http 200 response, no body
}
func (c CrossOriginResourceSharing) isOriginAllowed(origin string) bool {
if len(origin) == 0 {
return false
}
if len(c.AllowedDomains) == 0 {
return true
}
allowed := false
for _, each := range c.AllowedDomains {
if each == origin {
allowed = true
break
}
}
return allowed
}
func (c CrossOriginResourceSharing) setAllowOriginHeader(req *Request, resp *Response) {
origin := req.Request.Header.Get(HEADER_Origin)
if c.isOriginAllowed(origin) {
resp.AddHeader(HEADER_AccessControlAllowOrigin, origin)
}
}
func (c CrossOriginResourceSharing) checkAndSetExposeHeaders(resp *Response) {
if len(c.ExposeHeaders) > 0 {
resp.AddHeader(HEADER_AccessControlExposeHeaders, strings.Join(c.ExposeHeaders, ","))
}
}
func (c CrossOriginResourceSharing) checkAndSetAllowCredentials(resp *Response) {
if c.CookiesAllowed {
resp.AddHeader(HEADER_AccessControlAllowCredentials, "true")
}
}
func (c CrossOriginResourceSharing) isValidAccessControlRequestMethod(method string, allowedMethods []string) bool {
for _, each := range allowedMethods {
if each == method {
return true
}
}
return false
}
func (c CrossOriginResourceSharing) isValidAccessControlRequestHeader(header string) bool {
for _, each := range c.AllowedHeaders {
if strings.ToLower(each) == strings.ToLower(header) {
return true
}
}
return false
}