forked from russellhaering/gosaml2
-
Notifications
You must be signed in to change notification settings - Fork 5
/
utils.go
154 lines (135 loc) · 4.08 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
package providertests
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/xml"
"fmt"
"io/ioutil"
"sort"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/russellhaering/gosaml2"
"github.com/russellhaering/gosaml2/types"
"github.com/russellhaering/goxmldsig"
"github.com/stretchr/testify/require"
)
func scenarioIndexes(errs map[int]string, warns map[int]scenarioWarnings) (idxs []int) {
for idx, _ := range errs {
idxs = append(idxs, idx)
}
for idx, _ := range warns {
idxs = append(idxs, idx)
}
sort.Ints(idxs)
return
}
type scenarioWarnings struct {
InvalidTime bool
NotInAudience bool
}
func scenarioErrorChecker(i int, scenarioErrors map[int]string) func(*testing.T, error) {
return func(t *testing.T, err error) {
if msg, ok := scenarioErrors[i]; ok && msg != "" {
require.EqualError(t, err, msg, "Expected error message")
} else {
require.NoError(t, err)
}
}
}
func scenarioWarningChecker(i int, scenarioWarns map[int]scenarioWarnings) func(*testing.T, *saml2.WarningInfo) {
return func(t *testing.T, warningInfo *saml2.WarningInfo) {
expectedWarnings := scenarioWarns[i]
require.Equal(t, expectedWarnings.InvalidTime, warningInfo.InvalidTime, "InvalidTime mismatch")
require.Equal(t, expectedWarnings.NotInAudience, warningInfo.NotInAudience, "NotInAudience mismatch")
}
}
func LoadXMLResponse(path string) string {
xml, err := ioutil.ReadFile(path)
if err != nil {
panic(err)
}
return base64.StdEncoding.EncodeToString(xml)
}
func LoadRawResponse(path string) string {
data, err := ioutil.ReadFile(path)
if err != nil {
panic(err)
}
return string(data)
}
func LoadKeyStore(certPath, keyPath string) (ks dsig.TLSCertKeyStore) {
if certBytes, err := ioutil.ReadFile(certPath); err != nil {
panic(fmt.Errorf("%v: cannot read: %v", certPath, err))
} else if keyBytes, err := ioutil.ReadFile(keyPath); err != nil {
panic(fmt.Errorf("%v: cannot read: %v", keyPath, err))
} else if cert, err := tls.X509KeyPair(certBytes, keyBytes); err != nil {
panic(fmt.Errorf("%v/%v: cannot create key pair: %v", certPath, keyPath, err))
} else {
ks = dsig.TLSCertKeyStore(cert)
}
return
}
func LoadCertificateStore(path string) dsig.X509CertificateStore {
encoded, err := ioutil.ReadFile(path)
if err != nil {
panic(err)
}
block, _ := pem.Decode(encoded)
if block == nil {
panic("no certificate block found")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
panic(err)
}
return &dsig.MemoryX509CertificateStore{
Roots: []*x509.Certificate{cert},
}
}
type ProviderTestScenario struct {
ScenarioName string
Response string
ServiceProvider *saml2.SAMLServiceProvider
CheckError func(*testing.T, error)
CheckWarningInfo func(*testing.T, *saml2.WarningInfo)
}
func getAtTime(idx int, scenarioAtTimes map[int]string) (atTime time.Time) {
if strAtTime, ok := scenarioAtTimes[idx]; ok && strAtTime != "" {
if atm, err := time.Parse(time.RFC3339, strAtTime); err == nil {
return atm
}
}
return // zero time
}
func spAtTime(template *saml2.SAMLServiceProvider, atTime time.Time, rawResp string) *saml2.SAMLServiceProvider {
resp := &types.Response{}
if rawResp == "" {
panic(fmt.Errorf("empty rawResp"))
}
var respBytes []byte
var err error
if respBytes, err = base64.StdEncoding.DecodeString(rawResp); err != nil {
respBytes = []byte(rawResp)
}
if err := xml.Unmarshal(respBytes, resp); err != nil {
panic(fmt.Errorf("cannot parse Response XML: %v", err))
}
var sp saml2.SAMLServiceProvider
sp = *template // copy most fields template, we only set the clock below
if atTime.IsZero() {
// Prefer more official Assertion IssueInstant over Response IssueIntant
// (Assertion will be signed, either individually or as part of Response)
if len(resp.Assertions) > 0 && !resp.Assertions[0].IssueInstant.IsZero() {
atTime = resp.Assertions[0].IssueInstant
} else if !resp.IssueInstant.IsZero() {
atTime = resp.IssueInstant
} else {
panic(fmt.Errorf("could not determine atTime"))
}
}
sp.Clock = dsig.NewFakeClock(clockwork.NewFakeClockAt(atTime))
return &sp
}