-
Notifications
You must be signed in to change notification settings - Fork 0
/
cors.go
113 lines (96 loc) · 2.69 KB
/
cors.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
package oauth2cors
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/driver005/oauth/client"
"github.com/driver005/oauth/config"
"github.com/driver005/oauth/helpers"
"github.com/driver005/oauth/oauth2"
"github.com/gobwas/glob"
"github.com/rs/cors"
"github.com/ory/fosite"
)
func Middleware(reg interface {
Config() *config.Provider
helpers.RegistryLogger
oauth2.Registry
client.Registry
}) func(h http.Handler) http.Handler {
opts, enabled := reg.Config().CORS(config.PublicInterface)
if !enabled {
return func(h http.Handler) http.Handler {
return h
}
}
var alwaysAllow = len(opts.AllowedOrigins) == 0
var patterns []glob.Glob
for _, o := range opts.AllowedOrigins {
if o == "*" {
alwaysAllow = true
}
// if the protocol (http or https) is specified, but the url is wildcard, use special ** glob, which ignore the '.' separator.
// This way g := glob.Compile("http://**") g.Match("http://google.com") returns true.
if splittedO := strings.Split(o, "://"); len(splittedO) != 1 && splittedO[1] == "*" {
o = fmt.Sprintf("%s://**", splittedO[0])
}
g, err := glob.Compile(strings.ToLower(o), '.')
if err != nil {
reg.Logger().WithError(err).Fatalf("Unable to parse cors origin: %s", o)
}
patterns = append(patterns, g)
}
options := cors.Options{
AllowedOrigins: opts.AllowedOrigins,
AllowedMethods: opts.AllowedMethods,
AllowedHeaders: opts.AllowedHeaders,
ExposedHeaders: opts.ExposedHeaders,
MaxAge: opts.MaxAge,
AllowCredentials: opts.AllowCredentials,
OptionsPassthrough: opts.OptionsPassthrough,
Debug: opts.Debug,
AllowOriginRequestFunc: func(r *http.Request, origin string) bool {
if alwaysAllow {
return true
}
origin = strings.ToLower(origin)
for _, p := range patterns {
if p.Match(origin) {
return true
}
}
username, _, ok := r.BasicAuth()
if !ok || username == "" {
token := fosite.AccessTokenFromRequest(r)
if token == "" {
return false
}
session := oauth2.NewSessionWithCustomClaims("", reg.Config().AllowedTopLevelClaims())
_, ar, err := reg.OAuth2Provider().IntrospectToken(context.Background(), token, fosite.AccessToken, session)
if err != nil {
return false
}
username = ar.GetClient().GetID()
}
cl, err := reg.ClientManager().GetConcreteClient(r.Context(), username)
if err != nil {
return false
}
for _, o := range cl.AllowedCORSOrigins {
if o == "*" {
return true
}
g, err := glob.Compile(strings.ToLower(o), '.')
if err != nil {
return false
}
if g.Match(origin) {
return true
}
}
return false
},
}
return cors.New(options).Handler
}