forked from crewjam/saml
/
samlsp.go
142 lines (126 loc) · 3.53 KB
/
samlsp.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Package samlsp provides helpers that can be used to protect web
// services using SAML.
package samlsp
import (
"crypto/rsa"
"crypto/x509"
"encoding/xml"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"time"
"github.com/bilcus/saml"
"github.com/bilcus/saml/logger"
)
const defaultTokenMaxAge = time.Hour
// Options represents the parameters for creating a new middleware
type Options struct {
URL url.URL
Key *rsa.PrivateKey
Logger logger.Interface
Certificate *x509.Certificate
AllowIDPInitiated bool
IDPMetadata *saml.EntityDescriptor
IDPMetadataURL *url.URL
HTTPClient *http.Client
CookieMaxAge time.Duration
CookieSecure bool
ForceAuthn bool
EntityID string
NoDestinationCheck bool
}
// New creates a new Middleware
func New(opts Options) (*Middleware, error) {
metadataURL := opts.URL
metadataURL.Path = metadataURL.Path + "/saml/metadata"
acsURL := opts.URL
acsURL.Path = acsURL.Path + "/saml/acs"
logr := opts.Logger
if logr == nil {
logr = logger.DefaultLogger
}
tokenMaxAge := opts.CookieMaxAge
if opts.CookieMaxAge == 0 {
tokenMaxAge = defaultTokenMaxAge
}
m := &Middleware{
ServiceProvider: saml.ServiceProvider{
Key: opts.Key,
Logger: logr,
Certificate: opts.Certificate,
MetadataURL: metadataURL,
AcsURL: acsURL,
IDPMetadata: opts.IDPMetadata,
ForceAuthn: &opts.ForceAuthn,
EntityID: opts.EntityID,
NoDestinationCheck: opts.NoDestinationCheck,
},
AllowIDPInitiated: opts.AllowIDPInitiated,
TokenMaxAge: tokenMaxAge,
}
cookieStore := ClientCookies{
ServiceProvider: &m.ServiceProvider,
Name: defaultCookieName,
Domain: opts.URL.Host,
Secure: opts.CookieSecure,
}
m.ClientState = &cookieStore
m.ClientToken = &cookieStore
// fetch the IDP metadata if needed.
if opts.IDPMetadataURL == nil {
return m, nil
}
c := opts.HTTPClient
if c == nil {
c = http.DefaultClient
}
req, err := http.NewRequest("GET", opts.IDPMetadataURL.String(), nil)
if err != nil {
return nil, err
}
// Some providers (like OneLogin) do not work properly unless the User-Agent header is specified.
// Setting the user agent prevents the 403 Forbidden errors.
req.Header.Set("User-Agent", "Golang; github.com/bilcus/saml")
for i := 0; true; i++ {
resp, err := c.Do(req)
if err == nil && resp.StatusCode != http.StatusOK {
err = fmt.Errorf("%d %s", resp.StatusCode, resp.Status)
}
var data []byte
if err == nil {
data, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
}
if err != nil {
if i > 10 {
return nil, err
}
logr.Printf("ERROR: %s: %s (will retry)", opts.IDPMetadataURL, err)
time.Sleep(5 * time.Second)
continue
}
entity := &saml.EntityDescriptor{}
err = xml.Unmarshal(data, entity)
// this comparison is ugly, but it is how the error is generated in encoding/xml
if err != nil && err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
entities := &saml.EntitiesDescriptor{}
if err := xml.Unmarshal(data, entities); err != nil {
return nil, err
}
err = fmt.Errorf("no entity found with IDPSSODescriptor")
for i, e := range entities.EntityDescriptors {
if len(e.IDPSSODescriptors) > 0 {
entity = &entities.EntityDescriptors[i]
err = nil
}
}
}
if err != nil {
return nil, err
}
m.ServiceProvider.IDPMetadata = entity
return m, nil
}
panic("unreachable")
}